(ipsec) add status call using vici, related to https://github.com/opnsense/core/issues/139

First step is to switch the current status page away from legacy smp.
(work in progress)
This commit is contained in:
Ad Schellevis 2015-11-05 08:06:12 +00:00
parent 249c74c6d6
commit a87623867b
9 changed files with 811 additions and 0 deletions

View File

@ -0,0 +1,71 @@
#!/usr/local/bin/python2.7
"""
Copyright (c) 2015 Ad Schellevis
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice,
this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES,
INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY
AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY,
OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
POSSIBILITY OF SUCH DAMAGE.
--------------------------------------------------------------------------------------
list ipsec status, using vici interface
"""
import ujson
import vici
s = vici.Session()
result = dict()
# parse connections
for conns in s.list_conns():
for connection_id in conns:
result[connection_id] = dict()
result[connection_id]['version'] = conns[connection_id]['version']
result[connection_id]['local_addrs'] = ','.join(conns[connection_id]['local_addrs'])
result[connection_id]['local-id'] = ''
result[connection_id]['local-class'] = []
result[connection_id]['remote-id'] = ''
result[connection_id]['remote-class'] = []
result[connection_id]['children']= conns[connection_id]['children']
result[connection_id]['sas'] = []
# parse local-% and remote-% keys
for connKey in conns[connection_id].keys():
if connKey.find('local-') == 0:
if 'id' in conns[connection_id][connKey]:
result[connection_id]['local-id'] = conns[connection_id][connKey]['id']
result[connection_id]['local-class'].append(conns[connection_id][connKey]['class'])
elif connKey.find('remote-') == 0:
if 'id' in conns[connection_id][connKey]:
result[connection_id]['remote-id'] = conns[connection_id][connKey]['id']
result[connection_id]['remote-class'].append(conns[connection_id][connKey]['class'])
result[connection_id]['local-class'] = '+'.join(result[connection_id]['local-class'])
result[connection_id]['remote-class'] = '+'.join(result[connection_id]['remote-class'])
# attach Security Associations
for sas in s.list_sas():
for sa in sas:
if sa in result:
result[sa]['sas'].append(sas[sa])
print(ujson.dumps(result))

View File

@ -0,0 +1 @@
from .session import Session

View File

@ -0,0 +1,14 @@
# Help functions for compatibility between python version 2 and 3
# From http://legacy.python.org/dev/peps/pep-0469
try:
dict.iteritems
except AttributeError:
# python 3
def iteritems(d):
return iter(d.items())
else:
# python 2
def iteritems(d):
return d.iteritems()

View File

@ -0,0 +1,13 @@
"""Exception types that may be thrown by this library."""
class DeserializationException(Exception):
"""Encountered an unexpected byte sequence or missing element type."""
class SessionException(Exception):
"""Session request exception."""
class CommandException(Exception):
"""Command result exception."""
class EventUnknownException(Exception):
"""Event unknown exception."""

View File

