%PDF- %PDF-
Mini Shell

Mini Shell

Direktori : /lib/python3/dist-packages/twisted/internet/test/
Upload File :
Create Path :
Current File : //lib/python3/dist-packages/twisted/internet/test/test_unix.py

# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.

"""
Tests for implementations of L{IReactorUNIX}.
"""


from hashlib import md5
from os import close, fstat, stat, unlink, urandom
from pprint import pformat
from socket import AF_INET, SOCK_STREAM, SOL_SOCKET, socket
from stat import S_IMODE
from struct import pack
from tempfile import mkstemp, mktemp
from typing import Optional, Sequence, Type
from unittest import skipIf

try:
    from socket import AF_UNIX as _AF_UNIX
except ImportError:
    AF_UNIX = None
else:
    AF_UNIX = _AF_UNIX

from zope.interface import Interface, implementer

from twisted.internet import base, interfaces
from twisted.internet.address import UNIXAddress
from twisted.internet.defer import Deferred, fail, gatherResults
from twisted.internet.endpoints import UNIXClientEndpoint, UNIXServerEndpoint
from twisted.internet.error import (
    CannotListenError,
    ConnectionClosed,
    FileDescriptorOverrun,
)
from twisted.internet.interfaces import (
    IFileDescriptorReceiver,
    IReactorFDSet,
    IReactorSocket,
    IReactorUNIX,
)
from twisted.internet.protocol import ClientFactory, DatagramProtocol, ServerFactory
from twisted.internet.task import LoopingCall
from twisted.internet.test.connectionmixins import (
    ConnectableProtocol,
    ConnectionTestsMixin,
    EndpointCreator,
    StreamClientTestsMixin,
    runProtocolsWithReactor,
)
from twisted.internet.test.reactormixins import ReactorBuilder
from twisted.internet.test.test_tcp import (
    MyClientFactory,
    MyServerFactory,
    StreamTransportTestsMixin,
    WriteSequenceTestsMixin,
)
from twisted.python.compat import nativeString
from twisted.python.failure import Failure
from twisted.python.filepath import _coerceToFilesystemEncoding
from twisted.python.log import addObserver, err, removeObserver
from twisted.python.reflect import requireModule
from twisted.python.runtime import platform

sendmsg = requireModule("twisted.python.sendmsg")
sendmsgSkipReason = ""
if requireModule("twisted.python.sendmsg") is not None:
    sendmsgSkipReason = (
        "sendmsg extension unavailable, " "extended UNIX features disabled"
    )


class UNIXFamilyMixin:
    """
    Test-helper defining mixin for things related to AF_UNIX sockets.
    """

    def _modeTest(self, methodName, path, factory):
        """
        Assert that the mode of the created unix socket is set to the mode
        specified to the reactor method.
        """
        mode = 0o600
        reactor = self.buildReactor()
        unixPort = getattr(reactor, methodName)(path, factory, mode=mode)
        unixPort.stopListening()
        self.assertEqual(S_IMODE(stat(path).st_mode), mode)


def _abstractPath(case):
    """
    Return a new, unique abstract namespace path to be listened on.
    """
    return md5(urandom(100)).hexdigest()


class UNIXCreator(EndpointCreator):
    """
    Create UNIX socket end points.
    """

    requiredInterfaces: Optional[Sequence[Type[Interface]]] = (interfaces.IReactorUNIX,)

    def server(self, reactor):
        """
        Construct a UNIX server endpoint.
        """
        # self.mktemp() often returns a path which is too long to be used.
        path = mktemp(suffix=".sock", dir=".")
        return UNIXServerEndpoint(reactor, path)

    def client(self, reactor, serverAddress):
        """
        Construct a UNIX client endpoint.
        """
        return UNIXClientEndpoint(reactor, serverAddress.name)


