1
0

ConnectionRateLimitReactor.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. # usage:
  2. #
  3. # from twisted.internet import reactor
  4. # from ConnectionRateLimitReactor import connectionRateLimitReactor
  5. # connectionRateLimitReactor(reactor, max_incomplete=10)
  6. #
  7. # The contents of this file are subject to the Python Software Foundation
  8. # License Version 2.3 (the License). You may not copy or use this file, in
  9. # either source code or executable form, except in compliance with the License.
  10. # You may obtain a copy of the License at http://www.python.org/license.
  11. #
  12. # Software distributed under the License is distributed on an AS IS basis,
  13. # WITHOUT WARRANTY OF ANY KIND, either express or implied. See the License
  14. # for the specific language governing rights and limitations under the
  15. # License.
  16. #
  17. # by Greg Hazel
  18. import random
  19. import threading
  20. from twisted.python import failure
  21. from twisted.python import threadable
  22. from twisted.internet import error, address, abstract
  23. from BTL.circular_list import CircularList
  24. from BTL.Lists import QList
  25. from BTL.decorate import decorate_func
  26. debug = False
  27. class HookedFactory(object):
  28. def __init__(self, connector, factory):
  29. self.connector = connector
  30. self.factory = factory
  31. def clientConnectionFailed(self, connector, reason):
  32. if self.connector._started:
  33. self.connector.complete()
  34. return self.factory.clientConnectionFailed(connector, reason)
  35. def buildProtocol(self, addr):
  36. p = self.factory.buildProtocol(addr)
  37. p.connectionMade = decorate_func(self.connector.complete,
  38. p.connectionMade)
  39. return p
  40. def __getattr__(self, attr):
  41. return getattr(self.factory, attr)
  42. class IRobotConnector(object):
  43. # I did this to be nice, but zope sucks.
  44. ##implements(interfaces.IConnector)
  45. def __init__(self, reactor, protocol, host, port, factory, owner, urgent,
  46. *a, **kw):
  47. self.reactor = reactor
  48. self.protocol = protocol
  49. assert self.protocol in ('INET', 'SSL')
  50. self.host = host
  51. self.port = port
  52. self.owner = owner
  53. self.urgent = urgent
  54. self.a = a
  55. self.kw = kw
  56. self.connector = None
  57. self._started = False
  58. self.preempted = False
  59. self.factory = HookedFactory(self, factory)
  60. def started(self):
  61. if self._started:
  62. raise ValueError("Connector is already started!")
  63. self._started = True
  64. self.reactor.add_pending_connection(self.host, self)
  65. def disconnect(self):
  66. if self._started:
  67. return self.connector.disconnect()
  68. return self.stopConnecting()
  69. def _cleanup(self):
  70. if hasattr(self, 'a'):
  71. del self.a
  72. if hasattr(self, 'kw'):
  73. del self.kw
  74. if hasattr(self, 'factory'):
  75. del self.factory
  76. if hasattr(self, 'connector'):
  77. del self.connector
  78. def stopConnecting(self):
  79. if self._started:
  80. self.connector.stopConnecting()
  81. self._cleanup()
  82. return
  83. self.reactor.drop_postponed(self)
  84. # for accuracy
  85. self.factory.startedConnecting(self)
  86. abort = failure.Failure(error.UserError(string="Connection preempted"))
  87. self.factory.clientConnectionFailed(self, abort)
  88. self._cleanup()
  89. def connect(self):
  90. if debug: print 'connecting', self.host, self.port
  91. self.started()
  92. try:
  93. if self.protocol == 'SSL':
  94. self.connector = self.reactor.old_connectSSL(self.host,
  95. self.port,
  96. self.factory,
  97. *self.a, **self.kw)
  98. else:
  99. self.connector = self.reactor.old_connectTCP(self.host,
  100. self.port,
  101. self.factory,
  102. *self.a, **self.kw)
  103. # because other callbacks use this one
  104. self.connector.wasPreempted = self.wasPreempted
  105. except:
  106. # make sure failures get removed before we raise
  107. self.complete()
  108. raise
  109. # if connect is re-called on the connector, we want to restart
  110. self.connector.connect = decorate_func(self.started,
  111. self.connector.connect)
  112. return self
  113. def wasPreempted(self):
  114. return self.preempted
  115. def complete(self):
  116. if not self._started:
  117. return
  118. self._started = False
  119. self.reactor._remove_pending_connection(self.host, self)
  120. self._cleanup()
  121. def getDestination(self):
  122. return address.IPv4Address('TCP', self.host, self.port, self.protocol)
  123. class Postponed(CircularList):
  124. def __init__(self):
  125. CircularList.__init__(self)
  126. self.it = iter(self)
  127. self.preempt = QList()
  128. self.cm_to_list = {}
  129. def __len__(self):
  130. l = 0
  131. for k, v in self.cm_to_list.iteritems():
  132. l += len(v)
  133. l += len(self.preempt)
  134. return l
  135. def append_preempt(self, c):
  136. return self.preempt.append(c)
  137. def add_connection(self, keyable, c):
  138. if keyable not in self.cm_to_list:
  139. self.cm_to_list[keyable] = QList()
  140. self.prepend(keyable)
  141. self.cm_to_list[keyable].append(c)
  142. def pop_connection(self):
  143. if self.preempt:
  144. return self.preempt.popleft()
  145. keyable = self.it.next()
  146. l = self.cm_to_list[keyable]
  147. c = l.popleft()
  148. if len(l) == 0:
  149. self.remove(keyable)
  150. del self.cm_to_list[keyable]
  151. return c
  152. def remove_connection(self, keyable, c):
  153. # hmmm
  154. if c.urgent:
  155. self.preempt.remove(c)
  156. return
  157. l = self.cm_to_list[keyable]
  158. l.remove(c)
  159. if len(l) == 0:
  160. self.remove(keyable)
  161. del self.cm_to_list[keyable]
  162. class ConnectionRateLimiter(object):
  163. def __init__(self, reactor, max_incomplete):
  164. self.reactor = reactor
  165. self.postponed = Postponed()
  166. self.max_incomplete = max_incomplete
  167. # this can go away when urllib does
  168. self.halfopen_hosts_lock = threading.RLock()
  169. self.halfopen_hosts = {}
  170. self.old_connectTCP = self.reactor.connectTCP
  171. self.old_connectSSL = self.reactor.connectSSL
  172. if debug:
  173. from twisted.internet import task
  174. def p():
  175. print len(self.postponed), [ (k, len(v)) for k, v in self.halfopen_hosts.iteritems() ]
  176. assert len(self.halfopen_hosts) <= self.max_incomplete
  177. task.LoopingCall(p).start(1)
  178. # safe from any thread
  179. def add_pending_connection(self, host, connector=None):
  180. if debug: print 'adding', host, 'IOthread', threadable.isInIOThread()
  181. self.halfopen_hosts_lock.acquire()
  182. self.halfopen_hosts.setdefault(host, []).append(connector)
  183. self.halfopen_hosts_lock.release()
  184. # thread footwork, because _remove actually starts new connections
  185. def remove_pending_connection(self, host, connector=None):
  186. if not threadable.isInIOThread():
  187. self.reactor.callFromThread(self._remove_pending_connection,
  188. host, connector)
  189. else:
  190. self._remove_pending_connection(host, connector)
  191. def _remove_pending_connection(self, host, connector=None):
  192. if debug: print 'removing', host
  193. self.halfopen_hosts_lock.acquire()
  194. self.halfopen_hosts[host].remove(connector)
  195. if len(self.halfopen_hosts[host]) == 0:
  196. del self.halfopen_hosts[host]
  197. self._push_new_connections()
  198. self.halfopen_hosts_lock.release()
  199. def _push_new_connections(self):
  200. if not self.postponed:
  201. return
  202. c = self.postponed.pop_connection()
  203. self._connect(c)
  204. def drop_postponed(self, c):
  205. self.postponed.remove_connection(c.owner, c)
  206. def _preempt_for(self, c):
  207. if debug: print '\npreempting for', c.host, c.port, '\n'
  208. self.postponed.append_preempt(c)
  209. sorted = []
  210. for connectors in self.halfopen_hosts.itervalues():
  211. # drop hosts with connectors that have no handle (urllib)
  212. # drop hosts with any urgent connectors
  213. can_preempt = True
  214. for s in connectors:
  215. if not s or s.urgent:
  216. can_preempt = False
  217. break
  218. if not can_preempt:
  219. continue
  220. sorted.append((len(connectors), connectors))
  221. if len(sorted) == 0:
  222. # give up. no hosts can be interrupted
  223. return
  224. # find the host with least connectors to interrupt
  225. sorted.sort()
  226. connectors = sorted[0][1]
  227. for s in connectors:
  228. s.preempted = True
  229. if debug: print 'preempting', s.host, s.port
  230. s.disconnect()
  231. def _resolve_then_connect(self, c):
  232. if abstract.isIPAddress(c.host):
  233. self._connect(c)
  234. return c
  235. df = self.reactor.resolve(c.host)
  236. if debug: print 'resolving', c.host
  237. def set_host(ip):
  238. if debug: print 'resolved', c.host, ip
  239. c.host = ip
  240. self._connect(c)
  241. def error(f):
  242. # too lazy to figure out how to fail properly, so just connect
  243. self._connect(c)
  244. df.addCallbacks(set_host, error)
  245. return c
  246. def _connect(self, c):
  247. # the XP connection rate limiting is unique at the IP level
  248. if (len(self.halfopen_hosts) >= self.max_incomplete and
  249. c.host not in self.halfopen_hosts):
  250. if debug: print 'postponing', c.host, c.port
  251. if c.urgent:
  252. self._preempt_for(c)
  253. else:
  254. self.postponed.add_connection(c.owner, c)
  255. else:
  256. c.connect()
  257. return c
  258. def connectTCP(self, host, port, factory,
  259. timeout=30, bindAddress=None, owner=None, urgent=True):
  260. c = IRobotConnector(self, 'INET', host, port, factory, owner, urgent,
  261. timeout, bindAddress)
  262. self._resolve_then_connect(c)
  263. return c
  264. def connectSSL(self, host, port, factory, contextFactory,
  265. timeout=30, bindAddress=None, owner=None, urgent=True):
  266. c = IRobotConnector(self, 'SSL', host, port, factory, owner, urgent,
  267. contextFactory, timeout, bindAddress)
  268. self._resolve_then_connect(c)
  269. return c
  270. def connectionRateLimitReactor(reactor, max_incomplete):
  271. if (hasattr(reactor, 'limiter') and
  272. reactor.limiter.max_incomplete != max_incomplete):
  273. print 'Changing max_incomplete for ConnectionRateLimiterReactor!'
  274. reactor.limiter.max_incomplete = max_incomplete
  275. else:
  276. limiter = ConnectionRateLimiter(reactor, max_incomplete)
  277. reactor.connectTCP = limiter.connectTCP
  278. reactor.connectSSL = limiter.connectSSL
  279. reactor.add_pending_connection = limiter.add_pending_connection
  280. reactor.remove_pending_connection = limiter.remove_pending_connection
  281. reactor.limiter = limiter