Upload.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. # The contents of this file are subject to the BitTorrent Open Source License
  2. # Version 1.1 (the License). You may not copy or use this file, in either
  3. # source code or executable form, except in compliance with the License. You
  4. # may obtain a copy of the License at http://www.bittorrent.com/license/.
  5. #
  6. # Software distributed under the License is distributed on an AS IS basis,
  7. # WITHOUT WARRANTY OF ANY KIND, either express or implied. See the License
  8. # for the specific language governing rights and limitations under the
  9. # License.
  10. # Written by Bram Cohen, Greg Hazel, and David Harrison
  11. if __name__ == "__main__":
  12. # for unit-testing.
  13. import sys
  14. sys.path.append("..")
  15. from BitTorrent.CurrentRateMeasure import Measure
  16. import BitTorrent.Connector
  17. from BTL.hash import sha
  18. import struct
  19. import logging
  20. logger = logging.getLogger("BitTorrent.Upload")
  21. log = logger.debug
  22. # Maximum number of outstanding requests from a peer
  23. MAX_REQUESTS = 256
  24. def _compute_allowed_fast_list(infohash, ip, num_fast, num_pieces):
  25. # if ipv4 then (for now assume IPv4)
  26. iplist = [int(x) for x in ip.split(".")]
  27. # classful heuristic.
  28. iplist = [chr(iplist[0]),chr(iplist[1]),chr(iplist[2]),chr(0)]
  29. h = "".join(iplist)
  30. h = "".join([h,infohash])
  31. fastlist = []
  32. assert num_pieces < 2**32
  33. if num_pieces <= num_fast:
  34. return range(num_pieces) # <---- this would be bizarre
  35. while True:
  36. h = sha(h).digest() # rehash hash to generate new random string.
  37. for i in xrange(5):
  38. j = i*4
  39. #y = [ord(x) for x in h[j:j+4]]
  40. #z = (y[0] << 24) + (y[1]<<16) + (y[2]<<8) + y[3]
  41. z = struct.unpack("!L", h[j:j+4])[0]
  42. index = int(z % num_pieces)
  43. if index not in fastlist:
  44. fastlist.append(index)
  45. if len(fastlist) >= num_fast:
  46. return fastlist
  47. class Upload(object):
  48. """Upload over a single connection."""
  49. def __init__(self, multidownload, connector, ratelimiter, choker, storage,
  50. max_chunk_length, max_rate_period, num_fast, infohash):
  51. assert isinstance(connector, BitTorrent.Connector.Connector)
  52. self.multidownload = multidownload
  53. self.connector = connector
  54. self.ratelimiter = ratelimiter
  55. self.infohash = infohash
  56. self.choker = choker
  57. self.num_fast = num_fast
  58. self.storage = storage
  59. self.max_chunk_length = max_chunk_length
  60. self.choked = True
  61. self.unchoke_time = None
  62. self.interested = False
  63. self.had_length_error = False
  64. self.had_max_requests_error = False
  65. self.buffer = [] # contains piece data about to be sent.
  66. self.measure = Measure(max_rate_period)
  67. connector.add_sent_listener(self.measure.update_rate)
  68. self.allowed_fast_pieces = []
  69. if connector.uses_fast_extension:
  70. if storage.get_amount_left() == 0:
  71. connector.send_have_all()
  72. elif storage.do_I_have_anything():
  73. connector.send_bitfield(storage.get_have_list())
  74. else:
  75. connector.send_have_none()
  76. self._send_allowed_fast_list()
  77. elif storage.do_I_have_anything():
  78. connector.send_bitfield(storage.get_have_list())
  79. def _send_allowed_fast_list(self):
  80. """Computes and sends the 'allowed fast' set. """
  81. self.allowed_fast_pieces = _compute_allowed_fast_list(
  82. self.infohash,
  83. self.connector.ip, self.num_fast,
  84. self.storage.get_num_pieces())
  85. for index in self.allowed_fast_pieces:
  86. self.connector.send_allowed_fast(index)
  87. def _compute_allowed_fast_list(self,infohash,ip, num_fast, num_pieces):
  88. # if ipv4 then (for now assume IPv4)
  89. iplist = [int(x) for x in ip.split(".")]
  90. # classful heuristic.
  91. if iplist[0] | 0x7F==0xFF or iplist[0] & 0xC0==0x80: # class A or B
  92. iplist = [chr(iplist[0]),chr(iplist[1]),chr(0),chr(0)]
  93. else:
  94. iplist = [chr(iplist[0]),chr(iplist[1]),chr(iplist[2]),chr(0)]
  95. h = "".join(iplist)
  96. h = "".join([h,infohash])
  97. fastlist = []
  98. assert num_pieces < 2**32
  99. if num_pieces <= num_fast:
  100. return range(num_pieces) # <---- this would be bizarre
  101. while True:
  102. h = sha(h).digest() # rehash hash to generate new random string.
  103. #log("infohash=%s" % h.encode('hex'))
  104. for i in xrange(5):
  105. j = i*4
  106. y = [ord(x) for x in h[j:j+4]]
  107. z = (y[0] << 24) + (y[1]<<16) + (y[2]<<8) + y[3]
  108. index = int(z % num_pieces)
  109. #log("z=%s=%d, index=%d" % ( hex(z), z, index ))
  110. if index not in fastlist:
  111. fastlist.append(index)
  112. if len(fastlist) >= num_fast:
  113. return fastlist
  114. def got_not_interested(self):
  115. if self.interested:
  116. self.interested = False
  117. self.choker.not_interested(self.connector)
  118. def got_interested(self):
  119. if not self.interested:
  120. self.interested = True
  121. self.choker.interested(self.connector)
  122. def get_upload_chunk(self, index, begin, length):
  123. df = self.storage.read(index, begin, length)
  124. df.addCallback(lambda piece: (index, begin, piece))
  125. df.addErrback(self._failed_get_upload_chunk)
  126. return df
  127. def _failed_get_upload_chunk(self, f):
  128. log("get_upload_chunk failed", exc_info=f.exc_info())
  129. self.connector.close()
  130. return f
  131. def got_request(self, index, begin, length):
  132. if not self.interested:
  133. self.connector.protocol_violation("request when not interested")
  134. self.connector.close()
  135. return
  136. if length > self.max_chunk_length:
  137. if not self.had_length_error:
  138. m = ("request length %r exceeds max %r" %
  139. (length, self.max_chunk_length))
  140. self.connector.protocol_violation(m)
  141. self.had_length_error = True
  142. #self.connector.close()
  143. # we could still download...
  144. if self.connector.uses_fast_extension:
  145. self.connector.send_reject_request(index, begin, length)
  146. return
  147. if len(self.buffer) > MAX_REQUESTS:
  148. if not self.had_max_requests_error:
  149. m = ("max request limit %d" % MAX_REQUESTS)
  150. self.connector.protocol_violation(m)
  151. self.had_max_requests_error = True
  152. if self.connector.uses_fast_extension:
  153. self.connector.send_reject_request(index, begin, length)
  154. return
  155. if index in self.allowed_fast_pieces or not self.connector.choke_sent:
  156. df = self.get_upload_chunk(index, begin, length)
  157. df.addCallback(self._got_piece)
  158. df.addErrback(self.multidownload.errorfunc)
  159. elif self.connector.uses_fast_extension:
  160. self.connector.send_reject_request(index, begin, length)
  161. def _got_piece(self, piece_info):
  162. index, begin, piece = piece_info
  163. if self.connector.closed:
  164. return
  165. if self.choked:
  166. if not self.connector.uses_fast_extension:
  167. return
  168. if index not in self.allowed_fast_pieces:
  169. self.connector.send_reject_request(index, begin, len(piece))
  170. return
  171. self.buffer.append(((index, begin, len(piece)), piece))
  172. if self.connector.next_upload is None and \
  173. self.connector.connection.is_flushed():
  174. self.ratelimiter.queue(self.connector)
  175. def got_cancel(self, index, begin, length):
  176. req = (index, begin, length)
  177. for pos, (r, p) in enumerate(self.buffer):
  178. if r == req:
  179. del self.buffer[pos]
  180. if self.connector.uses_fast_extension:
  181. self.connector.send_reject_request(*req)
  182. break
  183. def choke(self):
  184. if not self.choked:
  185. self.choked = True
  186. self.connector.send_choke()
  187. def sent_choke(self):
  188. assert self.choked
  189. if self.connector.uses_fast_extension:
  190. b2 = []
  191. for r in self.buffer:
  192. ((index,begin,length),piecedata) = r
  193. if index not in self.allowed_fast_pieces:
  194. self.connector.send_reject_request(index, begin, length)
  195. else:
  196. b2.append(r)
  197. self.buffer = b2
  198. else:
  199. del self.buffer[:]
  200. def unchoke(self, time):
  201. if self.choked:
  202. self.choked = False
  203. self.unchoke_time = time
  204. self.connector.send_unchoke()
  205. def has_queries(self):
  206. return len(self.buffer) > 0
  207. def get_rate(self):
  208. return self.measure.get_rate()
  209. if __name__ == "__main__":
  210. # unit tests for allowed fast set generation.
  211. n_tests = n_tests_passed = 0
  212. infohash = "".join( ['\xaa']*20 ) # 20 byte string containing all 0xaa.
  213. ip = "80.4.4.200"
  214. expected_list = [1059,431,808,1217,287,376,1188]
  215. n_tests += 1
  216. fast_list =_compute_allowed_fast_list(
  217. infohash, ip, num_fast = 7, num_pieces = 1313 )
  218. if expected_list != fast_list:
  219. print ( "FAIL!! expected list = %s, but got %s" %
  220. (str(expected_list), str(fast_list)) )
  221. else:
  222. n_tests_passed += 1
  223. n_tests += 1
  224. expected_list.extend( [353,508] )
  225. fast_list =_compute_allowed_fast_list(
  226. infohash, ip, num_fast = 9, num_pieces = 1313 )
  227. if expected_list != fast_list:
  228. print ("FAIL!! expected list = %s, but got %s" %
  229. (str(expected_list), str(fast_list)))
  230. else:
  231. n_tests_passed += 1
  232. if n_tests == n_tests_passed:
  233. print "Success. Passed all %d unit tests." % n_tests
  234. else:
  235. print "Passed only %d out of %d unit tests." % (n_tests_passed,n_tests)