Firewall / Aliases - improve resolve performance by implementing async dns lookups. ref https://github.com/opnsense/core/issues/5117

This will need a new version of py-dnspython (py-dnspython2 in ports) for dns.asyncresolver support. Some additional log messages have been added to gain more insights into the resolving process via the general log.
Intermediate results aren't saved to disk anymore, which also simplifies the resolve() function in the Alias class. An address parser can queue hostname lookups for later retrieval (see _parse_address()) so we can batch process the list of hostnames to be collected.
This commit is contained in:
Ad Schellevis 2021-08-19 15:45:40 +02:00
parent 2872298658
commit 76b8ae4490
3 changed files with 96 additions and 28 deletions

View File

@ -172,7 +172,7 @@ CORE_DEPENDS?= ca_root_nss \
php${CORE_PHP}-zlib \
pkg \
py${CORE_PYTHON}-Jinja2 \
py${CORE_PYTHON}-dnspython \
py${CORE_PYTHON}-dnspython2 \
py${CORE_PYTHON}-netaddr \
py${CORE_PYTHON}-requests \
py${CORE_PYTHON}-sqlite3 \

View File

@ -23,8 +23,14 @@
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
POSSIBILITY OF SUCH DAMAGE.
"""
import asyncio
import dns.resolver
import ipaddress
import itertools
import syslog
import time
from dns.rdatatype import RdataType
from dns.asyncresolver import Resolver
def net_wildcard_iterator(network: str):
@ -62,3 +68,80 @@ def net_wildcard_iterator(network: str):
yield ipaddress.IPv6Network((this_ip, wildcard.max_prefixlen - mask_length), strict=False)
else:
yield ipaddress.IPv4Network((this_ip, wildcard.max_prefixlen - mask_length), strict=False)
class AsyncDNSResolver:
""" Asynchronous DNS resolver, collect addresses for hostnames collected in request queue.
simple example usecase collecting addresses associated with two domains:
asyncresolver = AsyncDNSResolver()
asyncresolver.add('www.example.com')
asyncresolver.add('mail.example.com')
asyncresolver.collect()
print(asyncresolver.addresses())
"""
batch_size = 100
report_size = 10000
def __init__(self, origin="<unknown>"):
self._request_queue = list()
self._requested = set()
self._response = set()
self._origin = origin
self._domains_queued = 0
def add(self, hostname):
self._request_queue.append(hostname)
async def request_ittr(self, loop):
dnsResolver = Resolver()
dnsResolver.timeout = 2
collected_errors = set()
while len(self._request_queue) > 0:
tasks = []
while len(tasks) < self.batch_size and len(self._request_queue) > 0:
hostname = self._request_queue.pop()
if hostname not in self._requested:
self._domains_queued += 1
# make sure we only request a host once
for record_type in ['A', 'AAAA']:
tasks.append(dnsResolver.resolve(hostname, record_type))
self._requested.add(hostname)
if len(tasks) > 0:
responses = await asyncio.gather(*tasks, return_exceptions=True)
for response in responses:
if type(response) is dns.resolver.Answer:
for item in response.response.answer:
if type(item) is dns.rrset.RRset:
for addr in item.items:
if addr.rdtype is RdataType.CNAME:
# query cname (recursion)
self._request_queue.append(addr.target)
else:
self._response.add(addr.address)
elif type(response) in [
dns.resolver.NoAnswer,
dns.resolver.NXDOMAIN,
dns.exception.Timeout,
dns.resolver.NoNameservers,
dns.name.EmptyLabel]:
if str(response) not in collected_errors:
syslog.syslog(syslog.LOG_ERR, '%s [for %s]' % (response, self._origin))
collected_errors.add(str(response))
if self._domains_queued % self.report_size == 0:
syslog.syslog(syslog.LOG_NOTICE, 'requested %d hostnames for %s' % (self._domains_queued, self._origin))
def collect(self):
if len(self._request_queue) > 0:
start_time = time.time()
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
asyncio.run(self.request_ittr(loop))
loop.close()
syslog.syslog(syslog.LOG_NOTICE, 'resolving %d hostnames (%d addresses) for %s took %.2f seconds' % (
self._domains_queued, len(self._response), self._origin, time.time() - start_time
))
return self
def addresses(self):
return self._response

View File

@ -35,7 +35,7 @@ import dns.resolver
import syslog
from hashlib import md5
from . import geoip
from . import net_wildcard_iterator
from . import net_wildcard_iterator, AsyncDNSResolver
from .arpcache import ArpCache
class Alias(object):
@ -49,8 +49,6 @@ class Alias(object):
:return: None
"""
self._known_aliases = known_aliases
self._dnsResolver = dns.resolver.Resolver()
self._dnsResolver.timeout = 2
self._is_changed = None
self._has_expired = None
self._ttl = ttl
@ -85,6 +83,7 @@ class Alias(object):
self._filename_alias_hash = '/var/db/aliastables/%s.md5.txt' % self._name
# the generated alias contents, without dependencies
self._filename_alias_content = '/var/db/aliastables/%s.self.txt' % self._name
self._dnsResolver = AsyncDNSResolver(self._name)
def _parse_address(self, address):
""" parse addresses and hostnames, yield only valid addresses and networks
@ -125,19 +124,8 @@ class Alias(object):
except (ipaddress.AddressValueError, ValueError):
pass
# try to resolve provided address
could_resolve = False
for record_type in ['A', 'AAAA']:
try:
for rdata in self._dnsResolver.query(address, record_type):
yield str(rdata)
could_resolve = True
except (dns.resolver.NoAnswer, dns.resolver.NXDOMAIN, dns.exception.Timeout, dns.resolver.NoNameservers, dns.name.EmptyLabel):
pass
if not could_resolve:
# log when none could be found
syslog.syslog(syslog.LOG_ERR, 'unable to resolve %s for alias %s' % (address, self._name))
# try to resolve provided address (queue for retrieval)
self._dnsResolver.add(address)
def _fetch_url(self, url):
""" return unparsed (raw) alias entries without dependencies
@ -244,18 +232,15 @@ class Alias(object):
else:
undo_content = ""
try:
address_parser = self.get_parser()
for item in self.items():
if address_parser:
for address in address_parser(item):
self._resolve_content.add(address)
# resolve hostnames (async) if there are any in the collected set
self._resolve_content = self._resolve_content.union(self._dnsResolver.collect().addresses())
with open(self._filename_alias_content, 'w') as f_out:
for item in self.items():
address_parser = self.get_parser()
if address_parser:
for address in address_parser(item):
if address not in self._resolve_content:
# flush new alias content (without dependencies) to disk, so progress can easliy
# be followed, large lists of domain names can take quite some resolve time.
f_out.write('%s\n' % address)
f_out.flush()
# preserve addresses
self._resolve_content.add(address)
f_out.write('\n'.join(self._resolve_content))
except IOError:
# parse issue, keep data as-is, flush previous content to disk
with open(self._filename_alias_content, 'w') as f_out: