1
0

NatCheck.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. # Written by Bram Cohen
  2. # see LICENSE.txt for license information
  3. from cStringIO import StringIO
  4. from socket import error as socketerror
  5. from traceback import print_exc
  6. from BitTornado.BTcrypto import Crypto, CRYPTO_OK
  7. try:
  8. True
  9. except:
  10. True = 1
  11. False = 0
  12. CHECK_PEER_ID_ENCRYPTED = True
  13. protocol_name = 'BitTorrent protocol'
  14. # header, reserved, download id, my id, [length, message]
  15. class NatCheck:
  16. def __init__(self, resultfunc, downloadid, peerid, ip, port, rawserver,
  17. encrypted = False):
  18. self.resultfunc = resultfunc
  19. self.downloadid = downloadid
  20. self.peerid = peerid
  21. self.ip = ip
  22. self.port = port
  23. self.encrypted = encrypted
  24. self.closed = False
  25. self.buffer = ''
  26. self.read = self._read
  27. self.write = self._write
  28. try:
  29. self.connection = rawserver.start_connection((ip, port), self)
  30. if encrypted:
  31. self._dc = not(CRYPTO_OK and CHECK_PEER_ID_ENCRYPTED)
  32. self.encrypter = Crypto(True, disable_crypto = self._dc)
  33. self.write(self.encrypter.pubkey+self.encrypter.padding())
  34. else:
  35. self.encrypter = None
  36. self.write(chr(len(protocol_name)) + protocol_name +
  37. (chr(0) * 8) + downloadid)
  38. except socketerror:
  39. self.answer(False)
  40. except IOError:
  41. self.answer(False)
  42. self.next_len, self.next_func = 1+len(protocol_name), self.read_header
  43. def answer(self, result):
  44. self.closed = True
  45. try:
  46. self.connection.close()
  47. except AttributeError:
  48. pass
  49. self.resultfunc(result, self.downloadid, self.peerid, self.ip, self.port)
  50. def _read_header(self, s):
  51. if s == chr(len(protocol_name))+protocol_name:
  52. return 8, self.read_options
  53. return None
  54. def read_header(self, s):
  55. if self._read_header(s):
  56. if self.encrypted:
  57. return None
  58. return 8, self.read_options
  59. if not self.encrypted:
  60. return None
  61. self._write_buffer(s)
  62. return self.encrypter.keylength, self.read_crypto_header
  63. ################## ENCRYPTION SUPPORT ######################
  64. def _start_crypto(self):
  65. self.encrypter.setrawaccess(self._read,self._write)
  66. self.write = self.encrypter.write
  67. self.read = self.encrypter.read
  68. if self.buffer:
  69. self.buffer = self.encrypter.decrypt(self.buffer)
  70. def read_crypto_header(self, s):
  71. self.encrypter.received_key(s)
  72. self.encrypter.set_skey(self.downloadid)
  73. cryptmode = '\x00\x00\x00\x02' # full stream encryption
  74. padc = self.encrypter.padding()
  75. self.write( self.encrypter.block3a
  76. + self.encrypter.block3b
  77. + self.encrypter.encrypt(
  78. ('\x00'*8) # VC
  79. + cryptmode # acceptable crypto modes
  80. + tobinary16(len(padc))
  81. + padc # PadC
  82. + '\x00\x00' ) ) # no initial payload data
  83. self._max_search = 520
  84. return 1, self.read_crypto_block4a
  85. def _search_for_pattern(self, s, pat):
  86. p = s.find(pat)
  87. if p < 0:
  88. if len(s) >= len(pat):
  89. self._max_search -= len(s)+1-len(pat)
  90. if self._max_search < 0:
  91. self.close()
  92. return False
  93. self._write_buffer(s[1-len(pat):])
  94. return False
  95. self._write_buffer(s[p+len(pat):])
  96. return True
  97. ### OUTGOING CONNECTION ###
  98. def read_crypto_block4a(self, s):
  99. if not self._search_for_pattern(s,self.encrypter.VC_pattern()):
  100. return -1, self.read_crypto_block4a # wait for more data
  101. if self._dc: # can't or won't go any further
  102. self.answer(True)
  103. return None
  104. self._start_crypto()
  105. return 6, self.read_crypto_block4b
  106. def read_crypto_block4b(self, s):
  107. self.cryptmode = toint(s[:4]) % 4
  108. if self.cryptmode != 2:
  109. return None # unknown encryption
  110. padlen = (ord(s[4])<<8)+ord(s[5])
  111. if padlen > 512:
  112. return None
  113. if padlen:
  114. return padlen, self.read_crypto_pad4
  115. return self.read_crypto_block4done()
  116. def read_crypto_pad4(self, s):
  117. # discard data
  118. return self.read_crypto_block4done()
  119. def read_crypto_block4done(self):
  120. if DEBUG:
  121. self._log_start()
  122. if self.cryptmode == 1: # only handshake encryption
  123. if not self.buffer: # oops; check for exceptions to this
  124. return None
  125. self._end_crypto()
  126. self.write(chr(len(protocol_name)) + protocol_name +
  127. option_pattern + self.Encoder.download_id)
  128. return 1+len(protocol_name), self.read_encrypted_header
  129. ### START PROTOCOL OVER ENCRYPTED CONNECTION ###
  130. def read_encrypted_header(self, s):
  131. return self._read_header(s)
  132. ################################################
  133. def read_options(self, s):
  134. return 20, self.read_download_id
  135. def read_download_id(self, s):
  136. if s != self.downloadid:
  137. return None
  138. return 20, self.read_peer_id
  139. def read_peer_id(self, s):
  140. if s != self.peerid:
  141. return None
  142. self.answer(True)
  143. return None
  144. def _write(self, message):
  145. if not self.closed:
  146. self.connection.write(message)
  147. def data_came_in(self, connection, s):
  148. self.read(s)
  149. def _write_buffer(self, s):
  150. self.buffer = s+self.buffer
  151. def _read(self, s):
  152. self.buffer += s
  153. while True:
  154. if self.closed:
  155. return
  156. # self.next_len = # of characters function expects
  157. # or 0 = all characters in the buffer
  158. # or -1 = wait for next read, then all characters in the buffer
  159. # not compatible w/ keepalives, switch out after all negotiation complete
  160. if self.next_len <= 0:
  161. m = self.buffer
  162. self.buffer = ''
  163. elif len(self.buffer) >= self.next_len:
  164. m = self.buffer[:self.next_len]
  165. self.buffer = self.buffer[self.next_len:]
  166. else:
  167. return
  168. try:
  169. x = self.next_func(m)
  170. except:
  171. if not self.closed:
  172. self.answer(False)
  173. return
  174. if x is None:
  175. if not self.closed:
  176. self.answer(False)
  177. return
  178. self.next_len, self.next_func = x
  179. if self.next_len < 0: # already checked buffer
  180. return # wait for additional data
  181. if self.bufferlen is not None:
  182. self._read2('')
  183. return
  184. def connection_lost(self, connection):
  185. if not self.closed:
  186. self.closed = True
  187. self.resultfunc(False, self.downloadid, self.peerid, self.ip, self.port)
  188. def connection_flushed(self, connection):
  189. pass