NewRateLimiter.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345
  1. # This was built to be like SFQ but turned out like round-robin.
  2. # (Why didn't you just use Deficit Round Robin?) --Dave
  3. # (Because of unitsize) --Greg
  4. #
  5. # I call it Heirarchical Round Robin Bucket Percentage Style
  6. #
  7. # by Greg Hazel
  8. import time
  9. import traceback
  10. from BTL.platform import bttime
  11. from BTL.DictWithLists import OrderedDictWithLists
  12. # these are for logging and such
  13. class GlobalRate(object):
  14. def __init__(self):
  15. self.total = 0.0
  16. self.start_time = bttime()
  17. self.last_time = self.start_time
  18. def print_rate(self, size):
  19. self.total += size
  20. this_time = bttime()
  21. start_delta = this_time - self.start_time
  22. this_delta = this_time - self.last_time
  23. if start_delta > 0 and this_delta > 0:
  24. print "UPLOAD: This:", size / this_delta, "Total:", self.total / start_delta
  25. self.last_time = this_time
  26. global_rate = GlobalRate()
  27. # very simple.
  28. # every call gives you the duration since the last call in tokens
  29. class DeltaTokens(object):
  30. def __init__(self, rate):
  31. self.set_rate(rate)
  32. def set_rate(self, rate):
  33. self.rate = rate
  34. # clear the history since the rate has changed and it could be way off
  35. self.last_time = bttime()
  36. # return the number of tokens you can have since the last call
  37. def __call__(self):
  38. new_time = bttime()
  39. delta_time = new_time - self.last_time
  40. # if last time was more than a second ago, we can't give a clear
  41. # approximation since rate is in tokens per second.
  42. delta_time = min(delta_time, 1.0)
  43. if delta_time <= 0:
  44. return 0
  45. tokens = self.rate * delta_time
  46. self.last_time = new_time
  47. return tokens
  48. # allows you to subtract tokens from DeltaTokens to compensate
  49. def remove_tokens(self, x):
  50. if self.rate == 0:
  51. # shit, I don't know.
  52. self.last_time += x
  53. else:
  54. self.last_time += x / self.rate
  55. # returns the time until you'll get tokens again
  56. def get_remaining_time(self):
  57. return max(0, self.last_time - bttime())
  58. class Classifer(object):
  59. def __init__(self):
  60. self.channels = OrderedDictWithLists()
  61. def add_data(self, keyable, func):
  62. # hmm, this should rotate every 10 seconds or so, but moving over the
  63. # old data is hard (can't write out-of-order)
  64. #key = sha.sha(id(o)).hexdigest()[0]
  65. # this is technically round-robin
  66. key = keyable
  67. self.channels.push_to_row(key, func)
  68. def rem_data(self, key):
  69. try:
  70. l = self.channels.poprow(key)
  71. l.clear()
  72. except KeyError:
  73. pass
  74. def rotate_data(self):
  75. # the removes the top-most row from the ordereddict
  76. k = self.channels.iterkeys().next()
  77. l = self.channels.poprow(k)
  78. data = l.popleft()
  79. # this puts the whole row at the bottom of the ordereddict
  80. self.channels.setrow(k, l)
  81. return data
  82. def __len__(self):
  83. return len(self.channels)
  84. class Scheduler(object):
  85. def __init__(self, rate, add_task):
  86. """@param rate: rate at which 'tokens' are generated.
  87. @param add_task: callback to schedule an event.
  88. """
  89. self.add_task = add_task
  90. self.classifier = Classifer()
  91. self.delta_tokens = DeltaTokens(rate)
  92. self.task = None
  93. self.children = {}
  94. def set_rate(self, rate, cascade=True):
  95. self.delta_tokens.set_rate(rate)
  96. if cascade:
  97. for child, scale in self.children.iteritems():
  98. child.set_rate(rate * scale)
  99. # the rate changed, so it's possible the loop is
  100. # running slower than it needs to
  101. self.restart_loop(0)
  102. def add_child(self, child, scale):
  103. self.children[child] = scale
  104. child.set_rate(self.delta_tokens.rate * scale)
  105. def remove_child(self, child):
  106. del self.children[child]
  107. def add_data(self, keyable, func):
  108. self.classifier.add_data(keyable, func)
  109. # kick off a loop since we have data now
  110. self.restart_loop(0)
  111. def restart_loop(self, t):
  112. # check for pending loop event
  113. if self.task and not self.task.called:
  114. ## look at when it's scheduled to occur
  115. # we can special case events which have a delta of 0, since they
  116. # should occur asap. no need to check the time.
  117. if self.task.delta == 0:
  118. return
  119. # use time.time since twisted does anyway
  120. s = self.task.getTime() - time.time()
  121. if s > t:
  122. # if it would occur after the time we want, reset it
  123. self.task.reset(t)
  124. self.task.delta = t
  125. else:
  126. if t == 0:
  127. # don't spin the event loop needlessly
  128. self.run()
  129. else:
  130. self.task = self.add_task(t, self.run)
  131. self.task.delta = t
  132. def _write(self, to_write):
  133. amount = 0
  134. each = min(self.delta_tokens.rate, self.unitsize)
  135. if self.children:
  136. for child, scale in self.children.iteritems():
  137. child.set_rate(self.delta_tokens.rate * scale, cascade=False)
  138. i = 0
  139. while amount < to_write and len(self.classifier) > 0:
  140. (func, args) = self.classifier.rotate_data()
  141. # ERROR: func can fill buffers, so use the on_flush technique
  142. try:
  143. amount += func(each)
  144. except:
  145. # don't stop the loop if we hit an error
  146. traceback.print_exc()
  147. i += 1
  148. if i == len(self.children):
  149. break
  150. for child, scale in self.children.iteritems():
  151. # really max, but we happen to know it can't exceed amount
  152. child.set_rate(amount, cascade=False)
  153. while amount < to_write and len(self.classifier) > 0:
  154. func = self.classifier.rotate_data()
  155. # ERROR: func can fill buffers, so use the on_flush technique
  156. try:
  157. amount += func(each)
  158. except:
  159. # don't stop the loop if we hit an error
  160. traceback.print_exc()
  161. return amount
  162. def _run_once(self):
  163. f_to_write = self.delta_tokens()
  164. to_write = int(f_to_write)
  165. if to_write == 0:
  166. written = 0
  167. else:
  168. written = self._write(to_write)
  169. # for debugging
  170. #print "Ideal:", self.delta_tokens.rate, f_to_write
  171. #global_rate.print_rate(written)
  172. self.delta_tokens.remove_tokens(written - f_to_write)
  173. return written
  174. def run(self):
  175. t = 0
  176. while t == 0:
  177. if len(self.classifier) == 0:
  178. return
  179. self._run_once()
  180. t = self.delta_tokens.get_remaining_time()
  181. self.restart_loop(t)
  182. # made to look like the original
  183. class MultiRateLimiter(Scheduler):
  184. # Since data is sent to peers in a round-robin fashion, max one
  185. # full request at a time, setting this higher would send more data
  186. # to peers that use request sizes larger than standard 16 KiB.
  187. # 17000 instead of 16384 to allow room for metadata messages.
  188. max_unitsize = 17000
  189. def __init__(self, sched, parent=None):
  190. Scheduler.__init__(self, rate = 0, add_task = sched)
  191. if parent == None:
  192. self.run()
  193. def set_parameters(self, rate, unitsize=2**500):
  194. self.set_rate(rate)
  195. unitsize = min(unitsize, self.max_unitsize)
  196. self.unitsize = unitsize
  197. def queue(self, conn):
  198. keyable = conn
  199. self.add_data(keyable, conn.send_partial)
  200. def dequeue(self, keyable):
  201. self.classifier.rem_data(keyable)
  202. def increase_offset(self, bytes):
  203. # hackity hack hack
  204. self.delta_tokens.remove_tokens(0 - bytes)
  205. class FakeConnection(object):
  206. def __init__(self, gr):
  207. self.gr = gr
  208. def _use_length_(self, length):
  209. def do():
  210. return length
  211. return self.write(do)
  212. def write(self, fn, *args):
  213. size = fn(*args)
  214. self.gr.print_rate(size)
  215. return size
  216. if __name__ == '__main__':
  217. profile = True
  218. try:
  219. from BTL.profile import Profiler, Stats
  220. prof_file_name = 'NewRateLimiter.prof'
  221. except ImportError, e:
  222. print "profiling not available:", e
  223. profile = False
  224. import os
  225. import random
  226. from RawServer_twisted import RawServer
  227. from twisted.internet import task
  228. from BTL.defer import DeferredEvent
  229. rawserver = RawServer()
  230. s = Scheduler(4096, add_task = rawserver.add_task)
  231. s.unitsize = 17000
  232. a = []
  233. for i in xrange(500):
  234. keyable = FakeConnection(global_rate)
  235. a.append(keyable)
  236. freq = 0.01
  237. def push():
  238. if random.randint(0, 5 / freq) == 0:
  239. rate = random.randint(1, 100) * 1000
  240. print "new rate", rate
  241. s.set_rate(rate)
  242. for c in a:
  243. s.add_data(c, c._use_length_)
  244. t = task.LoopingCall(push)
  245. t.start(freq)
  246. ## m = MultiRateLimiter(sched=rawserver.add_task)
  247. ## m.set_parameters(120000000)
  248. ## class C(object):
  249. ## def send_partial(self, size):
  250. ## global_rate.print_rate(size)
  251. ## rawserver.add_task(0, m.queue, self)
  252. ## return size
  253. ##
  254. ## m.queue(C())
  255. if profile:
  256. try:
  257. os.unlink(prof_file_name)
  258. except:
  259. pass
  260. prof = Profiler()
  261. prof.enable()
  262. rawserver.listen_forever()
  263. if profile:
  264. prof.disable()
  265. st = Stats(prof.getstats())
  266. st.sort()
  267. f = open(prof_file_name, 'wb')
  268. st.dump(file=f)