1
0

ktable.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341
  1. # The contents of this file are subject to the BitTorrent Open Source License
  2. # Version 1.1 (the License). You may not copy or use this file, in either
  3. # source code or executable form, except in compliance with the License. You
  4. # may obtain a copy of the License at http://www.bittorrent.com/license/.
  5. #
  6. # Software distributed under the License is distributed on an AS IS basis,
  7. # WITHOUT WARRANTY OF ANY KIND, either express or implied. See the License
  8. # for the specific language governing rights and limitations under the
  9. # License.
  10. from BTL.platform import bttime as time
  11. from bisect import *
  12. from types import *
  13. import khash as hash
  14. import const
  15. from const import K, HASH_LENGTH, NULL_ID, MAX_FAILURES, MIN_PING_INTERVAL
  16. from node import Node
  17. def ls(a, b):
  18. return cmp(a.lastSeen, b.lastSeen)
  19. class KTable(object):
  20. __slots__ = ('node', 'buckets')
  21. """local routing table for a kademlia like distributed hash table"""
  22. def __init__(self, node):
  23. # this is the root node, a.k.a. US!
  24. self.node = node
  25. self.buckets = [KBucket([], 0L, 2L**HASH_LENGTH)]
  26. self.insertNode(node)
  27. def _bucketIndexForInt(self, num):
  28. """the index of the bucket that should hold int"""
  29. return bisect_left(self.buckets, num)
  30. def bucketForInt(self, num):
  31. return self.buckets[self._bucketIndexForInt(num)]
  32. def findNodes(self, id, invalid=True):
  33. """
  34. return K nodes in our own local table closest to the ID.
  35. """
  36. if isinstance(id, str):
  37. num = hash.intify(id)
  38. elif isinstance(id, Node):
  39. num = id.num
  40. elif isinstance(id, int) or isinstance(id, long):
  41. num = id
  42. else:
  43. raise TypeError, "findNodes requires an int, string, or Node"
  44. nodes = []
  45. i = self._bucketIndexForInt(num)
  46. # if this node is already in our table then return it
  47. try:
  48. node = self.buckets[i].getNodeWithInt(num)
  49. except ValueError:
  50. pass
  51. else:
  52. return [node]
  53. # don't have the node, get the K closest nodes
  54. nodes = nodes + self.buckets[i].l
  55. if not invalid:
  56. nodes = [a for a in nodes if not a.invalid]
  57. if len(nodes) < K:
  58. # need more nodes
  59. min = i - 1
  60. max = i + 1
  61. while len(nodes) < K and (min >= 0 or max < len(self.buckets)):
  62. #ASw: note that this requires K be even
  63. if min >= 0:
  64. nodes = nodes + self.buckets[min].l
  65. if max < len(self.buckets):
  66. nodes = nodes + self.buckets[max].l
  67. min = min - 1
  68. max = max + 1
  69. if not invalid:
  70. nodes = [a for a in nodes if not a.invalid]
  71. nodes.sort(lambda a, b, num=num: cmp(num ^ a.num, num ^ b.num))
  72. return nodes[:K]
  73. def _splitBucket(self, a):
  74. diff = (a.max - a.min) / 2
  75. b = KBucket([], a.max - diff, a.max)
  76. self.buckets.insert(self.buckets.index(a.min) + 1, b)
  77. a.max = a.max - diff
  78. # transfer nodes to new bucket
  79. for anode in a.l[:]:
  80. if anode.num >= a.max:
  81. a.removeNode(anode)
  82. b.addNode(anode)
  83. def replaceStaleNode(self, stale, new):
  84. """this is used by clients to replace a node returned by insertNode after
  85. it fails to respond to a Pong message"""
  86. i = self._bucketIndexForInt(stale.num)
  87. if self.buckets[i].hasNode(stale):
  88. self.buckets[i].removeNode(stale)
  89. if new and self.buckets[i].hasNode(new):
  90. self.buckets[i].seenNode(new)
  91. elif new:
  92. self.buckets[i].addNode(new)
  93. return
  94. def insertNode(self, node, contacted=1, nocheck=False):
  95. """
  96. this insert the node, returning None if successful, returns the oldest node in the bucket if it's full
  97. the caller responsible for pinging the returned node and calling replaceStaleNode if it is found to be stale!!
  98. contacted means that yes, we contacted THEM and we know the node is reachable
  99. """
  100. if node.id == NULL_ID or node.id == self.node.id:
  101. return
  102. if contacted:
  103. node.updateLastSeen()
  104. # get the bucket for this node
  105. i = self._bucketIndexForInt(node.num)
  106. # check to see if node is in the bucket already
  107. if self.buckets[i].hasNode(node):
  108. it = self.buckets[i].l.index(node.num)
  109. xnode = self.buckets[i].l[it]
  110. if contacted:
  111. node.age = xnode.age
  112. self.buckets[i].seenNode(node)
  113. elif xnode.lastSeen != 0 and xnode.port == node.port and xnode.host == node.host:
  114. xnode.updateLastSeen()
  115. return
  116. # we don't have this node, check to see if the bucket is full
  117. if not self.buckets[i].bucketFull():
  118. # no, append this node and return
  119. self.buckets[i].addNode(node)
  120. return
  121. # full bucket, check to see if any nodes are invalid
  122. t = time()
  123. invalid = [x for x in self.buckets[i].invalid.values() if x.invalid]
  124. if len(invalid) and not nocheck:
  125. invalid.sort(ls)
  126. while invalid and not self.buckets[i].hasNode(invalid[0]):
  127. del(self.buckets[i].invalid[invalid[0].num])
  128. invalid = invalid[1:]
  129. if invalid and (invalid[0].lastSeen == 0 and invalid[0].fails < MAX_FAILURES):
  130. return invalid[0]
  131. elif invalid:
  132. self.replaceStaleNode(invalid[0], node)
  133. return
  134. stale = [n for n in self.buckets[i].l if (t - n.lastSeen) > MIN_PING_INTERVAL]
  135. if len(stale) and not nocheck:
  136. stale.sort(ls)
  137. return stale[0]
  138. # bucket is full and all nodes are valid, check to see if self.node is in the bucket
  139. if not (self.buckets[i].min <= self.node < self.buckets[i].max):
  140. return
  141. # this bucket is full and contains our node, split the bucket
  142. if len(self.buckets) >= HASH_LENGTH:
  143. # our table is FULL, this is really unlikely
  144. print "Hash Table is FULL! Increase K!"
  145. return
  146. self._splitBucket(self.buckets[i])
  147. # now that the bucket is split and balanced, try to insert the node again
  148. return self.insertNode(node, contacted)
  149. def justSeenNode(self, id):
  150. """call this any time you get a message from a node
  151. it will update it in the table if it's there """
  152. try:
  153. n = self.findNodes(id)[0]
  154. except IndexError:
  155. return None
  156. else:
  157. if n.id != id:
  158. return None
  159. tstamp = n.lastSeen
  160. n.updateLastSeen()
  161. bucket = self.bucketForInt(n.num)
  162. bucket.seenNode(n)
  163. return tstamp
  164. def invalidateNode(self, n):
  165. """
  166. forget about node n - use when you know that node is invalid
  167. """
  168. n.invalid = True
  169. bucket = self.bucketForInt(n.num)
  170. bucket.invalidateNode(n)
  171. def nodeFailed(self, node):
  172. """ call this when a node fails to respond to a message, to invalidate that node """
  173. try:
  174. n = self.findNodes(node.num)[0]
  175. except IndexError:
  176. return None
  177. else:
  178. if n.id != node.id:
  179. return None
  180. if n.msgFailed() >= const.MAX_FAILURES:
  181. self.invalidateNode(n)
  182. def numPeers(self):
  183. """ estimated number of connectable nodes in global table """
  184. return 8 * (2 ** (len(self.buckets) - 1))
  185. class KBucket(object):
  186. __slots__ = ('min', 'max', 'lastAccessed', 'l', 'index', 'invalid')
  187. def __init__(self, contents, min, max):
  188. self.l = contents
  189. self.index = {}
  190. self.invalid = {}
  191. self.min = min
  192. self.max = max
  193. self.lastAccessed = time()
  194. def touch(self):
  195. self.lastAccessed = time()
  196. def lacmp(self, a, b):
  197. if a.lastSeen > b.lastSeen:
  198. return 1
  199. elif b.lastSeen > a.lastSeen:
  200. return -1
  201. return 0
  202. def sort(self):
  203. self.l.sort(self.lacmp)
  204. def getNodeWithInt(self, num):
  205. try:
  206. node = self.index[num]
  207. except KeyError:
  208. raise ValueError
  209. return node
  210. def addNode(self, node):
  211. if len(self.l) >= K:
  212. return
  213. if self.index.has_key(node.num):
  214. return
  215. self.l.append(node)
  216. self.index[node.num] = node
  217. self.touch()
  218. def removeNode(self, node):
  219. assert self.index.has_key(node.num)
  220. del(self.l[self.l.index(node.num)])
  221. del(self.index[node.num])
  222. try:
  223. del(self.invalid[node.num])
  224. except KeyError:
  225. pass
  226. def invalidateNode(self, node):
  227. self.invalid[node.num] = node
  228. def seenNode(self, node):
  229. try:
  230. del(self.invalid[node.num])
  231. except KeyError:
  232. pass
  233. it = self.l.index(node.num)
  234. del(self.l[it])
  235. self.l.append(node)
  236. self.index[node.num] = node
  237. def hasNode(self, node):
  238. return self.index.has_key(node.num)
  239. def bucketFull(self):
  240. return len(self.l) >= K
  241. def __repr__(self):
  242. return "<KBucket %d items (%d to %d)>" % (len(self.l), self.min, self.max)
  243. ## Comparators
  244. # necessary for bisecting list of buckets with a hash expressed as an integer or a distance
  245. # compares integer or node object with the bucket's range
  246. def __lt__(self, a):
  247. if isinstance(a, Node): a = a.num
  248. return self.max <= a
  249. def __le__(self, a):
  250. if isinstance(a, Node): a = a.num
  251. return self.min < a
  252. def __gt__(self, a):
  253. if isinstance(a, Node): a = a.num
  254. return self.min > a
  255. def __ge__(self, a):
  256. if isinstance(a, Node): a = a.num
  257. return self.max >= a
  258. def __eq__(self, a):
  259. if isinstance(a, Node): a = a.num
  260. return self.min <= a and self.max > a
  261. def __ne__(self, a):
  262. if isinstance(a, Node): a = a.num
  263. return self.min >= a or self.max < a
  264. ### UNIT TESTS ###
  265. import unittest
  266. class TestKTable(unittest.TestCase):
  267. def setUp(self):
  268. self.a = Node().init(hash.newID(), 'localhost', 2002)
  269. self.t = KTable(self.a)
  270. def testAddNode(self):
  271. self.b = Node().init(hash.newID(), 'localhost', 2003)
  272. self.t.insertNode(self.b)
  273. self.assertEqual(len(self.t.buckets[0].l), 1)
  274. self.assertEqual(self.t.buckets[0].l[0], self.b)
  275. def testRemove(self):
  276. self.testAddNode()
  277. self.t.invalidateNode(self.b)
  278. self.assertEqual(len(self.t.buckets[0].l), 0)
  279. def testFail(self):
  280. self.testAddNode()
  281. for i in range(const.MAX_FAILURES - 1):
  282. self.t.nodeFailed(self.b)
  283. self.assertEqual(len(self.t.buckets[0].l), 1)
  284. self.assertEqual(self.t.buckets[0].l[0], self.b)
  285. self.t.nodeFailed(self.b)
  286. self.assertEqual(len(self.t.buckets[0].l), 0)
  287. if __name__ == "__main__":
  288. unittest.main()