From 76b8ae44908b861e41e886744f6f7cbda2ab91e4 Mon Sep 17 00:00:00 2001 From: Ad Schellevis Date: Thu, 19 Aug 2021 15:45:40 +0200 Subject: [PATCH] Firewall / Aliases - improve resolve performance by implementing async dns lookups. ref https://github.com/opnsense/core/issues/5117 This will need a new version of py-dnspython (py-dnspython2 in ports) for dns.asyncresolver support. Some additional log messages have been added to gain more insights into the resolving process via the general log. Intermediate results aren't saved to disk anymore, which also simplifies the resolve() function in the Alias class. An address parser can queue hostname lookups for later retrieval (see _parse_address()) so we can batch process the list of hostnames to be collected. --- Makefile | 2 +- src/opnsense/scripts/filter/lib/__init__.py | 83 +++++++++++++++++++++ src/opnsense/scripts/filter/lib/alias.py | 39 +++------- 3 files changed, 96 insertions(+), 28 deletions(-) diff --git a/Makefile b/Makefile index 56c4ad564..0f73c2498 100644 --- a/Makefile +++ b/Makefile @@ -172,7 +172,7 @@ CORE_DEPENDS?= ca_root_nss \ php${CORE_PHP}-zlib \ pkg \ py${CORE_PYTHON}-Jinja2 \ - py${CORE_PYTHON}-dnspython \ + py${CORE_PYTHON}-dnspython2 \ py${CORE_PYTHON}-netaddr \ py${CORE_PYTHON}-requests \ py${CORE_PYTHON}-sqlite3 \ diff --git a/src/opnsense/scripts/filter/lib/__init__.py b/src/opnsense/scripts/filter/lib/__init__.py index 85e9d854b..70814fa1b 100755 --- a/src/opnsense/scripts/filter/lib/__init__.py +++ b/src/opnsense/scripts/filter/lib/__init__.py @@ -23,8 +23,14 @@ ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """ +import asyncio +import dns.resolver import ipaddress import itertools +import syslog +import time +from dns.rdatatype import RdataType +from dns.asyncresolver import Resolver def net_wildcard_iterator(network: str): @@ -62,3 +68,80 @@ def net_wildcard_iterator(network: str): yield ipaddress.IPv6Network((this_ip, wildcard.max_prefixlen - mask_length), strict=False) else: yield ipaddress.IPv4Network((this_ip, wildcard.max_prefixlen - mask_length), strict=False) + + +class AsyncDNSResolver: + """ Asynchronous DNS resolver, collect addresses for hostnames collected in request queue. + simple example usecase collecting addresses associated with two domains: + + asyncresolver = AsyncDNSResolver() + asyncresolver.add('www.example.com') + asyncresolver.add('mail.example.com') + asyncresolver.collect() + print(asyncresolver.addresses()) + """ + batch_size = 100 + report_size = 10000 + + def __init__(self, origin=""): + self._request_queue = list() + self._requested = set() + self._response = set() + self._origin = origin + self._domains_queued = 0 + + def add(self, hostname): + self._request_queue.append(hostname) + + async def request_ittr(self, loop): + dnsResolver = Resolver() + dnsResolver.timeout = 2 + collected_errors = set() + while len(self._request_queue) > 0: + tasks = [] + while len(tasks) < self.batch_size and len(self._request_queue) > 0: + hostname = self._request_queue.pop() + if hostname not in self._requested: + self._domains_queued += 1 + # make sure we only request a host once + for record_type in ['A', 'AAAA']: + tasks.append(dnsResolver.resolve(hostname, record_type)) + self._requested.add(hostname) + if len(tasks) > 0: + responses = await asyncio.gather(*tasks, return_exceptions=True) + for response in responses: + if type(response) is dns.resolver.Answer: + for item in response.response.answer: + if type(item) is dns.rrset.RRset: + for addr in item.items: + if addr.rdtype is RdataType.CNAME: + # query cname (recursion) + self._request_queue.append(addr.target) + else: + self._response.add(addr.address) + elif type(response) in [ + dns.resolver.NoAnswer, + dns.resolver.NXDOMAIN, + dns.exception.Timeout, + dns.resolver.NoNameservers, + dns.name.EmptyLabel]: + if str(response) not in collected_errors: + syslog.syslog(syslog.LOG_ERR, '%s [for %s]' % (response, self._origin)) + collected_errors.add(str(response)) + if self._domains_queued % self.report_size == 0: + syslog.syslog(syslog.LOG_NOTICE, 'requested %d hostnames for %s' % (self._domains_queued, self._origin)) + + def collect(self): + if len(self._request_queue) > 0: + start_time = time.time() + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + asyncio.run(self.request_ittr(loop)) + loop.close() + syslog.syslog(syslog.LOG_NOTICE, 'resolving %d hostnames (%d addresses) for %s took %.2f seconds' % ( + self._domains_queued, len(self._response), self._origin, time.time() - start_time + )) + return self + + def addresses(self): + return self._response diff --git a/src/opnsense/scripts/filter/lib/alias.py b/src/opnsense/scripts/filter/lib/alias.py index eb2e4d8cb..70ffa6a56 100755 --- a/src/opnsense/scripts/filter/lib/alias.py +++ b/src/opnsense/scripts/filter/lib/alias.py @@ -35,7 +35,7 @@ import dns.resolver import syslog from hashlib import md5 from . import geoip -from . import net_wildcard_iterator +from . import net_wildcard_iterator, AsyncDNSResolver from .arpcache import ArpCache class Alias(object): @@ -49,8 +49,6 @@ class Alias(object): :return: None """ self._known_aliases = known_aliases - self._dnsResolver = dns.resolver.Resolver() - self._dnsResolver.timeout = 2 self._is_changed = None self._has_expired = None self._ttl = ttl @@ -85,6 +83,7 @@ class Alias(object): self._filename_alias_hash = '/var/db/aliastables/%s.md5.txt' % self._name # the generated alias contents, without dependencies self._filename_alias_content = '/var/db/aliastables/%s.self.txt' % self._name + self._dnsResolver = AsyncDNSResolver(self._name) def _parse_address(self, address): """ parse addresses and hostnames, yield only valid addresses and networks @@ -125,19 +124,8 @@ class Alias(object): except (ipaddress.AddressValueError, ValueError): pass - # try to resolve provided address - could_resolve = False - for record_type in ['A', 'AAAA']: - try: - for rdata in self._dnsResolver.query(address, record_type): - yield str(rdata) - could_resolve = True - except (dns.resolver.NoAnswer, dns.resolver.NXDOMAIN, dns.exception.Timeout, dns.resolver.NoNameservers, dns.name.EmptyLabel): - pass - - if not could_resolve: - # log when none could be found - syslog.syslog(syslog.LOG_ERR, 'unable to resolve %s for alias %s' % (address, self._name)) + # try to resolve provided address (queue for retrieval) + self._dnsResolver.add(address) def _fetch_url(self, url): """ return unparsed (raw) alias entries without dependencies @@ -244,18 +232,15 @@ class Alias(object): else: undo_content = "" try: + address_parser = self.get_parser() + for item in self.items(): + if address_parser: + for address in address_parser(item): + self._resolve_content.add(address) + # resolve hostnames (async) if there are any in the collected set + self._resolve_content = self._resolve_content.union(self._dnsResolver.collect().addresses()) with open(self._filename_alias_content, 'w') as f_out: - for item in self.items(): - address_parser = self.get_parser() - if address_parser: - for address in address_parser(item): - if address not in self._resolve_content: - # flush new alias content (without dependencies) to disk, so progress can easliy - # be followed, large lists of domain names can take quite some resolve time. - f_out.write('%s\n' % address) - f_out.flush() - # preserve addresses - self._resolve_content.add(address) + f_out.write('\n'.join(self._resolve_content)) except IOError: # parse issue, keep data as-is, flush previous content to disk with open(self._filename_alias_content, 'w') as f_out: