| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219 |
- # Written by Bram Cohen
- # see LICENSE.txt for license information
- from cStringIO import StringIO
- from socket import error as socketerror
- from traceback import print_exc
- from BitTornado.BTcrypto import Crypto, CRYPTO_OK
- try:
- True
- except:
- True = 1
- False = 0
- CHECK_PEER_ID_ENCRYPTED = True
- protocol_name = 'BitTorrent protocol'
- # header, reserved, download id, my id, [length, message]
- class NatCheck:
- def __init__(self, resultfunc, downloadid, peerid, ip, port, rawserver,
- encrypted = False):
- self.resultfunc = resultfunc
- self.downloadid = downloadid
- self.peerid = peerid
- self.ip = ip
- self.port = port
- self.encrypted = encrypted
- self.closed = False
- self.buffer = ''
- self.read = self._read
- self.write = self._write
- try:
- self.connection = rawserver.start_connection((ip, port), self)
- if encrypted:
- self._dc = not(CRYPTO_OK and CHECK_PEER_ID_ENCRYPTED)
- self.encrypter = Crypto(True, disable_crypto = self._dc)
- self.write(self.encrypter.pubkey+self.encrypter.padding())
- else:
- self.encrypter = None
- self.write(chr(len(protocol_name)) + protocol_name +
- (chr(0) * 8) + downloadid)
- except socketerror:
- self.answer(False)
- except IOError:
- self.answer(False)
- self.next_len, self.next_func = 1+len(protocol_name), self.read_header
- def answer(self, result):
- self.closed = True
- try:
- self.connection.close()
- except AttributeError:
- pass
- self.resultfunc(result, self.downloadid, self.peerid, self.ip, self.port)
- def _read_header(self, s):
- if s == chr(len(protocol_name))+protocol_name:
- return 8, self.read_options
- return None
- def read_header(self, s):
- if self._read_header(s):
- if self.encrypted:
- return None
- return 8, self.read_options
- if not self.encrypted:
- return None
- self._write_buffer(s)
- return self.encrypter.keylength, self.read_crypto_header
- ################## ENCRYPTION SUPPORT ######################
- def _start_crypto(self):
- self.encrypter.setrawaccess(self._read,self._write)
- self.write = self.encrypter.write
- self.read = self.encrypter.read
- if self.buffer:
- self.buffer = self.encrypter.decrypt(self.buffer)
- def read_crypto_header(self, s):
- self.encrypter.received_key(s)
- self.encrypter.set_skey(self.downloadid)
- cryptmode = '\x00\x00\x00\x02' # full stream encryption
- padc = self.encrypter.padding()
- self.write( self.encrypter.block3a
- + self.encrypter.block3b
- + self.encrypter.encrypt(
- ('\x00'*8) # VC
- + cryptmode # acceptable crypto modes
- + tobinary16(len(padc))
- + padc # PadC
- + '\x00\x00' ) ) # no initial payload data
- self._max_search = 520
- return 1, self.read_crypto_block4a
- def _search_for_pattern(self, s, pat):
- p = s.find(pat)
- if p < 0:
- if len(s) >= len(pat):
- self._max_search -= len(s)+1-len(pat)
- if self._max_search < 0:
- self.close()
- return False
- self._write_buffer(s[1-len(pat):])
- return False
- self._write_buffer(s[p+len(pat):])
- return True
- ### OUTGOING CONNECTION ###
- def read_crypto_block4a(self, s):
- if not self._search_for_pattern(s,self.encrypter.VC_pattern()):
- return -1, self.read_crypto_block4a # wait for more data
- if self._dc: # can't or won't go any further
- self.answer(True)
- return None
- self._start_crypto()
- return 6, self.read_crypto_block4b
- def read_crypto_block4b(self, s):
- self.cryptmode = toint(s[:4]) % 4
- if self.cryptmode != 2:
- return None # unknown encryption
- padlen = (ord(s[4])<<8)+ord(s[5])
- if padlen > 512:
- return None
- if padlen:
- return padlen, self.read_crypto_pad4
- return self.read_crypto_block4done()
- def read_crypto_pad4(self, s):
- # discard data
- return self.read_crypto_block4done()
- def read_crypto_block4done(self):
- if DEBUG:
- self._log_start()
- if self.cryptmode == 1: # only handshake encryption
- if not self.buffer: # oops; check for exceptions to this
- return None
- self._end_crypto()
- self.write(chr(len(protocol_name)) + protocol_name +
- option_pattern + self.Encoder.download_id)
- return 1+len(protocol_name), self.read_encrypted_header
- ### START PROTOCOL OVER ENCRYPTED CONNECTION ###
- def read_encrypted_header(self, s):
- return self._read_header(s)
- ################################################
- def read_options(self, s):
- return 20, self.read_download_id
- def read_download_id(self, s):
- if s != self.downloadid:
- return None
- return 20, self.read_peer_id
- def read_peer_id(self, s):
- if s != self.peerid:
- return None
- self.answer(True)
- return None
- def _write(self, message):
- if not self.closed:
- self.connection.write(message)
- def data_came_in(self, connection, s):
- self.read(s)
- def _write_buffer(self, s):
- self.buffer = s+self.buffer
- def _read(self, s):
- self.buffer += s
- while True:
- if self.closed:
- return
- # self.next_len = # of characters function expects
- # or 0 = all characters in the buffer
- # or -1 = wait for next read, then all characters in the buffer
- # not compatible w/ keepalives, switch out after all negotiation complete
- if self.next_len <= 0:
- m = self.buffer
- self.buffer = ''
- elif len(self.buffer) >= self.next_len:
- m = self.buffer[:self.next_len]
- self.buffer = self.buffer[self.next_len:]
- else:
- return
- try:
- x = self.next_func(m)
- except:
- if not self.closed:
- self.answer(False)
- return
- if x is None:
- if not self.closed:
- self.answer(False)
- return
- self.next_len, self.next_func = x
- if self.next_len < 0: # already checked buffer
- return # wait for additional data
- if self.bufferlen is not None:
- self._read2('')
- return
- def connection_lost(self, connection):
- if not self.closed:
- self.closed = True
- self.resultfunc(False, self.downloadid, self.peerid, self.ip, self.port)
- def connection_flushed(self, connection):
- pass
|