class SendFileDescriptor(ConnectableProtocol):
    """
    L{SendFileDescriptorAndBytes} sends a file descriptor and optionally some
    normal bytes and then closes its connection.

    @ivar reason: The reason the connection was lost, after C{connectionLost}
        is called.
    """

    reason = None

    def __init__(self, fd, data):
        """
        @param fd: A C{int} giving a file descriptor to send over the
            connection.

        @param data: A C{str} giving data to send over the connection, or
            L{None} if no data is to be sent.
        """
        self.fd = fd
        self.data = data

    def connectionMade(self):
        """
        Send C{self.fd} and, if it is not L{None}, C{self.data}.  Then close the
        connection.
        """
        self.transport.sendFileDescriptor(self.fd)
        if self.data:
            self.transport.write(self.data)
        self.transport.loseConnection()

    def connectionLost(self, reason):
        ConnectableProtocol.connectionLost(self, reason)
        self.reason = reason


@implementer(IFileDescriptorReceiver)
class ReceiveFileDescriptor(ConnectableProtocol):
    """
    L{ReceiveFileDescriptor} provides an API for waiting for file descriptors to
    be received.

    @ivar reason: The reason the connection was lost, after C{connectionLost}
        is called.

    @ivar waiting: A L{Deferred} which fires with a file descriptor once one is
        received, or with a failure if the connection is lost with no descriptor
        arriving.
    """

    reason = None
    waiting = None

    def waitForDescriptor(self):
        """
        Return a L{Deferred} which will fire with the next file descriptor
        received, or with a failure if the connection is or has already been
        lost.
        """
        if self.reason is None:
            self.waiting = Deferred()
            return self.waiting
        else:
            return fail(self.reason)

    def fileDescriptorReceived(self, descriptor):
        """
        Fire the waiting Deferred, initialized by C{waitForDescriptor}, with the
        file descriptor just received.
        """
        self.waiting.callback(descriptor)
        self.waiting = None

    def dataReceived(self, data):
        """
        Fail the waiting Deferred, if it has not already been fired by
        C{fileDescriptorReceived}.  The bytes sent along with a file descriptor
        are guaranteed to be delivered to the protocol's C{dataReceived} method
        only after the file descriptor has been delivered to the protocol's
        C{fileDescriptorReceived}.
        """
        if self.waiting is not None:
            self.waiting.errback(
                Failure(Exception(f"Received bytes ({data!r}) before descriptor."))
            )
            self.waiting = None

    def connectionLost(self, reason):
        """
        Fail the waiting Deferred, initialized by C{waitForDescriptor}, if there
        is one.
        """
        ConnectableProtocol.connectionLost(self, reason)
        if self.waiting is not None:
            self.waiting.errback(reason)
            self.waiting = None
        self.reason = reason


class UNIXTestsBuilder(UNIXFamilyMixin, ReactorBuilder, ConnectionTestsMixin):
    """
    Builder defining tests relating to L{IReactorUNIX}.
    """

    requiredInterfaces = (IReactorUNIX,)

    endpoints = UNIXCreator()

    def test_mode(self):
        """
        The UNIX socket created by L{IReactorUNIX.listenUNIX} is created with
        the mode specified.
        """
        self._modeTest("listenUNIX", self.mktemp(), ServerFactory())

    @skipIf(
        not platform.isLinux(),
        "Abstract namespace UNIX sockets only " "supported on Linux.",
    )
    def test_listenOnLinuxAbstractNamespace(self):
        """
        On Linux, a UNIX socket path may begin with C{'\0'} to indicate
        a socket in the abstract namespace.  L{IReactorUNIX.listenUNIX}
        accepts such a path.
        """
        # Don't listen on a path longer than the maximum allowed.
        path = _abstractPath(self)
        reactor = self.buildReactor()
        port = reactor.listenUNIX("\0" + path, ServerFactory())
        self.assertEqual(port.getHost(), UNIXAddress("\0" + path))

    def test_listenFailure(self):
        """
        L{IReactorUNIX.listenUNIX} raises L{CannotListenError} if the
        underlying port's createInternetSocket raises a socket error.
        """

        def raiseSocketError(self):
            raise OSError("FakeBasePort forced socket.error")

        self.patch(base.BasePort, "createInternetSocket", raiseSocketError)
        reactor = self.buildReactor()
        with self.assertRaises(CannotListenError):
            reactor.listenUNIX("not-used", ServerFactory())

    @skipIf(
        not platform.isLinux(),
        "Abstract namespace UNIX sockets only " "supported on Linux.",
    )
    def test_connectToLinuxAbstractNamespace(self):
        """
        L{IReactorUNIX.connectUNIX} also accepts a Linux abstract namespace
        path.
        """
        path = _abstractPath(self)
        reactor = self.buildReactor()
        connector = reactor.connectUNIX("\0" + path, ClientFactory())
        self.assertEqual(connector.getDestination(), UNIXAddress("\0" + path))

    def test_addresses(self):
        """
        A client's transport's C{getHost} and C{getPeer} return L{UNIXAddress}
        instances which have the filesystem path of the host and peer ends of
        the connection.
        """

        class SaveAddress(ConnectableProtocol):
            def makeConnection(self, transport):
                self.addresses = dict(
                    host=transport.getHost(), peer=transport.getPeer()
                )
                transport.loseConnection()

        server = SaveAddress()
        client = SaveAddress()

        runProtocolsWithReactor(self, server, client, self.endpoints)

        self.assertEqual(server.addresses["host"], client.addresses["peer"])
        self.assertEqual(server.addresses["peer"], client.addresses["host"])

    @skipIf(not sendmsg, sendmsgSkipReason)
    def test_sendFileDescriptor(self):
        """
        L{IUNIXTransport.sendFileDescriptor} accepts an integer file descriptor
        and sends a copy of it to the process reading from the connection.
        """
        from socket import fromfd

        s = socket()
        s.bind(("", 0))
        server = SendFileDescriptor(s.fileno(), b"junk")

        client = ReceiveFileDescriptor()
        d = client.waitForDescriptor()

        def checkDescriptor(descriptor):
            received = fromfd(descriptor, AF_INET, SOCK_STREAM)
            # Thanks for the free dup, fromfd()
            close(descriptor)

            # If the sockets have the same local address, they're probably the
            # same.
            self.assertEqual(s.getsockname(), received.getsockname())

            # But it would be cheating for them to be identified by the same
            # file descriptor.  The point was to get a copy, as we might get if
            # there were two processes involved here.
            self.assertNotEqual(s.fileno(), received.fileno())

        d.addCallback(checkDescriptor)
        d.addErrback(err, "Sending file descriptor encountered a problem")
        d.addBoth(lambda ignored: server.transport.loseConnection())

        runProtocolsWithReactor(self, server, client, self.endpoints)

    @skipIf(not sendmsg, sendmsgSkipReason)
    def test_sendFileDescriptorTriggersPauseProducing(self):
        """
        If a L{IUNIXTransport.sendFileDescriptor} call fills up
        the send buffer, any registered producer is paused.
        """

        class DoesNotRead(ConnectableProtocol):
            def connectionMade(self):
                self.transport.pauseProducing()

        class SendsManyFileDescriptors(ConnectableProtocol):
            paused = False

            def connectionMade(self):
                self.socket = socket()
                self.transport.registerProducer(self, True)

                def sender():
                    self.transport.sendFileDescriptor(self.socket.fileno())
                    self.transport.write(b"x")

                self.task = LoopingCall(sender)
                self.task.clock = self.transport.reactor
                self.task.start(0).addErrback(err, "Send loop failure")

            def stopProducing(self):
                self._disconnect()

            def resumeProducing(self):
                self._disconnect()

            def pauseProducing(self):
                self.paused = True
                self.transport.unregisterProducer()
                self._disconnect()

            def _disconnect(self):
                self.task.stop()
                self.transport.abortConnection()
                self.other.transport.abortConnection()

        server = SendsManyFileDescriptors()
        client = DoesNotRead()
        server.other = client
        runProtocolsWithReactor(self, server, client, self.endpoints)

        self.assertTrue(server.paused, "sendFileDescriptor producer was not paused")

    @skipIf(not sendmsg, sendmsgSkipReason)
    def test_fileDescriptorOverrun(self):
        """
        If L{IUNIXTransport.sendFileDescriptor} is used to queue a greater
        number of file descriptors than the number of bytes sent using
        L{ITransport.write}, the connection is closed and the protocol connected
        to the transport has its C{connectionLost} method called with a failure
        wrapping L{FileDescriptorOverrun}.
        """
        cargo = socket()
        server = SendFileDescriptor(cargo.fileno(), None)

        client = ReceiveFileDescriptor()
        result = []
        d = client.waitForDescriptor()
        d.addBoth(result.append)
        d.addBoth(lambda ignored: server.transport.loseConnection())

        runProtocolsWithReactor(self, server, client, self.endpoints)

        self.assertIsInstance(result[0], Failure)
        result[0].trap(ConnectionClosed)
        self.assertIsInstance(server.reason.value, FileDescriptorOverrun)

    def _sendmsgMixinFileDescriptorReceivedDriver(self, ancillaryPacker):
        """
        Drive _SendmsgMixin via sendmsg socket calls to check that
        L{IFileDescriptorReceiver.fileDescriptorReceived} is called once
        for each file descriptor received in the ancillary messages.

        @param ancillaryPacker: A callable that will be given a list of
            two file descriptors and should return a two-tuple where:
            The first item is an iterable of zero or more (cmsg_level,
            cmsg_type, cmsg_data) tuples in the same order as the given
            list for actual sending via sendmsg; the second item is an
            integer indicating the expected number of FDs to be received.
        """
        # Strategy:
        # - Create a UNIX socketpair.
        # - Associate one end to a FakeReceiver and FakeProtocol.
        # - Call sendmsg on the other end to send FDs as ancillary data.
        #   Ancillary data is obtained calling ancillaryPacker with
        #   the two FDs associated to two temp files (using the socket
        #   FDs for this fails the device/inode verification tests on
        #   macOS 10.10, so temp files are used instead).
        # - Call doRead in the FakeReceiver.
        # - Verify results on FakeProtocol.
        #   Using known device/inodes to verify correct order.

        # TODO: replace FakeReceiver test approach with one based in
        # IReactorSocket.adoptStreamConnection once AF_UNIX support is
        # implemented; see https://twistedmatrix.com/trac/ticket/5573.

        from socket import socketpair

        from twisted.internet.unix import _SendmsgMixin
        from twisted.python.sendmsg import sendmsg

        def deviceInodeTuple(fd):
            fs = fstat(fd)
            return (fs.st_dev, fs.st_ino)

        @implementer(IFileDescriptorReceiver)
        class FakeProtocol(ConnectableProtocol):
            def __init__(self):
                self.fds = []
                self.deviceInodesReceived = []

            def fileDescriptorReceived(self, fd):
                self.fds.append(fd)
                self.deviceInodesReceived.append(deviceInodeTuple(fd))
                close(fd)

        class FakeReceiver(_SendmsgMixin):
            bufferSize = 1024

            def __init__(self, skt, proto):
                self.socket = skt
                self.protocol = proto

            def _dataReceived(self, data):
                pass

            def getHost(self):
                pass

            def getPeer(self):
                pass

            def _getLogPrefix(self, o):
                pass

        sendSocket, recvSocket = socketpair(AF_UNIX, SOCK_STREAM)
        self.addCleanup(sendSocket.close)
        self.addCleanup(recvSocket.close)

        proto = FakeProtocol()
        receiver = FakeReceiver(recvSocket, proto)

        # Temp files give us two FDs to send/receive/verify.
        fileOneFD, fileOneName = mkstemp()
        fileTwoFD, fileTwoName = mkstemp()
        self.addCleanup(unlink, fileOneName)
        self.addCleanup(unlink, fileTwoName)

        dataToSend = b"some data needs to be sent"
        fdsToSend = [fileOneFD, fileTwoFD]
        ancillary, expectedCount = ancillaryPacker(fdsToSend)
        sendmsg(sendSocket, dataToSend, ancillary)

        receiver.doRead()

        # Verify that fileDescriptorReceived was called twice.
        self.assertEqual(len(proto.fds), expectedCount)

        # Verify that received FDs are different from the sent ones.
        self.assertFalse(set(fdsToSend).intersection(set(proto.fds)))

        # Verify that FDs were received in the same order, if any.
        if proto.fds:
            deviceInodesSent = [deviceInodeTuple(fd) for fd in fdsToSend]
            self.assertEqual(deviceInodesSent, proto.deviceInodesReceived)

    @skipIf(not sendmsg, sendmsgSkipReason)
    def test_multiFileDescriptorReceivedPerRecvmsgOneCMSG(self):
        """
        _SendmsgMixin handles multiple file descriptors per recvmsg, calling
        L{IFileDescriptorReceiver.fileDescriptorReceived} once per received
        file descriptor. Scenario: single CMSG with two FDs.
        """
        from twisted.python.sendmsg import SCM_RIGHTS

        def ancillaryPacker(fdsToSend):
            ancillary = [(SOL_SOCKET, SCM_RIGHTS, pack("ii", *fdsToSend))]
            expectedCount = 2
            return ancillary, expectedCount

        self._sendmsgMixinFileDescriptorReceivedDriver(ancillaryPacker)

    @skipIf(
        platform.isMacOSX(),
        "Multi control message ancillary sendmsg not supported on Mac.",
    )
    @skipIf(not sendmsg, sendmsgSkipReason)
    def test_multiFileDescriptorReceivedPerRecvmsgTwoCMSGs(self):
        """
        _SendmsgMixin handles multiple file descriptors per recvmsg, calling
        L{IFileDescriptorReceiver.fileDescriptorReceived} once per received
        file descriptor. Scenario: two CMSGs with one FD each.
        """
        from twisted.python.sendmsg import SCM_RIGHTS

        def ancillaryPacker(fdsToSend):
            ancillary = [(SOL_SOCKET, SCM_RIGHTS, pack("i", fd)) for fd in fdsToSend]
            expectedCount = 2
            return ancillary, expectedCount

        self._sendmsgMixinFileDescriptorReceivedDriver(ancillaryPacker)

    @skipIf(not sendmsg, sendmsgSkipReason)
    def test_multiFileDescriptorReceivedPerRecvmsgBadCMSG(self):
        """
        _SendmsgMixin handles multiple file descriptors per recvmsg, calling
        L{IFileDescriptorReceiver.fileDescriptorReceived} once per received
        file descriptor. Scenario: unsupported CMSGs.
        """
        # Given that we can't just send random/invalid ancillary data via the
        # packer for it to be sent via sendmsg -- the kernel would not accept
        # it -- we'll temporarily replace recvmsg with a fake one that produces
        # a non-supported ancillary message level/type. This being said, from
        # the perspective of the ancillaryPacker, all that is required is to
        # let the test driver know that 0 file descriptors are expected.
        from twisted.python import sendmsg

        def ancillaryPacker(fdsToSend):
            ancillary = []
            expectedCount = 0
            return ancillary, expectedCount

        def fakeRecvmsgUnsupportedAncillary(skt, *args, **kwargs):
            data = b"some data"
            ancillary = [(None, None, b"")]
            flags = 0
            return sendmsg.ReceivedMessage(data, ancillary, flags)

        events = []
        addObserver(events.append)
        self.addCleanup(removeObserver, events.append)

        self.patch(sendmsg, "recvmsg", fakeRecvmsgUnsupportedAncillary)
        self._sendmsgMixinFileDescriptorReceivedDriver(ancillaryPacker)

        # Verify the expected message was logged.
        expectedMessage = "received unsupported ancillary data"
        found = any(expectedMessage in e["format"] for e in events)
        self.assertTrue(found, "Expected message not found in logged events")

    @skipIf(not sendmsg, sendmsgSkipReason)
    def test_avoidLeakingFileDescriptors(self):
        """
        If associated with a protocol which does not provide
        L{IFileDescriptorReceiver}, file descriptors received by the
        L{IUNIXTransport} implementation are closed and a warning is emitted.
        """
        # To verify this, establish a connection.  Send one end of the
        # connection over the IUNIXTransport implementation.  After the copy
        # should no longer exist, close the original.  If the opposite end of
        # the connection decides the connection is closed, the copy does not
        # exist.
        from socket import socketpair

        probeClient, probeServer = socketpair()

        events = []
        addObserver(events.append)
        self.addCleanup(removeObserver, events.append)

        class RecordEndpointAddresses(SendFileDescriptor):
            def connectionMade(self):
                self.hostAddress = self.transport.getHost()
                self.peerAddress = self.transport.getPeer()
                SendFileDescriptor.connectionMade(self)

        server = RecordEndpointAddresses(probeClient.fileno(), b"junk")
        client = ConnectableProtocol()

        runProtocolsWithReactor(self, server, client, self.endpoints)

        # Get rid of the original reference to the socket.
        probeClient.close()

        # A non-blocking recv will return "" if the connection is closed, as
        # desired.  If the connection has not been closed, because the
        # duplicate file descriptor is still open, it will fail with EAGAIN
        # instead.
        probeServer.setblocking(False)
        self.assertEqual(b"", probeServer.recv(1024))

        # This is a surprising circumstance, so it should be logged.
        format = (
            "%(protocolName)s (on %(hostAddress)r) does not "
            "provide IFileDescriptorReceiver; closing file "
            "descriptor received (from %(peerAddress)r)."
        )
        clsName = "ConnectableProtocol"

        # Reverse host and peer, since the log event is from the client
        # perspective.
        expectedEvent = dict(
            hostAddress=server.peerAddress,
            peerAddress=server.hostAddress,
            protocolName=clsName,
            format=format,
        )

        for logEvent in events:
            for k, v in expectedEvent.items():
                if v != logEvent.get(k):
                    break
            else:
                # No mismatches were found, stop looking at events
                break
        else:
            # No fully matching events were found, fail the test.
            self.fail(
                "Expected event (%s) not found in logged events (%s)"
                % (
                    expectedEvent,
                    pformat(
                        events,
                    ),
                )
            )

    @skipIf(not sendmsg, sendmsgSkipReason)
    def test_descriptorDeliveredBeforeBytes(self):
        """
        L{IUNIXTransport.sendFileDescriptor} sends file descriptors before
        L{ITransport.write} sends normal bytes.
        """

        @implementer(IFileDescriptorReceiver)
        class RecordEvents(ConnectableProtocol):
            def connectionMade(self):
                ConnectableProtocol.connectionMade(self)
                self.events = []

            def fileDescriptorReceived(innerSelf, descriptor):
                self.addCleanup(close, descriptor)
                innerSelf.events.append(type(descriptor))

            def dataReceived(self, data):
                self.events.extend(data)

        cargo = socket()
        server = SendFileDescriptor(cargo.fileno(), b"junk")
        client = RecordEvents()

        runProtocolsWithReactor(self, server, client, self.endpoints)

        self.assertEqual(int, client.events[0])
        self.assertEqual(b"junk", bytes(client.events[1:]))


