Apparently, non-blocking connect doesn't work on windows if you use 0
timeout in the select call...
This commit is contained in:
parent
63e3d7c5b0
commit
e57f9e9964
7 changed files with 106 additions and 100 deletions
|
@ -35,6 +35,13 @@ local sending = newset()
|
|||
-- context for connections and servers
|
||||
local context = {}
|
||||
|
||||
function wait(who, what)
|
||||
if what == "input" then receiving:insert(who)
|
||||
else sending:insert(who) end
|
||||
context[who].last = socket.gettime()
|
||||
coroutine.yield()
|
||||
end
|
||||
|
||||
-- initializes the forward server
|
||||
function init()
|
||||
if table.getn(arg) < 1 then
|
||||
|
@ -63,145 +70,142 @@ function init()
|
|||
end
|
||||
|
||||
-- starts a connection in a non-blocking way
|
||||
function nbkcon(host, port)
|
||||
local peer, err = socket.tcp()
|
||||
if not peer then return nil, err end
|
||||
peer:settimeout(0)
|
||||
local ret, err = peer:connect(host, port)
|
||||
if ret then return peer end
|
||||
if err ~= "timeout" then
|
||||
peer:close()
|
||||
return nil, err
|
||||
function connect(who, host, port)
|
||||
who:settimeout(0.1)
|
||||
print("trying to connect peer", who, host, port)
|
||||
local ret, err = who:connect(host, port)
|
||||
if not ret and err == "timeout" then
|
||||
print("got timeout, will wait", who)
|
||||
wait(who, "output")
|
||||
ret, err = who:connected()
|
||||
print("connection results arrived", who, ret, err)
|
||||
end
|
||||
if not ret then
|
||||
print("connection failed", who)
|
||||
kick(who)
|
||||
kick(context[who].peer)
|
||||
else
|
||||
return forward(who)
|
||||
end
|
||||
return peer
|
||||
end
|
||||
|
||||
-- gets rid of a client
|
||||
-- gets rid of a client and its peer
|
||||
function kick(who)
|
||||
if context[who] then
|
||||
if who and context[who] then
|
||||
sending:remove(who)
|
||||
receiving:remove(who)
|
||||
local peer = context[who].peer
|
||||
context[who] = nil
|
||||
who:close()
|
||||
end
|
||||
end
|
||||
|
||||
-- decides what to do with a thread based on coroutine return
|
||||
function route(who, status, what)
|
||||
if status and what then
|
||||
if what == "receiving" then receiving:insert(who) end
|
||||
if what == "sending" then sending:insert(who) end
|
||||
else kick(who) end
|
||||
end
|
||||
|
||||
-- loops accepting connections and creating new threads to deal with them
|
||||
function accept(server)
|
||||
while true do
|
||||
-- accept a new connection and start a new coroutine to deal with it
|
||||
local client = server:accept()
|
||||
print("accepted ", client)
|
||||
if client then
|
||||
-- start a new connection, non-blockingly, to the forwarding address
|
||||
local ohost = context[server].ohost
|
||||
local oport = context[server].oport
|
||||
local peer = nbkcon(ohost, oport)
|
||||
-- create contexts for client and peer.
|
||||
local peer, err = socket.tcp()
|
||||
if peer then
|
||||
context[client] = {
|
||||
last = socket.gettime(),
|
||||
-- client goes straight to forwarding loop
|
||||
thread = coroutine.create(forward),
|
||||
peer = peer,
|
||||
}
|
||||
-- make sure peer will be tested for writing in the next select
|
||||
-- round, which means the connection attempt has finished
|
||||
sending:insert(peer)
|
||||
context[peer] = {
|
||||
last = socket.gettime(),
|
||||
peer = client,
|
||||
thread = coroutine.create(chkcon),
|
||||
-- peer first tries to connect to forwarding address
|
||||
thread = coroutine.create(connect),
|
||||
last = socket.gettime()
|
||||
}
|
||||
-- put both in non-blocking mode
|
||||
client:settimeout(0)
|
||||
peer:settimeout(0)
|
||||
-- resume peer and client so they can do their thing
|
||||
local ohost = context[server].ohost
|
||||
local oport = context[server].oport
|
||||
coroutine.resume(context[peer].thread, peer, ohost, oport)
|
||||
coroutine.resume(context[client].thread, client)
|
||||
else
|
||||
-- otherwise just dump the client
|
||||
client:close()
|
||||
print(err)
|
||||
client:close()
|
||||
end
|
||||
end
|
||||
-- tell scheduler we are done for now
|
||||
coroutine.yield("receiving")
|
||||
wait(server, "input")
|
||||
end
|
||||
end
|
||||
|
||||
-- forwards all data arriving to the appropriate peer
|
||||
function forward(who)
|
||||
print("starting to foward", who)
|
||||
who:settimeout(0)
|
||||
while true do
|
||||
-- wait until we have something to read
|
||||
wait(who, "input")
|
||||
-- try to read as much as possible
|
||||
local data, rec_err, partial = who:receive("*a")
|
||||
-- if we had an error other than timeout, abort
|
||||
if rec_err and rec_err ~= "timeout" then return error(rec_err) end
|
||||
if rec_err and rec_err ~= "timeout" then return kick(who) end
|
||||
-- if we got a timeout, we probably have partial results to send
|
||||
data = data or partial
|
||||
-- renew our timestamp so scheduler sees we are active
|
||||
context[who].last = socket.gettime()
|
||||
-- forward what we got right away
|
||||
local peer = context[who].peer
|
||||
while true do
|
||||
-- tell scheduler we need to wait until we can send something
|
||||
coroutine.yield("sending")
|
||||
wait(who, "output")
|
||||
local ret, snd_err
|
||||
local start = 0
|
||||
ret, snd_err, start = peer:send(data, start+1)
|
||||
if ret then break
|
||||
elseif snd_err ~= "timeout" then return error(snd_err) end
|
||||
-- renew our timestamp so scheduler sees we are active
|
||||
context[who].last = socket.gettime()
|
||||
elseif snd_err ~= "timeout" then return kick(who) end
|
||||
end
|
||||
-- if we are done receiving, we are done with this side of the
|
||||
-- connection
|
||||
if not rec_err then return nil end
|
||||
-- otherwise tell schedule we have to wait for more data to arrive
|
||||
coroutine.yield("receiving")
|
||||
-- if we are done receiving, we are done
|
||||
if not rec_err then return kick(who) end
|
||||
end
|
||||
end
|
||||
|
||||
-- checks if a connection completed successfully and if it did, starts
|
||||
-- forwarding all data
|
||||
function chkcon(who)
|
||||
local ret, err = who:connected()
|
||||
if ret then
|
||||
receiving:insert(context[who].peer)
|
||||
context[who].last = socket.gettime()
|
||||
coroutine.yield("receiving")
|
||||
return forward(who)
|
||||
else return error(err) end
|
||||
end
|
||||
|
||||
-- loop waiting until something happens, restarting the thread to deal with
|
||||
-- what happened, and routing it to wait until something else happens
|
||||
function go()
|
||||
while true do
|
||||
print("will select for reading")
|
||||
for i,v in ipairs(receiving) do
|
||||
print(i, v)
|
||||
end
|
||||
print("will select for sending")
|
||||
for i,v in ipairs(sending) do
|
||||
print(i, v)
|
||||
end
|
||||
-- check which sockets are interesting and act on them
|
||||
readable, writable = socket.select(receiving, sending, 3)
|
||||
-- for all readable connections, resume its thread and route it
|
||||
print("was readable")
|
||||
for i,v in ipairs(readable) do
|
||||
print(i, v)
|
||||
end
|
||||
print("was writable")
|
||||
for i,v in ipairs(writable) do
|
||||
print(i, v)
|
||||
end
|
||||
-- for all readable connections, resume its thread
|
||||
for _, who in ipairs(readable) do
|
||||
receiving:remove(who)
|
||||
if context[who] then
|
||||
route(who, coroutine.resume(context[who].thread, who))
|
||||
end
|
||||
coroutine.resume(context[who].thread, who)
|
||||
end
|
||||
-- for all writable connections, do the same
|
||||
for _, who in ipairs(writable) do
|
||||
sending:remove(who)
|
||||
if context[who] then
|
||||
route(who, coroutine.resume(context[who].thread, who))
|
||||
end
|
||||
coroutine.resume(context[who].thread, who)
|
||||
end
|
||||
-- put all inactive threads in death row
|
||||
local now = socket.gettime()
|
||||
local deathrow
|
||||
for who, data in pairs(context) do
|
||||
if data.last then
|
||||
if data.peer then
|
||||
if now - data.last > TIMEOUT then
|
||||
-- only create table if someone is doomed
|
||||
-- only create table if at least one is doomed
|
||||
deathrow = deathrow or {}
|
||||
deathrow[who] = true
|
||||
end
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue