# Copyright (C) Internet Systems Consortium, Inc. ("ISC") # # SPDX-License-Identifier: MPL-2.0 # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, you can obtain one at https://mozilla.org/MPL/2.0/. # # See the COPYRIGHT file distributed with this work for additional # information regarding copyright ownership. from collections.abc import Callable from typing import Any import os import time import dns.exception import dns.flags import dns.message import dns.name import dns.query import dns.rcode import dns.rdataclass import dns.rdatatype import isctest.log import isctest.run QUERY_TIMEOUT = 10 def generic_query( query_func: Callable[..., Any], message: dns.message.Message, ip: str, port: int | None = None, source: str | None = None, timeout: int = QUERY_TIMEOUT, attempts: int = 10, expected_rcode: dns.rcode.Rcode | None = None, verify: bool = False, log_query: bool = True, log_response: bool = True, ) -> Any: def log_querymsg(exception: Exception | None = None) -> None: """ Helper for logging query message. Call this *after* query_func() has been called, as it may modify the message, e.g. with a TSIG. If an exception is provided, it will be logged as well. """ nonlocal log_query if log_query: isctest.log.debug( f"isc.query.{query_func.__name__}(): query\n{message.to_text()}" ) log_query = False # only log query once if exception: isctest.log.debug( f"isc.query.{query_func.__name__}(): the '{exception}' exception raised" ) if port is None: if query_func.__name__ == "tls": port = int(os.environ["TLSPORT"]) else: port = int(os.environ["PORT"]) query_args = { "q": message, "where": ip, "timeout": timeout, "port": port, "source": source, } if query_func.__name__ == "tls": query_args["verify"] = verify res = None for attempt in range(attempts): log_msg = ( f"isc.query.{query_func.__name__}(): ip={ip}, port={port}, source={source}, " f"timeout={timeout}, attempts left={attempts-attempt}" ) isctest.log.debug(log_msg) exc = None try: res = query_func(**query_args) except (dns.exception.Timeout, ConnectionRefusedError) as e: exc = e finally: log_querymsg(exc) if res: if log_response: isctest.log.debug( f"isc.query.{query_func.__name__}(): response\n{res.to_text()}" ) if res.rcode() == expected_rcode or expected_rcode is None: return res time.sleep(1) if expected_rcode is not None: last_rcode = dns.rcode.to_text(res.rcode()) if res else None isctest.log.debug( f"isc.query.{query_func.__name__}(): expected rcode={dns.rcode.to_text(expected_rcode)}, last rcode={last_rcode}" ) raise dns.exception.Timeout def udp(*args, **kwargs) -> Any: return generic_query(dns.query.udp, *args, **kwargs) def tcp(*args, **kwargs) -> Any: return generic_query(dns.query.tcp, *args, **kwargs) def tls(*args, **kwargs) -> Any: try: return generic_query(dns.query.tls, *args, **kwargs) except TypeError as e: raise RuntimeError( "dnspython 2.5.0 or newer is required for isctest.query.tls()" ) from e def create( qname, qtype, qclass=dns.rdataclass.IN, dnssec: bool = True, rd: bool = True, cd: bool = False, ad: bool = True, ) -> dns.message.Message: """Create DNS query with defaults suitable for our tests.""" msg = dns.message.make_query( qname, qtype, qclass, use_edns=True, want_dnssec=dnssec ) msg.flags = 0 if rd: msg.flags = dns.flags.RD if ad: msg.flags |= dns.flags.AD if cd: msg.flags |= dns.flags.CD return msg def wait_for_serial(server_ip, zone, expected_serial, timeout=30): """Wait until the server has the expected SOA serial for the zone. Queries the server repeatedly until the SOA serial matches or the timeout expires. 'server_ip' is the IP address to query (string). 'zone' is the zone name (string, with or without trailing dot). 'expected_serial' is the expected SOA serial number (int). 'timeout' is the maximum time to wait in seconds (default 30). """ query = create(zone, "SOA", dnssec=False) def check(): res = tcp(query, server_ip) soa = res.get_rrset( res.answer, dns.name.from_text(zone), dns.rdataclass.IN, dns.rdatatype.SOA, ) return soa is not None and len(soa) == 1 and soa[0].serial == expected_serial isctest.run.retry_with_timeout( check, timeout=timeout, msg=f"timed out waiting for serial {expected_serial} at {server_ip} for {zone}", )