@ -0,0 +1,196 @@
import io
import socket
import struct
from collections import namedtuple
from collections import OrderedDict
from .compat import iteritems
from .exception import DeserializationException
class Transport(object):
HEADER_LENGTH = 4
MAX_SEGMENT = 512 * 1024
def __init__(self, sock):
self.socket = sock
def send(self, packet):
self.socket.sendall(struct.pack("!I", len(packet)) + packet)
def receive(self):
raw_length = self.socket.recv(self.HEADER_LENGTH)
length, = struct.unpack("!I", raw_length)
payload = self.socket.recv(length)
return payload
def close(self):
self.socket.shutdown(socket.SHUT_RDWR)
self.socket.close()
class Packet(object):
CMD_REQUEST = 0 # Named request message
CMD_RESPONSE = 1 # Unnamed response message for a request
CMD_UNKNOWN = 2 # Unnamed response if requested command is unknown
EVENT_REGISTER = 3 # Named event registration request
EVENT_UNREGISTER = 4 # Named event de-registration request
EVENT_CONFIRM = 5 # Unnamed confirmation for event (de-)registration
EVENT_UNKNOWN = 6 # Unnamed response if event (de-)registration failed
EVENT = 7 # Named event message
ParsedPacket = namedtuple(
"ParsedPacket",
["response_type", "payload"]
)
ParsedEventPacket = namedtuple(
"ParsedEventPacket",
["response_type", "event_type", "payload"]
)
@classmethod
def _named_request(cls, request_type, request, message=None):
request = request.encode()
payload = struct.pack("!BB", request_type, len(request)) + request
if message is not None:
return payload + message
else:
return payload
@classmethod
def request(cls, command, message=None):
return cls._named_request(cls.CMD_REQUEST, command, message)
@classmethod
def register_event(cls, event_type):
return cls._named_request(cls.EVENT_REGISTER, event_type)
@classmethod
def unregister_event(cls, event_type):
return cls._named_request(cls.EVENT_UNREGISTER, event_type)
@classmethod
def parse(cls, packet):
stream = FiniteStream(packet)
response_type, = struct.unpack("!B", stream.read(1))
if response_type == cls.EVENT:
length, = struct.unpack("!B", stream.read(1))
event_type = stream.read(length)
return cls.ParsedEventPacket(response_type, event_type, stream)
else:
return cls.ParsedPacket(response_type, stream)
class Message(object):
SECTION_START = 1 # Begin a new section having a name
SECTION_END = 2 # End a previously started section
KEY_VALUE = 3 # Define a value for a named key in the section
LIST_START = 4 # Begin a named list for list items
LIST_ITEM = 5 # Define an unnamed item value in the current list
LIST_END = 6 # End a previously started list
@classmethod
def serialize(cls, message):
def encode_named_type(marker, name):
name = name.encode()
return struct.pack("!BB", marker, len(name)) + name
def encode_blob(value):
if not isinstance(value, bytes):
value = str(value).encode()
return struct.pack("!H", len(value)) + value
def serialize_list(lst):
segment = bytes()
for item in lst:
segment += struct.pack("!B", cls.LIST_ITEM) + encode_blob(item)
return segment
def serialize_dict(d):
segment = bytes()
for key, value in iteritems(d):
if isinstance(value, dict):
segment += (
encode_named_type(cls.SECTION_START, key)
+ serialize_dict(value)
+ struct.pack("!B", cls.SECTION_END)
)
elif isinstance(value, list):
segment += (
encode_named_type(cls.LIST_START, key)
+ serialize_list(value)
+ struct.pack("!B", cls.LIST_END)
)
else:
segment += (
encode_named_type(cls.KEY_VALUE, key)
+ encode_blob(value)
)
return segment
return serialize_dict(message)
@classmethod
def deserialize(cls, stream):
def decode_named_type(stream):
length, = struct.unpack("!B", stream.read(1))
return stream.read(length).decode()
def decode_blob(stream):
length, = struct.unpack("!H", stream.read(2))
return stream.read(length)
def decode_list_item(stream):
marker, = struct.unpack("!B", stream.read(1))
while marker == cls.LIST_ITEM:
yield decode_blob(stream)
marker, = struct.unpack("!B", stream.read(1))
if marker != cls.LIST_END:
raise DeserializationException(
"Expected end of list at {pos}".format(pos=stream.tell())
)
section = OrderedDict()
section_stack = []
while stream.has_more():
element_type, = struct.unpack("!B", stream.read(1))
if element_type == cls.SECTION_START:
section_name = decode_named_type(stream)
new_section = OrderedDict()
section[section_name] = new_section
section_stack.append(section)
section = new_section
elif element_type == cls.LIST_START:
list_name = decode_named_type(stream)
section[list_name] = [item for item in decode_list_item(stream)]
elif element_type == cls.KEY_VALUE:
key = decode_named_type(stream)
section[key] = decode_blob(stream)
elif element_type == cls.SECTION_END:
if len(section_stack):
section = section_stack.pop()
else:
raise DeserializationException(
"Unexpected end of section at {pos}".format(
pos=stream.tell()
)
)
if len(section_stack):
raise DeserializationException("Expected end of section")
return section
class FiniteStream(io.BytesIO):
def __len__(self):
return len(self.getvalue())
def has_more(self):
return self.tell() < len(self)

View File