class UNIXDatagramTestsBuilder(UNIXFamilyMixin, ReactorBuilder):
    """
    Builder defining tests relating to L{IReactorUNIXDatagram}.
    """

    requiredInterfaces = (interfaces.IReactorUNIXDatagram,)

    # There's no corresponding test_connectMode because the mode parameter to
    # connectUNIXDatagram has been completely ignored since that API was first
    # introduced.
    def test_listenMode(self):
        """
        The UNIX socket created by L{IReactorUNIXDatagram.listenUNIXDatagram}
        is created with the mode specified.
        """
        self._modeTest("listenUNIXDatagram", self.mktemp(), DatagramProtocol())

    @skipIf(
        not platform.isLinux(),
        "Abstract namespace UNIX sockets only " "supported on Linux.",
    )
    def test_listenOnLinuxAbstractNamespace(self):
        """
        On Linux, a UNIX socket path may begin with C{'\0'} to indicate a
        socket in the abstract namespace.  L{IReactorUNIX.listenUNIXDatagram}
        accepts such a path.
        """
        path = _abstractPath(self)
        reactor = self.buildReactor()
        port = reactor.listenUNIXDatagram("\0" + path, DatagramProtocol())
        self.assertEqual(port.getHost(), UNIXAddress("\0" + path))


class SocketUNIXMixin:
    """
    Mixin which uses L{IReactorSocket.adoptStreamPort} to hand out listening
    UNIX ports.
    """

    requiredInterfaces: Optional[Sequence[Type[Interface]]] = (
        IReactorUNIX,
        IReactorSocket,
    )

    def getListeningPort(self, reactor, factory):
        """
        Get a UNIX port from a reactor, wrapping an already-initialized file
        descriptor.
        """
        portSock = socket(AF_UNIX)
        # self.mktemp() often returns a path which is too long to be used.
        path = mktemp(suffix=".sock", dir=".")
        portSock.bind(path)
        portSock.listen(3)
        portSock.setblocking(False)
        try:
            return reactor.adoptStreamPort(portSock.fileno(), portSock.family, factory)
        finally:
            portSock.close()

    def connectToListener(self, reactor, address, factory):
        """
        Connect to a listening UNIX socket.

        @param reactor: The reactor under test.
        @type reactor: L{IReactorUNIX}

        @param address: The listening's address.
        @type address: L{UNIXAddress}

        @param factory: The client factory.
        @type factory: L{ClientFactory}

        @return: The connector
        """
        return reactor.connectUNIX(address.name, factory)


class ListenUNIXMixin:
    """
    Mixin which uses L{IReactorTCP.listenUNIX} to hand out listening UNIX
    ports.
    """

    def getListeningPort(self, reactor, factory):
        """
        Get a UNIX port from a reactor
        """
        # self.mktemp() often returns a path which is too long to be used.
        path = mktemp(suffix=".sock", dir=".")
        return reactor.listenUNIX(path, factory)

    def connectToListener(self, reactor, address, factory):
        """
        Connect to a listening UNIX socket.

        @param reactor: The reactor under test.
        @type reactor: L{IReactorUNIX}

        @param address: The listening's address.
        @type address: L{UNIXAddress}

        @param factory: The client factory.
        @type factory: L{ClientFactory}

        @return: The connector
        """
        return reactor.connectUNIX(address.name, factory)


