mirror of
https://github.com/lucaspalomodevelop/opnsense-core.git
synced 2026-03-13 00:07:27 +00:00
VPN: IPsec: Status Overview - cleanup, remove vici library in favour of port package
This commit is contained in:
parent
058aedc61e
commit
becf4e9342
1
Makefile
1
Makefile
@ -184,6 +184,7 @@ CORE_DEPENDS?= ca_root_nss \
|
||||
py${CORE_PYTHON}-requests \
|
||||
py${CORE_PYTHON}-sqlite3 \
|
||||
py${CORE_PYTHON}-ujson \
|
||||
py${CORE_PYTHON}-vici \
|
||||
radvd \
|
||||
rrdtool \
|
||||
samplicator \
|
||||
|
||||
7
plist
7
plist
@ -801,13 +801,6 @@
|
||||
/usr/local/opnsense/scripts/ipsec/disconnect.py
|
||||
/usr/local/opnsense/scripts/ipsec/list_leases.py
|
||||
/usr/local/opnsense/scripts/ipsec/list_status.py
|
||||
/usr/local/opnsense/scripts/ipsec/vici/__init__.py
|
||||
/usr/local/opnsense/scripts/ipsec/vici/compat.py
|
||||
/usr/local/opnsense/scripts/ipsec/vici/exception.py
|
||||
/usr/local/opnsense/scripts/ipsec/vici/protocol.py
|
||||
/usr/local/opnsense/scripts/ipsec/vici/session.py
|
||||
/usr/local/opnsense/scripts/ipsec/vici/test/__init__.py
|
||||
/usr/local/opnsense/scripts/ipsec/vici/test/test_protocol.py
|
||||
/usr/local/opnsense/scripts/netflow/dump_log.py
|
||||
/usr/local/opnsense/scripts/netflow/export_details.py
|
||||
/usr/local/opnsense/scripts/netflow/flowctl_stats.py
|
||||
|
||||
@ -1 +0,0 @@
|
||||
from .session import Session
|
||||
@ -1,14 +0,0 @@
|
||||
# Help functions for compatibility between python version 2 and 3
|
||||
|
||||
|
||||
# From https://legacy.python.org/dev/peps/pep-0469
|
||||
try:
|
||||
dict.iteritems
|
||||
except AttributeError:
|
||||
# python 3
|
||||
def iteritems(d):
|
||||
return iter(d.items())
|
||||
else:
|
||||
# python 2
|
||||
def iteritems(d):
|
||||
return d.iteritems()
|
||||
@ -1,13 +0,0 @@
|
||||
"""Exception types that may be thrown by this library."""
|
||||
|
||||
class DeserializationException(Exception):
|
||||
"""Encountered an unexpected byte sequence or missing element type."""
|
||||
|
||||
class SessionException(Exception):
|
||||
"""Session request exception."""
|
||||
|
||||
class CommandException(Exception):
|
||||
"""Command result exception."""
|
||||
|
||||
class EventUnknownException(Exception):
|
||||
"""Event unknown exception."""
|
||||
@ -1,206 +0,0 @@
|
||||
import io
|
||||
import socket
|
||||
import struct
|
||||
|
||||
from collections import namedtuple
|
||||
from collections import OrderedDict
|
||||
|
||||
from .compat import iteritems
|
||||
from .exception import DeserializationException
|
||||
|
||||
|
||||
class Transport(object):
|
||||
HEADER_LENGTH = 4
|
||||
MAX_SEGMENT = 512 * 1024
|
||||
|
||||
def __init__(self, sock):
|
||||
self.socket = sock
|
||||
|
||||
def send(self, packet):
|
||||
self.socket.sendall(struct.pack("!I", len(packet)) + packet)
|
||||
|
||||
def receive(self):
|
||||
raw_length = self._recvall(self.HEADER_LENGTH)
|
||||
length, = struct.unpack("!I", raw_length)
|
||||
payload = self._recvall(length)
|
||||
return payload
|
||||
|
||||
def close(self):
|
||||
self.socket.shutdown(socket.SHUT_RDWR)
|
||||
self.socket.close()
|
||||
|
||||
def _recvall(self, count):
|
||||
"""Ensure to read count bytes from the socket"""
|
||||
data = b""
|
||||
while len(data) < count:
|
||||
buf = self.socket.recv(count - len(data))
|
||||
if not buf:
|
||||
raise socket.error('Connection closed')
|
||||
data += buf
|
||||
return data
|
||||
|
||||
|
||||
class Packet(object):
|
||||
CMD_REQUEST = 0 # Named request message
|
||||
CMD_RESPONSE = 1 # Unnamed response message for a request
|
||||
CMD_UNKNOWN = 2 # Unnamed response if requested command is unknown
|
||||
EVENT_REGISTER = 3 # Named event registration request
|
||||
EVENT_UNREGISTER = 4 # Named event de-registration request
|
||||
EVENT_CONFIRM = 5 # Unnamed confirmation for event (de-)registration
|
||||
EVENT_UNKNOWN = 6 # Unnamed response if event (de-)registration failed
|
||||
EVENT = 7 # Named event message
|
||||
|
||||
ParsedPacket = namedtuple(
|
||||
"ParsedPacket",
|
||||
["response_type", "payload"]
|
||||
)
|
||||
|
||||
ParsedEventPacket = namedtuple(
|
||||
"ParsedEventPacket",
|
||||
["response_type", "event_type", "payload"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _named_request(cls, request_type, request, message=None):
|
||||
request = request.encode("UTF-8")
|
||||
payload = struct.pack("!BB", request_type, len(request)) + request
|
||||
if message is not None:
|
||||
return payload + message
|
||||
else:
|
||||
return payload
|
||||
|
||||
@classmethod
|
||||
def request(cls, command, message=None):
|
||||
return cls._named_request(cls.CMD_REQUEST, command, message)
|
||||
|
||||
@classmethod
|
||||
def register_event(cls, event_type):
|
||||
return cls._named_request(cls.EVENT_REGISTER, event_type)
|
||||
|
||||
@classmethod
|
||||
def unregister_event(cls, event_type):
|
||||
return cls._named_request(cls.EVENT_UNREGISTER, event_type)
|
||||
|
||||
@classmethod
|
||||
def parse(cls, packet):
|
||||
stream = FiniteStream(packet)
|
||||
response_type, = struct.unpack("!B", stream.read(1))
|
||||
|
||||
if response_type == cls.EVENT:
|
||||
length, = struct.unpack("!B", stream.read(1))
|
||||
event_type = stream.read(length)
|
||||
return cls.ParsedEventPacket(response_type, event_type, stream)
|
||||
else:
|
||||
return cls.ParsedPacket(response_type, stream)
|
||||
|
||||
|
||||
class Message(object):
|
||||
SECTION_START = 1 # Begin a new section having a name
|
||||
SECTION_END = 2 # End a previously started section
|
||||
KEY_VALUE = 3 # Define a value for a named key in the section
|
||||
LIST_START = 4 # Begin a named list for list items
|
||||
LIST_ITEM = 5 # Define an unnamed item value in the current list
|
||||
LIST_END = 6 # End a previously started list
|
||||
|
||||
@classmethod
|
||||
def serialize(cls, message):
|
||||
def encode_named_type(marker, name):
|
||||
name = name.encode("UTF-8")
|
||||
return struct.pack("!BB", marker, len(name)) + name
|
||||
|
||||
def encode_blob(value):
|
||||
if not isinstance(value, bytes):
|
||||
value = str(value).encode("UTF-8")
|
||||
return struct.pack("!H", len(value)) + value
|
||||
|
||||
def serialize_list(lst):
|
||||
segment = bytes()
|
||||
for item in lst:
|
||||
segment += struct.pack("!B", cls.LIST_ITEM) + encode_blob(item)
|
||||
return segment
|
||||
|
||||
def serialize_dict(d):
|
||||
segment = bytes()
|
||||
for key, value in iteritems(d):
|
||||
if isinstance(value, dict):
|
||||
segment += (
|
||||
encode_named_type(cls.SECTION_START, key)
|
||||
+ serialize_dict(value)
|
||||
+ struct.pack("!B", cls.SECTION_END)
|
||||
)
|
||||
elif isinstance(value, list):
|
||||
segment += (
|
||||
encode_named_type(cls.LIST_START, key)
|
||||
+ serialize_list(value)
|
||||
+ struct.pack("!B", cls.LIST_END)
|
||||
)
|
||||
else:
|
||||
segment += (
|
||||
encode_named_type(cls.KEY_VALUE, key)
|
||||
+ encode_blob(value)
|
||||
)
|
||||
return segment
|
||||
|
||||
return serialize_dict(message)
|
||||
|
||||
@classmethod
|
||||
def deserialize(cls, stream):
|
||||
def decode_named_type(stream):
|
||||
length, = struct.unpack("!B", stream.read(1))
|
||||
return stream.read(length).decode("UTF-8")
|
||||
|
||||
def decode_blob(stream):
|
||||
length, = struct.unpack("!H", stream.read(2))
|
||||
return stream.read(length)
|
||||
|
||||
def decode_list_item(stream):
|
||||
marker, = struct.unpack("!B", stream.read(1))
|
||||
while marker == cls.LIST_ITEM:
|
||||
yield decode_blob(stream)
|
||||
marker, = struct.unpack("!B", stream.read(1))
|
||||
|
||||
if marker != cls.LIST_END:
|
||||
raise DeserializationException(
|
||||
"Expected end of list at {pos}".format(pos=stream.tell())
|
||||
)
|
||||
|
||||
section = OrderedDict()
|
||||
section_stack = []
|
||||
while stream.has_more():
|
||||
element_type, = struct.unpack("!B", stream.read(1))
|
||||
if element_type == cls.SECTION_START:
|
||||
section_name = decode_named_type(stream)
|
||||
new_section = OrderedDict()
|
||||
section[section_name] = new_section
|
||||
section_stack.append(section)
|
||||
section = new_section
|
||||
|
||||
elif element_type == cls.LIST_START:
|
||||
list_name = decode_named_type(stream)
|
||||
section[list_name] = [item for item in decode_list_item(stream)]
|
||||
|
||||
elif element_type == cls.KEY_VALUE:
|
||||
key = decode_named_type(stream)
|
||||
section[key] = decode_blob(stream)
|
||||
|
||||
elif element_type == cls.SECTION_END:
|
||||
if len(section_stack):
|
||||
section = section_stack.pop()
|
||||
else:
|
||||
raise DeserializationException(
|
||||
"Unexpected end of section at {pos}".format(
|
||||
pos=stream.tell()
|
||||
)
|
||||
)
|
||||
|
||||
if len(section_stack):
|
||||
raise DeserializationException("Expected end of section")
|
||||
return section
|
||||
|
||||
|
||||
class FiniteStream(io.BytesIO):
|
||||
def __len__(self):
|
||||
return len(self.getvalue())
|
||||
|
||||
def has_more(self):
|
||||
return self.tell() < len(self)
|
||||
@ -1,388 +0,0 @@
|
||||
import collections
|
||||
import socket
|
||||
|
||||
from .exception import SessionException, CommandException, EventUnknownException
|
||||
from .protocol import Transport, Packet, Message
|
||||
|
||||
|
||||
class Session(object):
|
||||
def __init__(self, sock=None):
|
||||
if sock is None:
|
||||
sock = socket.socket(socket.AF_UNIX)
|
||||
sock.connect("/var/run/charon.vici")
|
||||
self.handler = SessionHandler(Transport(sock))
|
||||
|
||||
def version(self):
|
||||
"""Retrieve daemon and system specific version information.
|
||||
|
||||
:return: daemon and system specific version information
|
||||
:rtype: dict
|
||||
"""
|
||||
return self.handler.request("version")
|
||||
|
||||
def stats(self):
|
||||
"""Retrieve IKE daemon statistics and load information.
|
||||
|
||||
:return: IKE daemon statistics and load information
|
||||
:rtype: dict
|
||||
"""
|
||||
return self.handler.request("stats")
|
||||
|
||||
def reload_settings(self):
|
||||
"""Reload strongswan.conf settings and any plugins supporting reload.
|
||||
"""
|
||||
self.handler.request("reload-settings")
|
||||
|
||||
def initiate(self, sa):
|
||||
"""Initiate an SA.
|
||||
|
||||
:param sa: the SA to initiate
|
||||
:type sa: dict
|
||||
:return: generator for logs emitted as dict
|
||||
:rtype: generator
|
||||
"""
|
||||
return self.handler.streamed_request("initiate", "control-log", sa)
|
||||
|
||||
def terminate(self, sa):
|
||||
"""Terminate an SA.
|
||||
|
||||
:param sa: the SA to terminate
|
||||
:type sa: dict
|
||||
:return: generator for logs emitted as dict
|
||||
:rtype: generator
|
||||
"""
|
||||
return self.handler.streamed_request("terminate", "control-log", sa)
|
||||
|
||||
def redirect(self, sa):
|
||||
"""Redirect an IKE_SA.
|
||||
|
||||
:param sa: the SA to redirect
|
||||
:type sa: dict
|
||||
"""
|
||||
self.handler.request("redirect", sa)
|
||||
|
||||
def install(self, policy):
|
||||
"""Install a trap, drop or bypass policy defined by a CHILD_SA config.
|
||||
|
||||
:param policy: policy to install
|
||||
:type policy: dict
|
||||
"""
|
||||
self.handler.request("install", policy)
|
||||
|
||||
def uninstall(self, policy):
|
||||
"""Uninstall a trap, drop or bypass policy defined by a CHILD_SA config.
|
||||
|
||||
:param policy: policy to uninstall
|
||||
:type policy: dict
|
||||
"""
|
||||
self.handler.request("uninstall", policy)
|
||||
|
||||
def list_sas(self, filters=None):
|
||||
"""Retrieve active IKE_SAs and associated CHILD_SAs.
|
||||
|
||||
:param filters: retrieve only matching IKE_SAs (optional)
|
||||
:type filters: dict
|
||||
:return: generator for active IKE_SAs and associated CHILD_SAs as dict
|
||||
:rtype: generator
|
||||
"""
|
||||
return self.handler.streamed_request("list-sas", "list-sa", filters)
|
||||
|
||||
def list_policies(self, filters=None):
|
||||
"""Retrieve installed trap, drop and bypass policies.
|
||||
|
||||
:param filters: retrieve only matching policies (optional)
|
||||
:type filters: dict
|
||||
:return: generator for installed trap, drop and bypass policies as dict
|
||||
:rtype: generator
|
||||
"""
|
||||
return self.handler.streamed_request("list-policies", "list-policy",
|
||||
filters)
|
||||
|
||||
def list_conns(self, filters=None):
|
||||
"""Retrieve loaded connections.
|
||||
|
||||
:param filters: retrieve only matching configuration names (optional)
|
||||
:type filters: dict
|
||||
:return: generator for loaded connections as dict
|
||||
:rtype: generator
|
||||
"""
|
||||
return self.handler.streamed_request("list-conns", "list-conn",
|
||||
filters)
|
||||
|
||||
def get_conns(self):
|
||||
"""Retrieve connection names loaded exclusively over vici.
|
||||
|
||||
:return: connection names
|
||||
:rtype: dict
|
||||
"""
|
||||
return self.handler.request("get-conns")
|
||||
|
||||
def list_certs(self, filters=None):
|
||||
"""Retrieve loaded certificates.
|
||||
|
||||
:param filters: retrieve only matching certificates (optional)
|
||||
:type filters: dict
|
||||
:return: generator for loaded certificates as dict
|
||||
:rtype: generator
|
||||
"""
|
||||
return self.handler.streamed_request("list-certs", "list-cert", filters)
|
||||
|
||||
def load_conn(self, connection):
|
||||
"""Load a connection definition into the daemon.
|
||||
|
||||
:param connection: connection definition
|
||||
:type connection: dict
|
||||
"""
|
||||
self.handler.request("load-conn", connection)
|
||||
|
||||
def unload_conn(self, name):
|
||||
"""Unload a connection definition.
|
||||
|
||||
:param name: connection definition name
|
||||
:type name: dict
|
||||
"""
|
||||
self.handler.request("unload-conn", name)
|
||||
|
||||
def load_cert(self, certificate):
|
||||
"""Load a certificate into the daemon.
|
||||
|
||||
:param certificate: PEM or DER encoded certificate
|
||||
:type certificate: dict
|
||||
"""
|
||||
self.handler.request("load-cert", certificate)
|
||||
|
||||
def load_key(self, private_key):
|
||||
"""Load a private key into the daemon.
|
||||
|
||||
:param private_key: PEM or DER encoded key
|
||||
"""
|
||||
self.handler.request("load-key", private_key)
|
||||
|
||||
def load_shared(self, secret):
|
||||
"""Load a shared IKE PSK, EAP or XAuth secret into the daemon.
|
||||
|
||||
:param secret: shared IKE PSK, EAP or XAuth secret
|
||||
:type secret: dict
|
||||
"""
|
||||
self.handler.request("load-shared", secret)
|
||||
|
||||
def flush_certs(self, filter=None):
|
||||
"""Flush the volatile certificate cache.
|
||||
|
||||
Flush the certificate stored temporarily in the cache. The filter
|
||||
allows to flush only a certain type of certificates, e.g. CRLs.
|
||||
|
||||
:param filter: flush only certificates of a given type (optional)
|
||||
:type filter: dict
|
||||
"""
|
||||
self.handler.request("flush-certs", filter)
|
||||
|
||||
def clear_creds(self):
|
||||
"""Clear credentials loaded over vici.
|
||||
|
||||
Clear all loaded certificate, private key and shared key credentials.
|
||||
This affects only credentials loaded over vici, but additionally
|
||||
flushes the credential cache.
|
||||
"""
|
||||
self.handler.request("clear-creds")
|
||||
|
||||
def load_pool(self, pool):
|
||||
"""Load a virtual IP pool.
|
||||
|
||||
Load an in-memory virtual IP and configuration attribute pool.
|
||||
Existing pools with the same name get updated, if possible.
|
||||
|
||||
:param pool: virtual IP and configuration attribute pool
|
||||
:type pool: dict
|
||||
"""
|
||||
return self.handler.request("load-pool", pool)
|
||||
|
||||
def unload_pool(self, pool_name):
|
||||
"""Unload a virtual IP pool.
|
||||
|
||||
Unload a previously loaded virtual IP and configuration attribute pool.
|
||||
Unloading fails for pools with leases currently online.
|
||||
|
||||
:param pool_name: pool by name
|
||||
:type pool_name: dict
|
||||
"""
|
||||
self.handler.request("unload-pool", pool_name)
|
||||
|
||||
def get_pools(self, options):
|
||||
"""Retrieve loaded pools.
|
||||
|
||||
:param options: filter by name and/or retrieve leases (optional)
|
||||
:type options: dict
|
||||
:return: loaded pools
|
||||
:rtype: dict
|
||||
"""
|
||||
return self.handler.request("get-pools", options)
|
||||
|
||||
def listen(self, event_types):
|
||||
"""Register and listen for the given events.
|
||||
|
||||
:param event_types: event types to register
|
||||
:type event_types: list
|
||||
:return: generator for streamed event responses as (event_type, dict)
|
||||
:rtype: generator
|
||||
"""
|
||||
return self.handler.listen(event_types)
|
||||
|
||||
|
||||
class SessionHandler(object):
|
||||
"""Handles client command execution requests over vici."""
|
||||
|
||||
def __init__(self, transport):
|
||||
self.transport = transport
|
||||
|
||||
def _communicate(self, packet):
|
||||
"""Send packet over transport and parse response.
|
||||
|
||||
:param packet: packet to send
|
||||
:type packet: :py:class:`vici.protocol.Packet`
|
||||
:return: parsed packet in a tuple with message type and payload
|
||||
:rtype: :py:class:`collections.namedtuple`
|
||||
"""
|
||||
self.transport.send(packet)
|
||||
return Packet.parse(self.transport.receive())
|
||||
|
||||
def _register_unregister(self, event_type, register):
|
||||
"""Register or unregister for the given event.
|
||||
|
||||
:param event_type: event to register
|
||||
:type event_type: str
|
||||
:param register: whether to register or unregister
|
||||
:type register: bool
|
||||
"""
|
||||
if register:
|
||||
packet = Packet.register_event(event_type)
|
||||
else:
|
||||
packet = Packet.unregister_event(event_type)
|
||||
response = self._communicate(packet)
|
||||
if response.response_type == Packet.EVENT_UNKNOWN:
|
||||
raise EventUnknownException(
|
||||
"Unknown event type '{event}'".format(event=event_type)
|
||||
)
|
||||
elif response.response_type != Packet.EVENT_CONFIRM:
|
||||
raise SessionException(
|
||||
"Unexpected response type {type}, "
|
||||
"expected '{confirm}' (EVENT_CONFIRM)".format(
|
||||
type=response.response_type,
|
||||
confirm=Packet.EVENT_CONFIRM,
|
||||
)
|
||||
)
|
||||
|
||||
def request(self, command, message=None):
|
||||
"""Send request with an optional message.
|
||||
|
||||
:param command: command to send
|
||||
:type command: str
|
||||
:param message: message (optional)
|
||||
:type message: str
|
||||
:return: command result
|
||||
:rtype: dict
|
||||
"""
|
||||
if message is not None:
|
||||
message = Message.serialize(message)
|
||||
packet = Packet.request(command, message)
|
||||
response = self._communicate(packet)
|
||||
|
||||
if response.response_type != Packet.CMD_RESPONSE:
|
||||
raise SessionException(
|
||||
"Unexpected response type {type}, "
|
||||
"expected '{response}' (CMD_RESPONSE)".format(
|
||||
type=response.response_type,
|
||||
response=Packet.CMD_RESPONSE
|
||||
)
|
||||
)
|
||||
|
||||
command_response = Message.deserialize(response.payload)
|
||||
if "success" in command_response:
|
||||
if command_response["success"] != b"yes":
|
||||
raise CommandException(
|
||||
"Command failed: {errmsg}".format(
|
||||
errmsg=command_response["errmsg"]
|
||||
)
|
||||
)
|
||||
|
||||
return command_response
|
||||
|
||||
def streamed_request(self, command, event_stream_type, message=None):
|
||||
"""Send command request and collect and return all emitted events.
|
||||
|
||||
:param command: command to send
|
||||
:type command: str
|
||||
:param event_stream_type: event type emitted on command execution
|
||||
:type event_stream_type: str
|
||||
:param message: message (optional)
|
||||
:type message: str
|
||||
:return: generator for streamed event responses as dict
|
||||
:rtype: generator
|
||||
"""
|
||||
if message is not None:
|
||||
message = Message.serialize(message)
|
||||
|
||||
self._register_unregister(event_stream_type, True);
|
||||
|
||||
try:
|
||||
packet = Packet.request(command, message)
|
||||
self.transport.send(packet)
|
||||
exited = False
|
||||
while True:
|
||||
response = Packet.parse(self.transport.receive())
|
||||
if response.response_type == Packet.EVENT:
|
||||
if not exited:
|
||||
try:
|
||||
yield Message.deserialize(response.payload)
|
||||
except GeneratorExit:
|
||||
exited = True
|
||||
pass
|
||||
else:
|
||||
break
|
||||
|
||||
if response.response_type == Packet.CMD_RESPONSE:
|
||||
command_response = Message.deserialize(response.payload)
|
||||
else:
|
||||
raise SessionException(
|
||||
"Unexpected response type {type}, "
|
||||
"expected '{response}' (CMD_RESPONSE)".format(
|
||||
type=response.response_type,
|
||||
response=Packet.CMD_RESPONSE
|
||||
)
|
||||
)
|
||||
|
||||
finally:
|
||||
self._register_unregister(event_stream_type, False);
|
||||
|
||||
# evaluate command result, if any
|
||||
if "success" in command_response:
|
||||
if command_response["success"] != b"yes":
|
||||
raise CommandException(
|
||||
"Command failed: {errmsg}".format(
|
||||
errmsg=command_response["errmsg"]
|
||||
)
|
||||
)
|
||||
|
||||
def listen(self, event_types):
|
||||
"""Register and listen for the given events.
|
||||
|
||||
:param event_types: event types to register
|
||||
:type event_types: list
|
||||
:return: generator for streamed event responses as (event_type, dict)
|
||||
:rtype: generator
|
||||
"""
|
||||
for event_type in event_types:
|
||||
self._register_unregister(event_type, True)
|
||||
|
||||
try:
|
||||
while True:
|
||||
response = Packet.parse(self.transport.receive())
|
||||
if response.response_type == Packet.EVENT:
|
||||
try:
|
||||
yield response.event_type, Message.deserialize(response.payload)
|
||||
except GeneratorExit:
|
||||
break
|
||||
|
||||
finally:
|
||||
for event_type in event_types:
|
||||
self._register_unregister(event_type, False)
|
||||
@ -1,144 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from ..protocol import Packet, Message, FiniteStream
|
||||
from ..exception import DeserializationException
|
||||
|
||||
|
||||
class TestPacket(object):
|
||||
# test data definitions for outgoing packet types
|
||||
cmd_request = b"\x00\x0c" b"command_type"
|
||||
cmd_request_msg = b"\x00\x07" b"command" b"payload"
|
||||
event_register = b"\x03\x0a" b"event_type"
|
||||
event_unregister = b"\x04\x0a" b"event_type"
|
||||
|
||||
# test data definitions for incoming packet types
|
||||
cmd_response = b"\x01" b"reply"
|
||||
cmd_unknown = b"\x02"
|
||||
event_confirm = b"\x05"
|
||||
event_unknown = b"\x06"
|
||||
event = b"\x07\x03" b"log" b"message"
|
||||
|
||||
def test_request(self):
|
||||
assert Packet.request("command_type") == self.cmd_request
|
||||
assert Packet.request("command", b"payload") == self.cmd_request_msg
|
||||
|
||||
def test_register_event(self):
|
||||
assert Packet.register_event("event_type") == self.event_register
|
||||
|
||||
def test_unregister_event(self):
|
||||
assert Packet.unregister_event("event_type") == self.event_unregister
|
||||
|
||||
def test_parse(self):
|
||||
parsed_cmd_response = Packet.parse(self.cmd_response)
|
||||
assert parsed_cmd_response.response_type == Packet.CMD_RESPONSE
|
||||
assert parsed_cmd_response.payload.getvalue() == self.cmd_response
|
||||
|
||||
parsed_cmd_unknown = Packet.parse(self.cmd_unknown)
|
||||
assert parsed_cmd_unknown.response_type == Packet.CMD_UNKNOWN
|
||||
assert parsed_cmd_unknown.payload.getvalue() == self.cmd_unknown
|
||||
|
||||
parsed_event_confirm = Packet.parse(self.event_confirm)
|
||||
assert parsed_event_confirm.response_type == Packet.EVENT_CONFIRM
|
||||
assert parsed_event_confirm.payload.getvalue() == self.event_confirm
|
||||
|
||||
parsed_event_unknown = Packet.parse(self.event_unknown)
|
||||
assert parsed_event_unknown.response_type == Packet.EVENT_UNKNOWN
|
||||
assert parsed_event_unknown.payload.getvalue() == self.event_unknown
|
||||
|
||||
parsed_event = Packet.parse(self.event)
|
||||
assert parsed_event.response_type == Packet.EVENT
|
||||
assert parsed_event.payload.getvalue() == self.event
|
||||
|
||||
|
||||
class TestMessage(object):
|
||||
"""Message (de)serialization test."""
|
||||
|
||||
# data definitions for test of de(serialization)
|
||||
# serialized messages holding a section
|
||||
ser_sec_unclosed = b"\x01\x08unclosed"
|
||||
ser_sec_single = b"\x01\x07section\x02"
|
||||
ser_sec_nested = b"\x01\x05outer\x01\x0asubsection\x02\x02"
|
||||
|
||||
# serialized messages holding a list
|
||||
ser_list_invalid = b"\x04\x07invalid\x05\x00\x02e1\x02\x03sec\x06"
|
||||
ser_list_0_item = b"\x04\x05empty\x06"
|
||||
ser_list_1_item = b"\x04\x01l\x05\x00\x02e1\x06"
|
||||
ser_list_2_item = b"\x04\x01l\x05\x00\x02e1\x05\x00\x02e2\x06"
|
||||
|
||||
# serialized messages with key value pairs
|
||||
ser_kv_pair = b"\x03\x03key\x00\x05value"
|
||||
ser_kv_zero = b"\x03\x0azerolength\x00\x00"
|
||||
|
||||
# deserialized messages holding a section
|
||||
des_sec_single = { "section": {} }
|
||||
des_sec_nested = { "outer": { "subsection": {} } }
|
||||
|
||||
# deserialized messages holding a list
|
||||
des_list_0_item = { "empty": [] }
|
||||
des_list_1_item = { "l": [ b"e1" ] }
|
||||
des_list_2_item = { "l": [ b"e1", b"e2" ] }
|
||||
|
||||
# deserialized messages with key value pairs
|
||||
des_kv_pair = { "key": b"value" }
|
||||
des_kv_zero = { "zerolength": b"" }
|
||||
|
||||
def test_section_serialization(self):
|
||||
assert Message.serialize(self.des_sec_single) == self.ser_sec_single
|
||||
assert Message.serialize(self.des_sec_nested) == self.ser_sec_nested
|
||||
|
||||
def test_list_serialization(self):
|
||||
assert Message.serialize(self.des_list_0_item) == self.ser_list_0_item
|
||||
assert Message.serialize(self.des_list_1_item) == self.ser_list_1_item
|
||||
assert Message.serialize(self.des_list_2_item) == self.ser_list_2_item
|
||||
|
||||
def test_key_serialization(self):
|
||||
assert Message.serialize(self.des_kv_pair) == self.ser_kv_pair
|
||||
assert Message.serialize(self.des_kv_zero) == self.ser_kv_zero
|
||||
|
||||
def test_section_deserialization(self):
|
||||
single = Message.deserialize(FiniteStream(self.ser_sec_single))
|
||||
nested = Message.deserialize(FiniteStream(self.ser_sec_nested))
|
||||
|
||||
assert single == self.des_sec_single
|
||||
assert nested == self.des_sec_nested
|
||||
|
||||
with pytest.raises(DeserializationException):
|
||||
Message.deserialize(FiniteStream(self.ser_sec_unclosed))
|
||||
|
||||
def test_list_deserialization(self):
|
||||
l0 = Message.deserialize(FiniteStream(self.ser_list_0_item))
|
||||
l1 = Message.deserialize(FiniteStream(self.ser_list_1_item))
|
||||
l2 = Message.deserialize(FiniteStream(self.ser_list_2_item))
|
||||
|
||||
assert l0 == self.des_list_0_item
|
||||
assert l1 == self.des_list_1_item
|
||||
assert l2 == self.des_list_2_item
|
||||
|
||||
with pytest.raises(DeserializationException):
|
||||
Message.deserialize(FiniteStream(self.ser_list_invalid))
|
||||
|
||||
def test_key_deserialization(self):
|
||||
pair = Message.deserialize(FiniteStream(self.ser_kv_pair))
|
||||
zerolength = Message.deserialize(FiniteStream(self.ser_kv_zero))
|
||||
|
||||
assert pair == self.des_kv_pair
|
||||
assert zerolength == self.des_kv_zero
|
||||
|
||||
def test_roundtrip(self):
|
||||
message = {
|
||||
"key1": "value1",
|
||||
"section1": {
|
||||
"sub-section": {
|
||||
"key2": b"value2",
|
||||
},
|
||||
"list1": [ "item1", "item2" ],
|
||||
},
|
||||
}
|
||||
serialized_message = FiniteStream(Message.serialize(message))
|
||||
deserialized_message = Message.deserialize(serialized_message)
|
||||
|
||||
# ensure that list items and key values remain as undecoded bytes
|
||||
deserialized_section = deserialized_message["section1"]
|
||||
assert deserialized_message["key1"] == b"value1"
|
||||
assert deserialized_section["sub-section"]["key2"] == b"value2"
|
||||
assert deserialized_section["list1"] == [ b"item1", b"item2" ]
|
||||
Loading…
x
Reference in New Issue
Block a user