@ -0,0 +1,367 @@
import collections
import socket
from .exception import SessionException, CommandException, EventUnknownException
from .protocol import Transport, Packet, Message
class Session(object):
def __init__(self, sock=None):
if sock is None:
sock = socket.socket(socket.AF_UNIX)
sock.connect("/var/run/charon.vici")
self.handler = SessionHandler(Transport(sock))
def version(self):
"""Retrieve daemon and system specific version information.
:return: daemon and system specific version information
:rtype: dict
"""
return self.handler.request("version")
def stats(self):
"""Retrieve IKE daemon statistics and load information.
:return: IKE daemon statistics and load information
:rtype: dict
"""
return self.handler.request("stats")
def reload_settings(self):
"""Reload strongswan.conf settings and any plugins supporting reload.
"""
self.handler.request("reload-settings")
def initiate(self, sa):
"""Initiate an SA.
:param sa: the SA to initiate
:type sa: dict
:return: generator for logs emitted as dict
:rtype: generator
"""
return self.handler.streamed_request("initiate", "control-log", sa)
def terminate(self, sa):
"""Terminate an SA.
:param sa: the SA to terminate
:type sa: dict
:return: generator for logs emitted as dict
:rtype: generator
"""
return self.handler.streamed_request("terminate", "control-log", sa)
def install(self, policy):
"""Install a trap, drop or bypass policy defined by a CHILD_SA config.
:param policy: policy to install
:type policy: dict
"""
self.handler.request("install", policy)
def uninstall(self, policy):
"""Uninstall a trap, drop or bypass policy defined by a CHILD_SA config.
:param policy: policy to uninstall
:type policy: dict
"""
self.handler.request("uninstall", policy)
def list_sas(self, filters=None):
"""Retrieve active IKE_SAs and associated CHILD_SAs.
:param filters: retrieve only matching IKE_SAs (optional)
:type filters: dict
:return: generator for active IKE_SAs and associated CHILD_SAs as dict
:rtype: generator
"""
return self.handler.streamed_request("list-sas", "list-sa", filters)
def list_policies(self, filters=None):
"""Retrieve installed trap, drop and bypass policies.
:param filters: retrieve only matching policies (optional)
:type filters: dict
:return: generator for installed trap, drop and bypass policies as dict
:rtype: generator
"""
return self.handler.streamed_request("list-policies", "list-policy",
filters)
def list_conns(self, filters=None):
"""Retrieve loaded connections.
:param filters: retrieve only matching configuration names (optional)
:type filters: dict
:return: generator for loaded connections as dict
:rtype: generator
"""
return self.handler.streamed_request("list-conns", "list-conn",
filters)
def get_conns(self):
"""Retrieve connection names loaded exclusively over vici.
:return: connection names
:rtype: dict
"""
return self.handler.request("get-conns")
def list_certs(self, filters=None):
"""Retrieve loaded certificates.
:param filters: retrieve only matching certificates (optional)
:type filters: dict
:return: generator for loaded certificates as dict
:rtype: generator
"""
return self.handler.streamed_request("list-certs", "list-cert", filters)
def load_conn(self, connection):
"""Load a connection definition into the daemon.
:param connection: connection definition
:type connection: dict
"""
self.handler.request("load-conn", connection)
def unload_conn(self, name):
"""Unload a connection definition.
:param name: connection definition name
:type name: dict
"""
self.handler.request("unload-conn", name)
def load_cert(self, certificate):
"""Load a certificate into the daemon.
:param certificate: PEM or DER encoded certificate
:type certificate: dict
"""
self.handler.request("load-cert", certificate)
def load_key(self, private_key):
"""Load a private key into the daemon.
:param private_key: PEM or DER encoded key
"""
self.handler.request("load-key", private_key)
def load_shared(self, secret):
"""Load a shared IKE PSK, EAP or XAuth secret into the daemon.
:param secret: shared IKE PSK, EAP or XAuth secret
:type secret: dict
"""
self.handler.request("load-shared", secret)
def clear_creds(self):
"""Clear credentials loaded over vici.
Clear all loaded certificate, private key and shared key credentials.
This affects only credentials loaded over vici, but additionally
flushes the credential cache.
"""
self.handler.request("clear-creds")
def load_pool(self, pool):
"""Load a virtual IP pool.
Load an in-memory virtual IP and configuration attribute pool.
Existing pools with the same name get updated, if possible.
:param pool: virtual IP and configuration attribute pool
:type pool: dict
"""
return self.handler.request("load-pool", pool)
def unload_pool(self, pool_name):
"""Unload a virtual IP pool.
Unload a previously loaded virtual IP and configuration attribute pool.
Unloading fails for pools with leases currently online.
:param pool_name: pool by name
:type pool_name: dict
"""
self.handler.request("unload-pool", pool_name)
def get_pools(self):
"""Retrieve loaded pools.
:return: loaded pools
:rtype: dict
"""
return self.handler.request("get-pools")
def listen(self, event_types):
"""Register and listen for the given events.
:param event_types: event types to register
:type event_types: list
:return: generator for streamed event responses as (event_type, dict)
:rtype: generator
"""
return self.handler.listen(event_types)
class SessionHandler(object):
"""Handles client command execution requests over vici."""
def __init__(self, transport):
self.transport = transport
def _communicate(self, packet):
"""Send packet over transport and parse response.
:param packet: packet to send
:type packet: :py:class:`vici.protocol.Packet`
:return: parsed packet in a tuple with message type and payload
:rtype: :py:class:`collections.namedtuple`
"""
self.transport.send(packet)
return Packet.parse(self.transport.receive())
def _register_unregister(self, event_type, register):
"""Register or unregister for the given event.
:param event_type: event to register
:type event_type: str
:param register: whether to register or unregister
:type register: bool
"""
if register:
packet = Packet.register_event(event_type)
else:
packet = Packet.unregister_event(event_type)
response = self._communicate(packet)
if response.response_type == Packet.EVENT_UNKNOWN:
raise EventUnknownException(
"Unknown event type '{event}'".format(event=event_type)
)
elif response.response_type != Packet.EVENT_CONFIRM:
raise SessionException(
"Unexpected response type {type}, "
"expected '{confirm}' (EVENT_CONFIRM)".format(
type=response.response_type,
confirm=Packet.EVENT_CONFIRM,
)
)
def request(self, command, message=None):
"""Send request with an optional message.
:param command: command to send
:type command: str
:param message: message (optional)
:type message: str
:return: command result
:rtype: dict
"""
if message is not None:
message = Message.serialize(message)
packet = Packet.request(command, message)
response = self._communicate(packet)
if response.response_type != Packet.CMD_RESPONSE:
raise SessionException(
"Unexpected response type {type}, "
"expected '{response}' (CMD_RESPONSE)".format(
type=response.response_type,
response=Packet.CMD_RESPONSE
)
)
command_response = Message.deserialize(response.payload)
if "success" in command_response:
if command_response["success"] != b"yes":
raise CommandException(
"Command failed: {errmsg}".format(
errmsg=command_response["errmsg"]
)
)
return command_response
def streamed_request(self, command, event_stream_type, message=None):
"""Send command request and collect and return all emitted events.
:param command: command to send
:type command: str
:param event_stream_type: event type emitted on command execution
:type event_stream_type: str
:param message: message (optional)
:type message: str
:return: generator for streamed event responses as dict
:rtype: generator
"""
if message is not None:
message = Message.serialize(message)
self._register_unregister(event_stream_type, True);
try:
packet = Packet.request(command, message)
self.transport.send(packet)
exited = False
while True:
response = Packet.parse(self.transport.receive())
if response.response_type == Packet.EVENT:
if not exited:
try:
yield Message.deserialize(response.payload)
except GeneratorExit:
exited = True
pass
else:
break
if response.response_type == Packet.CMD_RESPONSE:
command_response = Message.deserialize(response.payload)
else:
raise SessionException(
"Unexpected response type {type}, "
"expected '{response}' (CMD_RESPONSE)".format(
type=response.response_type,
response=Packet.CMD_RESPONSE
)
)
finally:
self._register_unregister(event_stream_type, False);
# evaluate command result, if any
if "success" in command_response:
if command_response["success"] != b"yes":
raise CommandException(
"Command failed: {errmsg}".format(
errmsg=command_response["errmsg"]
)
)
def listen(self, event_types):
"""Register and listen for the given events.
:param event_types: event types to register
:type event_types: list
:return: generator for streamed event responses as (event_type, dict)
:rtype: generator
"""
for event_type in event_types:
self._register_unregister(event_type, True)
try:
while True:
response = Packet.parse(self.transport.receive())
if response.response_type == Packet.EVENT:
try:
yield response.event_type, Message.deserialize(response.payload)
except GeneratorExit:
break
finally:
for event_type in event_types:
self._register_unregister(event_type, False)

