| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135 |
- # Written by Bram Cohen
- # see LICENSE.txt for license information
- from cStringIO import StringIO
- from binascii import b2a_hex
- from socket import error as socketerror
- from urllib import quote
- from traceback import print_exc
- import Connecter
- try:
- True
- except:
- True = 1
- False = 0
- DEBUG = False
- protocol_name = 'BitTorrent protocol'
- option_pattern = chr(0)*8
- def toint(s):
- return long(b2a_hex(s), 16)
- def tobinary(i):
- return (chr(i >> 24) + chr((i >> 16) & 0xFF) +
- chr((i >> 8) & 0xFF) + chr(i & 0xFF))
- hexchars = '0123456789ABCDEF'
- hexmap = []
- for i in xrange(256):
- hexmap.append(hexchars[(i&0xF0)/16]+hexchars[i&0x0F])
- def tohex(s):
- r = []
- for c in s:
- r.append(hexmap[ord(c)])
- return ''.join(r)
- def make_readable(s):
- if not s:
- return ''
- if quote(s).find('%') >= 0:
- return tohex(s)
- return '"'+s+'"'
-
- def toint(s):
- return long(b2a_hex(s), 16)
- # header, reserved, download id, my id, [length, message]
- streamno = 0
- class StreamCheck:
- def __init__(self):
- global streamno
- self.no = streamno
- streamno += 1
- self.buffer = StringIO()
- self.next_len, self.next_func = 1, self.read_header_len
- def read_header_len(self, s):
- if ord(s) != len(protocol_name):
- print self.no, 'BAD HEADER LENGTH'
- return len(protocol_name), self.read_header
- def read_header(self, s):
- if s != protocol_name:
- print self.no, 'BAD HEADER'
- return 8, self.read_reserved
- def read_reserved(self, s):
- return 20, self.read_download_id
- def read_download_id(self, s):
- if DEBUG:
- print self.no, 'download ID ' + tohex(s)
- return 20, self.read_peer_id
- def read_peer_id(self, s):
- if DEBUG:
- print self.no, 'peer ID' + make_readable(s)
- return 4, self.read_len
- def read_len(self, s):
- l = toint(s)
- if l > 2 ** 23:
- print self.no, 'BAD LENGTH: '+str(l)+' ('+s+')'
- return l, self.read_message
- def read_message(self, s):
- if not s:
- return 4, self.read_len
- m = s[0]
- if ord(m) > 8:
- print self.no, 'BAD MESSAGE: '+str(ord(m))
- if m == Connecter.REQUEST:
- if len(s) != 13:
- print self.no, 'BAD REQUEST SIZE: '+str(len(s))
- return 4, self.read_len
- index = toint(s[1:5])
- begin = toint(s[5:9])
- length = toint(s[9:])
- print self.no, 'Request: '+str(index)+': '+str(begin)+'-'+str(begin)+'+'+str(length)
- elif m == Connecter.CANCEL:
- if len(s) != 13:
- print self.no, 'BAD CANCEL SIZE: '+str(len(s))
- return 4, self.read_len
- index = toint(s[1:5])
- begin = toint(s[5:9])
- length = toint(s[9:])
- print self.no, 'Cancel: '+str(index)+': '+str(begin)+'-'+str(begin)+'+'+str(length)
- elif m == Connecter.PIECE:
- index = toint(s[1:5])
- begin = toint(s[5:9])
- length = len(s)-9
- print self.no, 'Piece: '+str(index)+': '+str(begin)+'-'+str(begin)+'+'+str(length)
- else:
- print self.no, 'Message '+str(ord(m))+' (length '+str(len(s))+')'
- return 4, self.read_len
- def write(self, s):
- while True:
- i = self.next_len - self.buffer.tell()
- if i > len(s):
- self.buffer.write(s)
- return
- self.buffer.write(s[:i])
- s = s[i:]
- m = self.buffer.getvalue()
- self.buffer.reset()
- self.buffer.truncate()
- x = self.next_func(m)
- self.next_len, self.next_func = x
|