class UNIXPortTestsMixin:
    requiredInterfaces: Optional[Sequence[Type[Interface]]] = (IReactorUNIX,)

    def getExpectedStartListeningLogMessage(self, port, factory):
        """
        Get the message expected to be logged when a UNIX port starts listening.
        """
        return f"{factory} starting on {nativeString(port.getHost().name)!r}"

    def getExpectedConnectionLostLogMsg(self, port):
        """
        Get the expected connection lost message for a UNIX port
        """
        return f"(UNIX Port {nativeString(port.getHost().name)} Closed)"


class UNIXPortTestsBuilder(
    ListenUNIXMixin,
    UNIXPortTestsMixin,
    ReactorBuilder,
    StreamTransportTestsMixin,
):
    """
    Tests for L{IReactorUNIX.listenUnix}
    """


class UNIXFDPortTestsBuilder(
    SocketUNIXMixin,
    UNIXPortTestsMixin,
    ReactorBuilder,
    StreamTransportTestsMixin,
):
    """
    Tests for L{IReactorUNIX.adoptStreamPort}
    """


class UNIXAdoptStreamConnectionTestsBuilder(WriteSequenceTestsMixin, ReactorBuilder):
    requiredInterfaces = (
        IReactorFDSet,
        IReactorSocket,
        IReactorUNIX,
    )

    def test_buildProtocolReturnsNone(self):
        """
        {IReactorSocket.adoptStreamConnection} returns None if the given
        factory's buildProtocol returns None.
        """

        # Build reactor before anything else: allow self.buildReactor()
        # to skip the test if any of the self.requiredInterfaces isn't
        # provided by the reactor (example: Windows), preventing later
        # failures unrelated to the test itself.
        reactor = self.buildReactor()

        from socket import socketpair

        class NoneFactory(ServerFactory):
            def buildProtocol(self, address):
                return None

        s1, s2 = socketpair(AF_UNIX, SOCK_STREAM)
        s1.setblocking(False)
        self.addCleanup(s1.close)
        self.addCleanup(s2.close)

        s1FD = s1.fileno()
        factory = NoneFactory()
        result = reactor.adoptStreamConnection(s1FD, AF_UNIX, factory)
        self.assertIsNone(result)

    def test_ServerAddressUNIX(self):
        """
        Helper method to test UNIX server addresses.
        """

        def connected(protocols):
            client, server, port = protocols
            try:
                portPath = _coerceToFilesystemEncoding("", port.getHost().name)
                self.assertEqual(
                    "<AccumulatingProtocol #%s on %s>"
                    % (server.transport.sessionno, portPath),
                    str(server.transport),
                )

                self.assertEqual(
                    "AccumulatingProtocol,%s,%s"
                    % (server.transport.sessionno, portPath),
                    server.transport.logstr,
                )

                peerAddress = server.factory.peerAddresses[0]
                self.assertIsInstance(peerAddress, UNIXAddress)
            finally:
                # Be certain to drop the connection so the test completes.
                server.transport.loseConnection()

        reactor = self.buildReactor()
        d = self.getConnectedClientAndServer(
            reactor, interface=None, addressFamily=None
        )
        d.addCallback(connected)
        self.runReactor(reactor)

    def getConnectedClientAndServer(self, reactor, interface, addressFamily):
        """
        Return a L{Deferred} firing with a L{MyClientFactory} and
        L{MyServerFactory} connected pair, and the listening C{Port}. The
        particularity is that the server protocol has been obtained after doing
        a C{adoptStreamConnection} against the original server connection.
        """
        firstServer = MyServerFactory()
        firstServer.protocolConnectionMade = Deferred()

        server = MyServerFactory()
        server.protocolConnectionMade = Deferred()
        server.protocolConnectionLost = Deferred()

        client = MyClientFactory()
        client.protocolConnectionMade = Deferred()
        client.protocolConnectionLost = Deferred()

        # self.mktemp() often returns a path which is too long to be used.
        path = mktemp(suffix=".sock", dir=".")
        port = reactor.listenUNIX(path, firstServer)

        def firstServerConnected(proto):
            reactor.removeReader(proto.transport)
            reactor.removeWriter(proto.transport)
            reactor.adoptStreamConnection(proto.transport.fileno(), AF_UNIX, server)

        firstServer.protocolConnectionMade.addCallback(firstServerConnected)

        lostDeferred = gatherResults(
            [client.protocolConnectionLost, server.protocolConnectionLost]
        )

        def stop(result):
            if reactor.running:
                reactor.stop()
            return result

        lostDeferred.addBoth(stop)

        deferred = Deferred()
        deferred.addErrback(stop)

        startDeferred = gatherResults(
            [client.protocolConnectionMade, server.protocolConnectionMade]
        )

        def start(protocols):
            client, server = protocols
            deferred.callback((client, server, port))

        startDeferred.addCallback(start)

        reactor.connectUNIX(port.getHost().name, client)
        return deferred