View File

@ -0,0 +1,144 @@
import pytest
from ..protocol import Packet, Message, FiniteStream
from ..exception import DeserializationException
class TestPacket(object):
# test data definitions for outgoing packet types
cmd_request = b"\x00\x0c" b"command_type"
cmd_request_msg = b"\x00\x07" b"command" b"payload"
event_register = b"\x03\x0a" b"event_type"
event_unregister = b"\x04\x0a" b"event_type"
# test data definitions for incoming packet types
cmd_response = b"\x01" b"reply"
cmd_unknown = b"\x02"
event_confirm = b"\x05"
event_unknown = b"\x06"
event = b"\x07\x03" b"log" b"message"
def test_request(self):
assert Packet.request("command_type") == self.cmd_request
assert Packet.request("command", b"payload") == self.cmd_request_msg
def test_register_event(self):
assert Packet.register_event("event_type") == self.event_register
def test_unregister_event(self):
assert Packet.unregister_event("event_type") == self.event_unregister
def test_parse(self):
parsed_cmd_response = Packet.parse(self.cmd_response)
assert parsed_cmd_response.response_type == Packet.CMD_RESPONSE
assert parsed_cmd_response.payload.getvalue() == self.cmd_response
parsed_cmd_unknown = Packet.parse(self.cmd_unknown)
assert parsed_cmd_unknown.response_type == Packet.CMD_UNKNOWN
assert parsed_cmd_unknown.payload.getvalue() == self.cmd_unknown
parsed_event_confirm = Packet.parse(self.event_confirm)
assert parsed_event_confirm.response_type == Packet.EVENT_CONFIRM
assert parsed_event_confirm.payload.getvalue() == self.event_confirm
parsed_event_unknown = Packet.parse(self.event_unknown)
assert parsed_event_unknown.response_type == Packet.EVENT_UNKNOWN
assert parsed_event_unknown.payload.getvalue() == self.event_unknown
parsed_event = Packet.parse(self.event)
assert parsed_event.response_type == Packet.EVENT
assert parsed_event.payload.getvalue() == self.event
class TestMessage(object):
"""Message (de)serialization test."""
# data definitions for test of de(serialization)
# serialized messages holding a section
ser_sec_unclosed = b"\x01\x08unclosed"
ser_sec_single = b"\x01\x07section\x02"
ser_sec_nested = b"\x01\x05outer\x01\x0asubsection\x02\x02"
# serialized messages holding a list
ser_list_invalid = b"\x04\x07invalid\x05\x00\x02e1\x02\x03sec\x06"
ser_list_0_item = b"\x04\x05empty\x06"
ser_list_1_item = b"\x04\x01l\x05\x00\x02e1\x06"
ser_list_2_item = b"\x04\x01l\x05\x00\x02e1\x05\x00\x02e2\x06"
# serialized messages with key value pairs
ser_kv_pair = b"\x03\x03key\x00\x05value"
ser_kv_zero = b"\x03\x0azerolength\x00\x00"
# deserialized messages holding a section
des_sec_single = { "section": {} }
des_sec_nested = { "outer": { "subsection": {} } }
# deserialized messages holding a list
des_list_0_item = { "empty": [] }
des_list_1_item = { "l": [ b"e1" ] }
des_list_2_item = { "l": [ b"e1", b"e2" ] }
# deserialized messages with key value pairs
des_kv_pair = { "key": b"value" }
des_kv_zero = { "zerolength": b"" }
def test_section_serialization(self):
assert Message.serialize(self.des_sec_single) == self.ser_sec_single
assert Message.serialize(self.des_sec_nested) == self.ser_sec_nested
def test_list_serialization(self):
assert Message.serialize(self.des_list_0_item) == self.ser_list_0_item
assert Message.serialize(self.des_list_1_item) == self.ser_list_1_item
assert Message.serialize(self.des_list_2_item) == self.ser_list_2_item
def test_key_serialization(self):
assert Message.serialize(self.des_kv_pair) == self.ser_kv_pair
assert Message.serialize(self.des_kv_zero) == self.ser_kv_zero
def test_section_deserialization(self):
single = Message.deserialize(FiniteStream(self.ser_sec_single))
nested = Message.deserialize(FiniteStream(self.ser_sec_nested))
assert single == self.des_sec_single
assert nested == self.des_sec_nested
with pytest.raises(DeserializationException):
Message.deserialize(FiniteStream(self.ser_sec_unclosed))
def test_list_deserialization(self):
l0 = Message.deserialize(FiniteStream(self.ser_list_0_item))
l1 = Message.deserialize(FiniteStream(self.ser_list_1_item))
l2 = Message.deserialize(FiniteStream(self.ser_list_2_item))
assert l0 == self.des_list_0_item
assert l1 == self.des_list_1_item
assert l2 == self.des_list_2_item
with pytest.raises(DeserializationException):
Message.deserialize(FiniteStream(self.ser_list_invalid))
def test_key_deserialization(self):
pair = Message.deserialize(FiniteStream(self.ser_kv_pair))
zerolength = Message.deserialize(FiniteStream(self.ser_kv_zero))
assert pair == self.des_kv_pair
assert zerolength == self.des_kv_zero
def test_roundtrip(self):
message = {
"key1": "value1",
"section1": {
"sub-section": {
"key2": b"value2",
},
"list1": [ "item1", "item2" ],
},
}
serialized_message = FiniteStream(Message.serialize(message))
deserialized_message = Message.deserialize(serialized_message)
# ensure that list items and key values remain as undecoded bytes
deserialized_section = deserialized_message["section1"]
assert deserialized_message["key1"] == b"value1"
assert deserialized_section["sub-section"]["key2"] == b"value2"
assert deserialized_section["list1"] == [ b"item1", b"item2" ]

View File

@ -0,0 +1,5 @@
[list_status]
command:/usr/local/opnsense/scripts/ipsec/list_status.py
parameters:
type:script_output
message:IPsec list status