#!/usr/bin/env python

import os
import pwd
import sys
import asyncore
import socket
import errno
from struct import unpack

def bytestream (value):
	return ''.join(['%02X' % ord(_) for _ in value])

def dump (value):
	def spaced (value):
		even = None
		for v in value:
			if even is False:
				yield ' '
			yield '%02X' % ord(v)
			even = not even
	return ''.join(spaced(value))


class BGPHandler(asyncore.dispatcher_with_send):

	keepalive = chr(0xFF)*16 + chr(0x0) + chr(0x13) + chr(0x4)

	_name = {
		chr(1) : 'OPEN',
		chr(2) : 'UPDATE',
		chr(3) : 'NOTIFICATION',
		chr(4) : 'KEEPALIVE',
	}

	def kind (self,header):
		return header[18]

	def isupdate (self,header):
		return header[18] == chr(2)

	def isnotification (self,header):
		return header[18] == chr(4)

	def name (self,header):
		return self._name.get(header[18],'SOME WEIRD RFC PACKET')

	def routes (self,body):
		len_w = unpack('!H',body[0:2])[0]
		withdrawn = [ord(_) for _ in body[2:2+len_w]]
		len_a = unpack('!H',body[2+len_w:2+len_w+2])[0]
		announced = [ord(_) for _ in body[2+len_w + 2+len_a:]]

		if not withdrawn and not announced:
			if len(body) == 11:
				yield 'eor:%d:%d' % (ord(body[-2]),ord(body[-1])),'raw:%s' % bytestream(body)
			else:  # undecoded MP route
				yield 'mp:','raw:%s' % bytestream(body)

		while withdrawn:
			if len(withdrawn) > 5:
				yield '', 'raw:%s' % bytestream(body)
				break
			m = withdrawn.pop(0)
			r = [0,0,0,0]
			for index in range(4):
				if index*8 >= m: break
				r[index] = withdrawn.pop(0)
			yield 'withdraw:%s' % '.'.join(str(_) for _ in r) + '/' + str(m), 'raw:%s' % bytestream(body)

		while announced:
			if len(announced) > 5:
				yield '', 'raw:%s' % bytestream(body)
				break
			m = announced.pop(0)
			r = [0,0,0,0]
			for index in range(4):
				if index*8 >= m: break
				r[index] = announced.pop(0)
			yield 'announce:%s' % '.'.join(str(_) for _ in r) + '/' + str(m), 'raw:%s' % bytestream(body)

	def notification (self,body):
		yield 'notification:%d,%d' % (ord(body[0]),ord(body[1])), 'raw:%s' % bytestream(body)

	def announce (self,*args):
		print self.ip,self.port,' '.join(str(_) for _ in args) if len(args) > 1 else args[0]

	def setup (self,ip,port,letter,messages,exit):
		self.ip = ip
		self.port = port
		self.exit = exit
		self.handle_read = self.handle_open
		self.sequence = {}
		self.raw = False
		self.letter = letter
		for rule in messages:
			sequence,announcement = rule.split(':',1)
			if announcement.startswith('raw:'):
				raw = ''.join(announcement[4:].replace(':',''))
				stream = raw[38:] if raw.startswith('F'*16) else raw
				announcement = 'raw:' + stream
				self.raw = True
			self.sequence.setdefault(sequence,[]).append(announcement)
		self.update_sequence()
		return self

	def update_sequence (self):
		keys = sorted(list(self.sequence))
		if keys:
			key = keys[0]
			self.messages = self.sequence[key]
			self.step = key
			del self.sequence[key]
			return True
		return False

	def read_message (self):
		header = ''
		while len(header) != 19:
			try:
				left = 19-len(header)
				header += self.recv(left)
				if left == 19-len(header):  # ugly
					# the TCP session is gone.
					return None,None
			except socket.error,e:
				if e.args[0] in (errno.EWOULDBLOCK,errno.EAGAIN):
					continue
				raise e

		length = unpack('!H',header[16:18])[0] - 19

		body = ''
		while len(body) != length:
			try:
				left = length-len(body)
				body += self.recv(left)
			except socket.error,e:
				if e.args[0] in (errno.EWOULDBLOCK,errno.EAGAIN):
					continue
				raise e

		return header,body

	def handle_open (self):
		# reply with a IBGP response with the same capability (just changing routerID)
		header,body = self.read_message()
		routerid = chr((ord(body[8])+1) & 0xFF)
		o = header+body[:8]+routerid+body[9:]
		self.send(o)
		self.send(self.keepalive)
		self.handle_read = self.handle_keepalive

	def handle_keepalive (self):
		header,body = self.read_message()

		if header is None:
			self.announce('connection closed')
			self.close()
			return

		parser = self._decoder.get(self.kind(header),None)

		if parser:
			for parsed,raw in parser(self,body):
				announcement = raw if self.raw else parsed

				if announcement.startswith('eor:'):  # skip EOR
					self.announce('skip eor',announcement)
					continue

				if announcement.startswith('mp:'):  # skip unparsed MP
					self.announce('skip announcement for body :',dump(body))
					continue

				if announcement in self.messages:
					self.messages.remove(announcement)
					self.announce('expected announcement received (%s%s):' % (self.letter,self.step),announcement)
				else:
					if raw:
						self.announce('\n','unexpected announcement (%s%s):' % (self.letter,self.step),'%s %s' % (bytestream(header),bytestream(body)))
					else:
						self.announce('\n','unexpected announcement:',announcement)

					self.announce('valid options were:',', '.join(self.messages),'\n')
					sys.exit(1)

				if not self.messages:
					if not self.update_sequence():
						if self.exit:
							self.announce('successful')
							sys.exit(0)

		self.send(self.keepalive)

	_decoder= {
		chr(2) : routes,
		chr(3) : notification,
	}

class BGPServer (asyncore.dispatcher):
	def __init__ (self,host,port,messages):
		asyncore.dispatcher.__init__(self)
		self.create_socket(socket.AF_INET,socket.SOCK_STREAM)
		self.set_reuse_addr()
		self.bind((host,port))
		self.listen(5)

		self.messages = {}

		for message in messages:
			if message[0].isalpha():
				index,content = message[:1].upper(), message[1:]
			else:
				index,content = 'A',message
			self.messages.setdefault(index,[]).append(content)

	def handle_accept (self):
		messages = None
		for number in range(ord('A'),ord('Z')+1):
			letter = chr(number)
			if letter in self.messages:
				messages = self.messages[letter]
				del self.messages[letter]
				break

		if not messages:
			print 'we used all the test data available, can not handle this new connection'
			sys.exit(1)
		else:
			print 'using :\n', '\n'.join(messages), '\n'

		exit = not len(self.messages.keys())

		pair = self.accept()
		if pair is not None:
			sock,addr = pair
			handler = BGPHandler(sock).setup(*addr,letter=letter,messages=messages,exit=exit)

def drop ():
	uid = os.getuid()
	gid = os.getgid()

	if uid and gid:
		return

	for name in ['nobody',]:
		try:
			user = pwd.getpwnam(name)
			nuid = int(user.pw_uid)
			ngid = int(user.pw_uid)
		except KeyError:
			pass

	if not gid:
		os.setgid(ngid)
	if not uid:
		os.setuid(nuid)

def main ():
	if len(sys.argv) <= 1:
		print 'a list of expected route announcement/withdrawl in the format <number>:announce:<ipv4-route> <number>:withdraw:<ipv4-route> <number>:raw:<exabgp hex dump : separated>'
		print 'for example:',sys.argv[0],'1:announce:10.0.0.0/8 1:announce:192.0.2.0/24 2:withdraw:10.0.0.0/8 '
		print 'routes with the same <number> can arrive in any order'
		sys.exit(1)

	try:
		with open(sys.argv[1]) as content:
			messages = [_.strip() for _ in content.readlines() if _.strip() and '#' not in _]
	except IOError:
		print 'coud not open file', sys.argv[1]

	try:
		BGPServer('localhost',int(os.environ.get('exabgp.tcp.port','179')),messages)
		drop()
		asyncore.loop()
	except socket.error,e:
		if e.errno == errno.EACCES:
			print "failure: could not bind to port %s - most likely not run as root" % os.environ.get('exabgp.tcp.port','179')
		elif e.errno == errno.EADDRINUSE:
			print "failure: could not bind to port %s - port already in use" % os.environ.get('exabgp.tcp.port','179')
		else:
			print "failure", str(e)

if __name__ == '__main__':
	main()