globals().update(UNIXTestsBuilder.makeTestCaseClasses())
globals().update(UNIXDatagramTestsBuilder.makeTestCaseClasses())
globals().update(UNIXPortTestsBuilder.makeTestCaseClasses())
globals().update(UNIXFDPortTestsBuilder.makeTestCaseClasses())
globals().update(UNIXAdoptStreamConnectionTestsBuilder.makeTestCaseClasses())


class UnixClientTestsBuilder(ReactorBuilder, StreamClientTestsMixin):
    """
    Define tests for L{IReactorUNIX.connectUNIX}.
    """

    requiredInterfaces = (IReactorUNIX,)

    _path = None

    @property
    def path(self):
        """
        Return a path usable by C{connectUNIX} and C{listenUNIX}.

        @return: A path instance, built with C{_abstractPath}.
        """
        if self._path is None:
            self._path = _abstractPath(self)
        return self._path

    def listen(self, reactor, factory):
        """
        Start an UNIX server with the given C{factory}.

        @param reactor: The reactor to create the UNIX port in.

        @param factory: The server factory.

        @return: A UNIX port instance.
        """
        return reactor.listenUNIX(self.path, factory)

    def connect(self, reactor, factory):
        """
        Start an UNIX client with the given C{factory}.

        @param reactor: The reactor to create the connection in.

        @param factory: The client factory.

        @return: A UNIX connector instance.
        """
        return reactor.connectUNIX(self.path, factory)


globals().update(UnixClientTestsBuilder.makeTestCaseClasses())

Zerion Mini Shell 1.0