| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333 |
- # usage:
- #
- # from twisted.internet import reactor
- # from ConnectionRateLimitReactor import connectionRateLimitReactor
- # connectionRateLimitReactor(reactor, max_incomplete=10)
- #
- # The contents of this file are subject to the Python Software Foundation
- # License Version 2.3 (the License). You may not copy or use this file, in
- # either source code or executable form, except in compliance with the License.
- # You may obtain a copy of the License at http://www.python.org/license.
- #
- # Software distributed under the License is distributed on an AS IS basis,
- # WITHOUT WARRANTY OF ANY KIND, either express or implied. See the License
- # for the specific language governing rights and limitations under the
- # License.
- #
- # by Greg Hazel
- import random
- import threading
- from twisted.python import failure
- from twisted.python import threadable
- from twisted.internet import error, address, abstract
- from BTL.circular_list import CircularList
- from BTL.Lists import QList
- from BTL.decorate import decorate_func
- debug = False
- class HookedFactory(object):
-
- def __init__(self, connector, factory):
- self.connector = connector
- self.factory = factory
- def clientConnectionFailed(self, connector, reason):
- if self.connector._started:
- self.connector.complete()
- return self.factory.clientConnectionFailed(connector, reason)
- def buildProtocol(self, addr):
- p = self.factory.buildProtocol(addr)
- p.connectionMade = decorate_func(self.connector.complete,
- p.connectionMade)
- return p
- def __getattr__(self, attr):
- return getattr(self.factory, attr)
-
- class IRobotConnector(object):
- # I did this to be nice, but zope sucks.
- ##implements(interfaces.IConnector)
- def __init__(self, reactor, protocol, host, port, factory, owner, urgent,
- *a, **kw):
- self.reactor = reactor
- self.protocol = protocol
- assert self.protocol in ('INET', 'SSL')
- self.host = host
- self.port = port
- self.owner = owner
- self.urgent = urgent
- self.a = a
- self.kw = kw
- self.connector = None
- self._started = False
- self.preempted = False
- self.factory = HookedFactory(self, factory)
- def started(self):
- if self._started:
- raise ValueError("Connector is already started!")
- self._started = True
- self.reactor.add_pending_connection(self.host, self)
-
- def disconnect(self):
- if self._started:
- return self.connector.disconnect()
- return self.stopConnecting()
- def _cleanup(self):
- if hasattr(self, 'a'):
- del self.a
- if hasattr(self, 'kw'):
- del self.kw
- if hasattr(self, 'factory'):
- del self.factory
- if hasattr(self, 'connector'):
- del self.connector
-
- def stopConnecting(self):
- if self._started:
- self.connector.stopConnecting()
- self._cleanup()
- return
- self.reactor.drop_postponed(self)
- # for accuracy
- self.factory.startedConnecting(self)
- abort = failure.Failure(error.UserError(string="Connection preempted"))
- self.factory.clientConnectionFailed(self, abort)
- self._cleanup()
-
- def connect(self):
- if debug: print 'connecting', self.host, self.port
- self.started()
- try:
- if self.protocol == 'SSL':
- self.connector = self.reactor.old_connectSSL(self.host,
- self.port,
- self.factory,
- *self.a, **self.kw)
- else:
- self.connector = self.reactor.old_connectTCP(self.host,
- self.port,
- self.factory,
- *self.a, **self.kw)
- # because other callbacks use this one
- self.connector.wasPreempted = self.wasPreempted
- except:
- # make sure failures get removed before we raise
- self.complete()
- raise
- # if connect is re-called on the connector, we want to restart
- self.connector.connect = decorate_func(self.started,
- self.connector.connect)
- return self
- def wasPreempted(self):
- return self.preempted
- def complete(self):
- if not self._started:
- return
- self._started = False
- self.reactor._remove_pending_connection(self.host, self)
- self._cleanup()
- def getDestination(self):
- return address.IPv4Address('TCP', self.host, self.port, self.protocol)
- class Postponed(CircularList):
- def __init__(self):
- CircularList.__init__(self)
- self.it = iter(self)
- self.preempt = QList()
- self.cm_to_list = {}
- def __len__(self):
- l = 0
- for k, v in self.cm_to_list.iteritems():
- l += len(v)
- l += len(self.preempt)
- return l
- def append_preempt(self, c):
- return self.preempt.append(c)
-
- def add_connection(self, keyable, c):
- if keyable not in self.cm_to_list:
- self.cm_to_list[keyable] = QList()
- self.prepend(keyable)
- self.cm_to_list[keyable].append(c)
- def pop_connection(self):
- if self.preempt:
- return self.preempt.popleft()
- keyable = self.it.next()
- l = self.cm_to_list[keyable]
- c = l.popleft()
- if len(l) == 0:
- self.remove(keyable)
- del self.cm_to_list[keyable]
- return c
- def remove_connection(self, keyable, c):
- # hmmm
- if c.urgent:
- self.preempt.remove(c)
- return
- l = self.cm_to_list[keyable]
- l.remove(c)
- if len(l) == 0:
- self.remove(keyable)
- del self.cm_to_list[keyable]
- class ConnectionRateLimiter(object):
-
- def __init__(self, reactor, max_incomplete):
- self.reactor = reactor
- self.postponed = Postponed()
- self.max_incomplete = max_incomplete
- # this can go away when urllib does
- self.halfopen_hosts_lock = threading.RLock()
- self.halfopen_hosts = {}
- self.old_connectTCP = self.reactor.connectTCP
- self.old_connectSSL = self.reactor.connectSSL
- if debug:
- from twisted.internet import task
- def p():
- print len(self.postponed), [ (k, len(v)) for k, v in self.halfopen_hosts.iteritems() ]
- assert len(self.halfopen_hosts) <= self.max_incomplete
- task.LoopingCall(p).start(1)
- # safe from any thread
- def add_pending_connection(self, host, connector=None):
- if debug: print 'adding', host, 'IOthread', threadable.isInIOThread()
- self.halfopen_hosts_lock.acquire()
- self.halfopen_hosts.setdefault(host, []).append(connector)
- self.halfopen_hosts_lock.release()
- # thread footwork, because _remove actually starts new connections
- def remove_pending_connection(self, host, connector=None):
- if not threadable.isInIOThread():
- self.reactor.callFromThread(self._remove_pending_connection,
- host, connector)
- else:
- self._remove_pending_connection(host, connector)
- def _remove_pending_connection(self, host, connector=None):
- if debug: print 'removing', host
- self.halfopen_hosts_lock.acquire()
- self.halfopen_hosts[host].remove(connector)
- if len(self.halfopen_hosts[host]) == 0:
- del self.halfopen_hosts[host]
- self._push_new_connections()
- self.halfopen_hosts_lock.release()
- def _push_new_connections(self):
- if not self.postponed:
- return
- c = self.postponed.pop_connection()
- self._connect(c)
- def drop_postponed(self, c):
- self.postponed.remove_connection(c.owner, c)
- def _preempt_for(self, c):
- if debug: print '\npreempting for', c.host, c.port, '\n'
- self.postponed.append_preempt(c)
-
- sorted = []
- for connectors in self.halfopen_hosts.itervalues():
- # drop hosts with connectors that have no handle (urllib)
- # drop hosts with any urgent connectors
- can_preempt = True
- for s in connectors:
- if not s or s.urgent:
- can_preempt = False
- break
- if not can_preempt:
- continue
-
- sorted.append((len(connectors), connectors))
- if len(sorted) == 0:
- # give up. no hosts can be interrupted
- return
- # find the host with least connectors to interrupt
- sorted.sort()
- connectors = sorted[0][1]
-
- for s in connectors:
- s.preempted = True
- if debug: print 'preempting', s.host, s.port
- s.disconnect()
-
- def _resolve_then_connect(self, c):
- if abstract.isIPAddress(c.host):
- self._connect(c)
- return c
- df = self.reactor.resolve(c.host)
- if debug: print 'resolving', c.host
- def set_host(ip):
- if debug: print 'resolved', c.host, ip
- c.host = ip
- self._connect(c)
- def error(f):
- # too lazy to figure out how to fail properly, so just connect
- self._connect(c)
- df.addCallbacks(set_host, error)
- return c
- def _connect(self, c):
- # the XP connection rate limiting is unique at the IP level
- if (len(self.halfopen_hosts) >= self.max_incomplete and
- c.host not in self.halfopen_hosts):
- if debug: print 'postponing', c.host, c.port
- if c.urgent:
- self._preempt_for(c)
- else:
- self.postponed.add_connection(c.owner, c)
- else:
- c.connect()
- return c
- def connectTCP(self, host, port, factory,
- timeout=30, bindAddress=None, owner=None, urgent=True):
- c = IRobotConnector(self, 'INET', host, port, factory, owner, urgent,
- timeout, bindAddress)
- self._resolve_then_connect(c)
- return c
- def connectSSL(self, host, port, factory, contextFactory,
- timeout=30, bindAddress=None, owner=None, urgent=True):
- c = IRobotConnector(self, 'SSL', host, port, factory, owner, urgent,
- contextFactory, timeout, bindAddress)
- self._resolve_then_connect(c)
- return c
- def connectionRateLimitReactor(reactor, max_incomplete):
- if (hasattr(reactor, 'limiter') and
- reactor.limiter.max_incomplete != max_incomplete):
- print 'Changing max_incomplete for ConnectionRateLimiterReactor!'
- reactor.limiter.max_incomplete = max_incomplete
- else:
- limiter = ConnectionRateLimiter(reactor, max_incomplete)
- reactor.connectTCP = limiter.connectTCP
- reactor.connectSSL = limiter.connectSSL
- reactor.add_pending_connection = limiter.add_pending_connection
- reactor.remove_pending_connection = limiter.remove_pending_connection
- reactor.limiter = limiter
|