diff --git a/Makefile b/Makefile index 56c4ad564..0f73c2498 100644 --- a/Makefile +++ b/Makefile @@ -172,7 +172,7 @@ CORE_DEPENDS?= ca_root_nss \ php${CORE_PHP}-zlib \ pkg \ py${CORE_PYTHON}-Jinja2 \ - py${CORE_PYTHON}-dnspython \ + py${CORE_PYTHON}-dnspython2 \ py${CORE_PYTHON}-netaddr \ py${CORE_PYTHON}-requests \ py${CORE_PYTHON}-sqlite3 \ diff --git a/src/opnsense/scripts/filter/lib/__init__.py b/src/opnsense/scripts/filter/lib/__init__.py index 85e9d854b..70814fa1b 100755 --- a/src/opnsense/scripts/filter/lib/__init__.py +++ b/src/opnsense/scripts/filter/lib/__init__.py @@ -23,8 +23,14 @@ ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """ +import asyncio +import dns.resolver import ipaddress import itertools +import syslog +import time +from dns.rdatatype import RdataType +from dns.asyncresolver import Resolver def net_wildcard_iterator(network: str): @@ -62,3 +68,80 @@ def net_wildcard_iterator(network: str): yield ipaddress.IPv6Network((this_ip, wildcard.max_prefixlen - mask_length), strict=False) else: yield ipaddress.IPv4Network((this_ip, wildcard.max_prefixlen - mask_length), strict=False) + + +class AsyncDNSResolver: + """ Asynchronous DNS resolver, collect addresses for hostnames collected in request queue. + simple example usecase collecting addresses associated with two domains: + + asyncresolver = AsyncDNSResolver() + asyncresolver.add('www.example.com') + asyncresolver.add('mail.example.com') + asyncresolver.collect() + print(asyncresolver.addresses()) + """ + batch_size = 100 + report_size = 10000 + + def __init__(self, origin=""): + self._request_queue = list() + self._requested = set() + self._response = set() + self._origin = origin + self._domains_queued = 0 + + def add(self, hostname): + self._request_queue.append(hostname) + + async def request_ittr(self, loop): + dnsResolver = Resolver() + dnsResolver.timeout = 2 + collected_errors = set() + while len(self._request_queue) > 0: + tasks = [] + while len(tasks) < self.batch_size and len(self._request_queue) > 0: + hostname = self._request_queue.pop() + if hostname not in self._requested: + self._domains_queued += 1 + # make sure we only request a host once + for record_type in ['A', 'AAAA']: + tasks.append(dnsResolver.resolve(hostname, record_type)) + self._requested.add(hostname) + if len(tasks) > 0: + responses = await asyncio.gather(*tasks, return_exceptions=True) + for response in responses: + if type(response) is dns.resolver.Answer: + for item in response.response.answer: + if type(item) is dns.rrset.RRset: + for addr in item.items: + if addr.rdtype is RdataType.CNAME: + # query cname (recursion) + self._request_queue.append(addr.target) + else: + self._response.add(addr.address) + elif type(response) in [ + dns.resolver.NoAnswer, + dns.resolver.NXDOMAIN, + dns.exception.Timeout, + dns.resolver.NoNameservers, + dns.name.EmptyLabel]: + if str(response) not in collected_errors: + syslog.syslog(syslog.LOG_ERR, '%s [for %s]' % (response, self._origin)) + collected_errors.add(str(response)) + if self._domains_queued % self.report_size == 0: + syslog.syslog(syslog.LOG_NOTICE, 'requested %d hostnames for %s' % (self._domains_queued, self._origin)) + + def collect(self): + if len(self._request_queue) > 0: + start_time = time.time() + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + asyncio.run(self.request_ittr(loop)) + loop.close() + syslog.syslog(syslog.LOG_NOTICE, 'resolving %d hostnames (%d addresses) for %s took %.2f seconds' % ( + self._domains_queued, len(self._response), self._origin, time.time() - start_time + )) + return self + + def addresses(self): + return self._response diff --git a/src/opnsense/scripts/filter/lib/alias.py b/src/opnsense/scripts/filter/lib/alias.py index eb2e4d8cb..70ffa6a56 100755 --- a/src/opnsense/scripts/filter/lib/alias.py +++ b/src/opnsense/scripts/filter/lib/alias.py @@ -35,7 +35,7 @@ import dns.resolver import syslog from hashlib import md5 from . import geoip -from . import net_wildcard_iterator +from . import net_wildcard_iterator, AsyncDNSResolver from .arpcache import ArpCache class Alias(object): @@ -49,8 +49,6 @@ class Alias(object): :return: None """ self._known_aliases = known_aliases - self._dnsResolver = dns.resolver.Resolver() - self._dnsResolver.timeout = 2 self._is_changed = None self._has_expired = None self._ttl = ttl @@ -85,6 +83,7 @@ class Alias(object): self._filename_alias_hash = '/var/db/aliastables/%s.md5.txt' % self._name # the generated alias contents, without dependencies self._filename_alias_content = '/var/db/aliastables/%s.self.txt' % self._name + self._dnsResolver = AsyncDNSResolver(self._name) def _parse_address(self, address): """ parse addresses and hostnames, yield only valid addresses and networks @@ -125,19 +124,8 @@ class Alias(object): except (ipaddress.AddressValueError, ValueError): pass - # try to resolve provided address - could_resolve = False - for record_type in ['A', 'AAAA']: - try: - for rdata in self._dnsResolver.query(address, record_type): - yield str(rdata) - could_resolve = True - except (dns.resolver.NoAnswer, dns.resolver.NXDOMAIN, dns.exception.Timeout, dns.resolver.NoNameservers, dns.name.EmptyLabel): - pass - - if not could_resolve: - # log when none could be found - syslog.syslog(syslog.LOG_ERR, 'unable to resolve %s for alias %s' % (address, self._name)) + # try to resolve provided address (queue for retrieval) + self._dnsResolver.add(address) def _fetch_url(self, url): """ return unparsed (raw) alias entries without dependencies @@ -244,18 +232,15 @@ class Alias(object): else: undo_content = "" try: + address_parser = self.get_parser() + for item in self.items(): + if address_parser: + for address in address_parser(item): + self._resolve_content.add(address) + # resolve hostnames (async) if there are any in the collected set + self._resolve_content = self._resolve_content.union(self._dnsResolver.collect().addresses()) with open(self._filename_alias_content, 'w') as f_out: - for item in self.items(): - address_parser = self.get_parser() - if address_parser: - for address in address_parser(item): - if address not in self._resolve_content: - # flush new alias content (without dependencies) to disk, so progress can easliy - # be followed, large lists of domain names can take quite some resolve time. - f_out.write('%s\n' % address) - f_out.flush() - # preserve addresses - self._resolve_content.add(address) + f_out.write('\n'.join(self._resolve_content)) except IOError: # parse issue, keep data as-is, flush previous content to disk with open(self._filename_alias_content, 'w') as f_out: