2019-01-14 08:14:06 +01:00

207 lines
6.9 KiB
Python
Executable File

import io
import socket
import struct
from collections import namedtuple
from collections import OrderedDict
from .compat import iteritems
from .exception import DeserializationException
class Transport(object):
HEADER_LENGTH = 4
MAX_SEGMENT = 512 * 1024
def __init__(self, sock):
self.socket = sock
def send(self, packet):
self.socket.sendall(struct.pack("!I", len(packet)) + packet)
def receive(self):
raw_length = self._recvall(self.HEADER_LENGTH)
length, = struct.unpack("!I", raw_length)
payload = self._recvall(length)
return payload
def close(self):
self.socket.shutdown(socket.SHUT_RDWR)
self.socket.close()
def _recvall(self, count):
"""Ensure to read count bytes from the socket"""
data = b""
while len(data) < count:
buf = self.socket.recv(count - len(data))
if not buf:
raise socket.error('Connection closed')
data += buf
return data
class Packet(object):
CMD_REQUEST = 0 # Named request message
CMD_RESPONSE = 1 # Unnamed response message for a request
CMD_UNKNOWN = 2 # Unnamed response if requested command is unknown
EVENT_REGISTER = 3 # Named event registration request
EVENT_UNREGISTER = 4 # Named event de-registration request
EVENT_CONFIRM = 5 # Unnamed confirmation for event (de-)registration
EVENT_UNKNOWN = 6 # Unnamed response if event (de-)registration failed
EVENT = 7 # Named event message
ParsedPacket = namedtuple(
"ParsedPacket",
["response_type", "payload"]
)
ParsedEventPacket = namedtuple(
"ParsedEventPacket",
["response_type", "event_type", "payload"]
)
@classmethod
def _named_request(cls, request_type, request, message=None):
request = request.encode("UTF-8")
payload = struct.pack("!BB", request_type, len(request)) + request
if message is not None:
return payload + message
else:
return payload
@classmethod
def request(cls, command, message=None):
return cls._named_request(cls.CMD_REQUEST, command, message)
@classmethod
def register_event(cls, event_type):
return cls._named_request(cls.EVENT_REGISTER, event_type)
@classmethod
def unregister_event(cls, event_type):
return cls._named_request(cls.EVENT_UNREGISTER, event_type)
@classmethod
def parse(cls, packet):
stream = FiniteStream(packet)
response_type, = struct.unpack("!B", stream.read(1))
if response_type == cls.EVENT:
length, = struct.unpack("!B", stream.read(1))
event_type = stream.read(length)
return cls.ParsedEventPacket(response_type, event_type, stream)
else:
return cls.ParsedPacket(response_type, stream)
class Message(object):
SECTION_START = 1 # Begin a new section having a name
SECTION_END = 2 # End a previously started section
KEY_VALUE = 3 # Define a value for a named key in the section
LIST_START = 4 # Begin a named list for list items
LIST_ITEM = 5 # Define an unnamed item value in the current list
LIST_END = 6 # End a previously started list
@classmethod
def serialize(cls, message):
def encode_named_type(marker, name):
name = name.encode("UTF-8")
return struct.pack("!BB", marker, len(name)) + name
def encode_blob(value):
if not isinstance(value, bytes):
value = str(value).encode("UTF-8")
return struct.pack("!H", len(value)) + value
def serialize_list(lst):
segment = bytes()
for item in lst:
segment += struct.pack("!B", cls.LIST_ITEM) + encode_blob(item)
return segment
def serialize_dict(d):
segment = bytes()
for key, value in iteritems(d):
if isinstance(value, dict):
segment += (
encode_named_type(cls.SECTION_START, key)
+ serialize_dict(value)
+ struct.pack("!B", cls.SECTION_END)
)
elif isinstance(value, list):
segment += (
encode_named_type(cls.LIST_START, key)
+ serialize_list(value)
+ struct.pack("!B", cls.LIST_END)
)
else:
segment += (
encode_named_type(cls.KEY_VALUE, key)
+ encode_blob(value)
)
return segment
return serialize_dict(message)
@classmethod
def deserialize(cls, stream):
def decode_named_type(stream):
length, = struct.unpack("!B", stream.read(1))
return stream.read(length).decode("UTF-8")
def decode_blob(stream):
length, = struct.unpack("!H", stream.read(2))
return stream.read(length)
def decode_list_item(stream):
marker, = struct.unpack("!B", stream.read(1))
while marker == cls.LIST_ITEM:
yield decode_blob(stream)
marker, = struct.unpack("!B", stream.read(1))
if marker != cls.LIST_END:
raise DeserializationException(
"Expected end of list at {pos}".format(pos=stream.tell())
)
section = OrderedDict()
section_stack = []
while stream.has_more():
element_type, = struct.unpack("!B", stream.read(1))
if element_type == cls.SECTION_START:
section_name = decode_named_type(stream)
new_section = OrderedDict()
section[section_name] = new_section
section_stack.append(section)
section = new_section
elif element_type == cls.LIST_START:
list_name = decode_named_type(stream)
section[list_name] = [item for item in decode_list_item(stream)]
elif element_type == cls.KEY_VALUE:
key = decode_named_type(stream)
section[key] = decode_blob(stream)
elif element_type == cls.SECTION_END:
if len(section_stack):
section = section_stack.pop()
else:
raise DeserializationException(
"Unexpected end of section at {pos}".format(
pos=stream.tell()
)
)
if len(section_stack):
raise DeserializationException("Expected end of section")
return section
class FiniteStream(io.BytesIO):
def __len__(self):
return len(self.getvalue())
def has_more(self):
return self.tell() < len(self)