SocketHandler.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375
  1. # Written by Bram Cohen
  2. # see LICENSE.txt for license information
  3. import socket
  4. from errno import EWOULDBLOCK, ECONNREFUSED, EHOSTUNREACH
  5. try:
  6. from select import poll, error, POLLIN, POLLOUT, POLLERR, POLLHUP
  7. timemult = 1000
  8. except ImportError:
  9. from selectpoll import poll, error, POLLIN, POLLOUT, POLLERR, POLLHUP
  10. timemult = 1
  11. from time import sleep
  12. from clock import clock
  13. import sys
  14. from random import shuffle, randrange
  15. from natpunch import UPnP_open_port, UPnP_close_port
  16. # from BT1.StreamCheck import StreamCheck
  17. # import inspect
  18. try:
  19. True
  20. except:
  21. True = 1
  22. False = 0
  23. all = POLLIN | POLLOUT
  24. UPnP_ERROR = "unable to forward port via UPnP"
  25. class SingleSocket:
  26. def __init__(self, socket_handler, sock, handler, ip = None):
  27. self.socket_handler = socket_handler
  28. self.socket = sock
  29. self.handler = handler
  30. self.buffer = []
  31. self.last_hit = clock()
  32. self.fileno = sock.fileno()
  33. self.connected = False
  34. self.skipped = 0
  35. # self.check = StreamCheck()
  36. try:
  37. self.ip = self.socket.getpeername()[0]
  38. except:
  39. if ip is None:
  40. self.ip = 'unknown'
  41. else:
  42. self.ip = ip
  43. def get_ip(self, real=False):
  44. if real:
  45. try:
  46. self.ip = self.socket.getpeername()[0]
  47. except:
  48. pass
  49. return self.ip
  50. def close(self):
  51. '''
  52. for x in xrange(5,0,-1):
  53. try:
  54. f = inspect.currentframe(x).f_code
  55. print (f.co_filename,f.co_firstlineno,f.co_name)
  56. del f
  57. except:
  58. pass
  59. print ''
  60. '''
  61. assert self.socket
  62. self.connected = False
  63. sock = self.socket
  64. self.socket = None
  65. self.buffer = []
  66. del self.socket_handler.single_sockets[self.fileno]
  67. self.socket_handler.poll.unregister(sock)
  68. sock.close()
  69. def shutdown(self, val):
  70. self.socket.shutdown(val)
  71. def is_flushed(self):
  72. return not self.buffer
  73. def write(self, s):
  74. # self.check.write(s)
  75. assert self.socket is not None
  76. self.buffer.append(s)
  77. if len(self.buffer) == 1:
  78. self.try_write()
  79. def try_write(self):
  80. if self.connected:
  81. dead = False
  82. try:
  83. while self.buffer:
  84. buf = self.buffer[0]
  85. amount = self.socket.send(buf)
  86. if amount == 0:
  87. self.skipped += 1
  88. break
  89. self.skipped = 0
  90. if amount != len(buf):
  91. self.buffer[0] = buf[amount:]
  92. break
  93. del self.buffer[0]
  94. except socket.error, e:
  95. try:
  96. dead = e[0] != EWOULDBLOCK
  97. except:
  98. dead = True
  99. self.skipped += 1
  100. if self.skipped >= 3:
  101. dead = True
  102. if dead:
  103. self.socket_handler.dead_from_write.append(self)
  104. return
  105. if self.buffer:
  106. self.socket_handler.poll.register(self.socket, all)
  107. else:
  108. self.socket_handler.poll.register(self.socket, POLLIN)
  109. def set_handler(self, handler):
  110. self.handler = handler
  111. class SocketHandler:
  112. def __init__(self, timeout, ipv6_enable, readsize = 100000):
  113. self.timeout = timeout
  114. self.ipv6_enable = ipv6_enable
  115. self.readsize = readsize
  116. self.poll = poll()
  117. # {socket: SingleSocket}
  118. self.single_sockets = {}
  119. self.dead_from_write = []
  120. self.max_connects = 1000
  121. self.port_forwarded = None
  122. self.servers = {}
  123. def scan_for_timeouts(self):
  124. t = clock() - self.timeout
  125. tokill = []
  126. for s in self.single_sockets.values():
  127. if s.last_hit < t:
  128. tokill.append(s)
  129. for k in tokill:
  130. if k.socket is not None:
  131. self._close_socket(k)
  132. def bind(self, port, bind = '', reuse = False, ipv6_socket_style = 1, upnp = 0):
  133. port = int(port)
  134. addrinfos = []
  135. self.servers = {}
  136. self.interfaces = []
  137. # if bind != "" thread it as a comma seperated list and bind to all
  138. # addresses (can be ips or hostnames) else bind to default ipv6 and
  139. # ipv4 address
  140. if bind:
  141. if self.ipv6_enable:
  142. socktype = socket.AF_UNSPEC
  143. else:
  144. socktype = socket.AF_INET
  145. bind = bind.split(',')
  146. for addr in bind:
  147. if sys.version_info < (2,2):
  148. addrinfos.append((socket.AF_INET, None, None, None, (addr, port)))
  149. else:
  150. addrinfos.extend(socket.getaddrinfo(addr, port,
  151. socktype, socket.SOCK_STREAM))
  152. else:
  153. if self.ipv6_enable:
  154. addrinfos.append([socket.AF_INET6, None, None, None, ('', port)])
  155. if not addrinfos or ipv6_socket_style != 0:
  156. addrinfos.append([socket.AF_INET, None, None, None, ('', port)])
  157. for addrinfo in addrinfos:
  158. try:
  159. server = socket.socket(addrinfo[0], socket.SOCK_STREAM)
  160. if reuse:
  161. server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
  162. server.setblocking(0)
  163. server.bind(addrinfo[4])
  164. self.servers[server.fileno()] = server
  165. if bind:
  166. self.interfaces.append(server.getsockname()[0])
  167. server.listen(64)
  168. self.poll.register(server, POLLIN)
  169. except socket.error, e:
  170. for server in self.servers.values():
  171. try:
  172. server.close()
  173. except:
  174. pass
  175. if self.ipv6_enable and ipv6_socket_style == 0 and self.servers:
  176. raise socket.error('blocked port (may require ipv6_binds_v4 to be set)')
  177. raise socket.error(str(e))
  178. if not self.servers:
  179. raise socket.error('unable to open server port')
  180. if upnp:
  181. if not UPnP_open_port(port):
  182. for server in self.servers.values():
  183. try:
  184. server.close()
  185. except:
  186. pass
  187. self.servers = None
  188. self.interfaces = None
  189. raise socket.error(UPnP_ERROR)
  190. self.port_forwarded = port
  191. self.port = port
  192. def find_and_bind(self, minport, maxport, bind = '', reuse = False,
  193. ipv6_socket_style = 1, upnp = 0, randomizer = False):
  194. e = 'maxport less than minport - no ports to check'
  195. if maxport-minport < 50 or not randomizer:
  196. portrange = range(minport, maxport+1)
  197. if randomizer:
  198. shuffle(portrange)
  199. portrange = portrange[:20] # check a maximum of 20 ports
  200. else:
  201. portrange = []
  202. while len(portrange) < 20:
  203. listen_port = randrange(minport, maxport+1)
  204. if not listen_port in portrange:
  205. portrange.append(listen_port)
  206. for listen_port in portrange:
  207. try:
  208. self.bind(listen_port, bind,
  209. ipv6_socket_style = ipv6_socket_style, upnp = upnp)
  210. return listen_port
  211. except socket.error, e:
  212. pass
  213. raise socket.error(str(e))
  214. def set_handler(self, handler):
  215. self.handler = handler
  216. def start_connection_raw(self, dns, socktype = socket.AF_INET, handler = None):
  217. if handler is None:
  218. handler = self.handler
  219. sock = socket.socket(socktype, socket.SOCK_STREAM)
  220. sock.setblocking(0)
  221. try:
  222. sock.connect_ex(dns)
  223. except socket.error:
  224. raise
  225. except Exception, e:
  226. raise socket.error(str(e))
  227. self.poll.register(sock, POLLIN)
  228. s = SingleSocket(self, sock, handler, dns[0])
  229. self.single_sockets[sock.fileno()] = s
  230. return s
  231. def start_connection(self, dns, handler = None, randomize = False):
  232. if handler is None:
  233. handler = self.handler
  234. if sys.version_info < (2,2):
  235. s = self.start_connection_raw(dns,socket.AF_INET,handler)
  236. else:
  237. if self.ipv6_enable:
  238. socktype = socket.AF_UNSPEC
  239. else:
  240. socktype = socket.AF_INET
  241. try:
  242. addrinfos = socket.getaddrinfo(dns[0], int(dns[1]),
  243. socktype, socket.SOCK_STREAM)
  244. except socket.error, e:
  245. raise
  246. except Exception, e:
  247. raise socket.error(str(e))
  248. if randomize:
  249. shuffle(addrinfos)
  250. for addrinfo in addrinfos:
  251. try:
  252. s = self.start_connection_raw(addrinfo[4],addrinfo[0],handler)
  253. break
  254. except:
  255. pass
  256. else:
  257. raise socket.error('unable to connect')
  258. return s
  259. def _sleep(self):
  260. sleep(1)
  261. def handle_events(self, events):
  262. for sock, event in events:
  263. s = self.servers.get(sock)
  264. if s:
  265. if event & (POLLHUP | POLLERR) != 0:
  266. self.poll.unregister(s)
  267. s.close()
  268. del self.servers[sock]
  269. print "lost server socket"
  270. elif len(self.single_sockets) < self.max_connects:
  271. try:
  272. newsock, addr = s.accept()
  273. newsock.setblocking(0)
  274. nss = SingleSocket(self, newsock, self.handler)
  275. self.single_sockets[newsock.fileno()] = nss
  276. self.poll.register(newsock, POLLIN)
  277. self.handler.external_connection_made(nss)
  278. except socket.error:
  279. self._sleep()
  280. else:
  281. s = self.single_sockets.get(sock)
  282. if not s:
  283. continue
  284. s.connected = True
  285. if (event & (POLLHUP | POLLERR)):
  286. self._close_socket(s)
  287. continue
  288. if (event & POLLIN):
  289. try:
  290. s.last_hit = clock()
  291. data = s.socket.recv(self.readsize)
  292. if not data:
  293. self._close_socket(s)
  294. else:
  295. s.handler.data_came_in(s, data)
  296. except socket.error, e:
  297. code, msg = e
  298. if code != EWOULDBLOCK:
  299. self._close_socket(s)
  300. continue
  301. if (event & POLLOUT) and s.socket and not s.is_flushed():
  302. s.try_write()
  303. if s.is_flushed():
  304. s.handler.connection_flushed(s)
  305. def close_dead(self):
  306. while self.dead_from_write:
  307. old = self.dead_from_write
  308. self.dead_from_write = []
  309. for s in old:
  310. if s.socket:
  311. self._close_socket(s)
  312. def _close_socket(self, s):
  313. s.close()
  314. s.handler.connection_lost(s)
  315. def do_poll(self, t):
  316. r = self.poll.poll(t*timemult)
  317. if r is None:
  318. connects = len(self.single_sockets)
  319. to_close = int(connects*0.05)+1 # close 5% of sockets
  320. self.max_connects = connects-to_close
  321. closelist = self.single_sockets.values()
  322. shuffle(closelist)
  323. closelist = closelist[:to_close]
  324. for sock in closelist:
  325. self._close_socket(sock)
  326. return []
  327. return r
  328. def get_stats(self):
  329. return { 'interfaces': self.interfaces,
  330. 'port': self.port,
  331. 'upnp': self.port_forwarded is not None }
  332. def shutdown(self):
  333. for ss in self.single_sockets.values():
  334. try:
  335. ss.close()
  336. except:
  337. pass
  338. for server in self.servers.values():
  339. try:
  340. server.close()
  341. except:
  342. pass
  343. if self.port_forwarded is not None:
  344. UPnP_close_port(self.port_forwarded)