"""
A module implementing an interface for individual tftp endpoints.
"""
# built-in
import asyncio
from contextlib import AsyncExitStack, suppress
import logging
from pathlib import Path
from typing import BinaryIO, Callable, Optional, Union
# third-party
from vcorelib.asyncio.poll import repeat_until
from vcorelib.logging import LoggerMixin, LoggerType
from vcorelib.math import RateLimiter
from vcorelib.paths.info import FileInfo
# internal
from runtimepy.net import IpHost
from runtimepy.net.udp.tftp.enums import TftpErrorCode
from runtimepy.net.udp.tftp.io import tftp_chunks
from runtimepy.primitives import Double
TftpDataSender = Callable[[int, bytes, Union[IpHost, tuple[str, int]]], None]
TftpAckSender = Callable[[int, Union[IpHost, tuple[str, int]]], None]
TftpErrorSender = Callable[
    [TftpErrorCode, str, Union[IpHost, tuple[str, int]]], None
]
TFTP_MAX_BLOCK = 512
DALLY_PERIOD = 0.05
DALLY_TIMEOUT = 0.25
[docs]
class TftpEndpoint(LoggerMixin):
    """A data structure for endpoint-related runtime storage."""
    def __init__(
        self,
        root: Path,
        logger: LoggerType,
        addr: IpHost,
        data_sender: TftpDataSender,
        ack_sender: TftpAckSender,
        error_sender: TftpErrorSender,
        period: Double,
        timeout: Double,
    ) -> None:
        """Initialize instance."""
        super().__init__(logger=logger)
        self._path = root
        self.addr = addr
        self.data_sender = data_sender
        self.ack_sender = ack_sender
        self.error_sender = error_sender
        # Avoid concurrency bugs when actively writing or reading.
        self.lock = asyncio.Lock()
        # Message receiving.
        self.awaiting_acks: dict[int, asyncio.Event] = {}
        self.awaiting_blocks: dict[int, asyncio.Event] = {}
        self.blocks: dict[int, bytes] = {}
        # Can be upgraded via RFC 2347.
        self.max_block_size = TFTP_MAX_BLOCK
        # Runtime settings.
        self.period = period
        self.timeout = timeout
        self.log_limiter = RateLimiter.from_s(1.0)
[docs]
    def update_from_other(self, other: "TftpEndpoint") -> "TftpEndpoint":
        """Update this endpoint's attributes with attributes of another's."""
        self.logger.info("Updating address to '%s'.", other.addr)
        self.addr = other.addr
        return self 
[docs]
    def chunk_sender(self, block: int, data: bytes) -> Callable[[], None]:
        """Create a method that sends a specific block of data."""
        def sender() -> None:
            """Send a block of data."""
            self.data_sender(block, data, self.addr)
        return sender 
    def _ack_sender(self, block: int) -> Callable[[], None]:
        """
        Create a method that sends an acknowledgement for a specific block
        number.
        """
        def sender() -> None:
            """Send an acknowledgement."""
            self.ack_sender(block, self.addr)
        return sender
[docs]
    def set_root(self, path: Path) -> None:
        """Set a new root path for this instance."""
        self._path = path 
[docs]
    def handle_data(self, block: int, data: bytes) -> None:
        """Handle a data payload."""
        if block in self.awaiting_blocks:
            self.blocks[block] = data
            self.awaiting_blocks[block].set()
            del self.awaiting_blocks[block]
        else:
            self.error_sender(
                TftpErrorCode.UNKNOWN_ID,
                "Not expecting any data (got "
                f"block={block} - {len(data)} bytes)",
                self.addr,
            ) 
[docs]
    def handle_ack(self, block: int) -> None:
        """Handle a block acknowledgement."""
        if block in self.awaiting_acks:
            self.awaiting_acks[block].set()
            del self.awaiting_acks[block]
        else:
            self.governed_log(
                self.log_limiter,
                "Not expecting any ack (got %d).",
                block,
                level=logging.ERROR,
            ) 
            # Sending an error seems to cause more harm than good.
            # self.error_sender(TftpErrorCode.UNKNOWN_ID, msg, self.addr)
    def __str__(self) -> str:
        """Get this instance as a string."""
        return str(self.addr)
[docs]
    def handle_error(self, error_code: TftpErrorCode, message: str) -> None:
        """Handle a tftp error message."""
        self.governed_log(
            self.log_limiter,
            "%s '%s' %s.",
            self,
            error_code.name,
            message,
            level=logging.ERROR,
        ) 
[docs]
    async def ingest_file(self, stream: BinaryIO) -> bool:
        """Ingest incoming file data and write to a stream."""
        keep_going = True
        idx = 1
        curr_size = 0
        written = 0
        while keep_going:
            # Set up event trigger for expected data payload.
            event = asyncio.Event()
            self.awaiting_blocks[idx] = event
            keep_going = (
                await repeat_until(
                    # Acknowledge the previous message until we get new
                    # data.
                    self._ack_sender(idx - 1),
                    event,
                    self.period.value,
                    self.timeout.value,
                )
                and idx in self.blocks
            )
            if keep_going:
                # Write chunk.
                data = self.blocks[idx]
                curr_size = len(data)
                # If this occurs, it's probably RFC 2348 (using this assertion
                # to determine practical need for that support).
                assert curr_size <= self.max_block_size, curr_size
                stream.write(data)
                written += curr_size
                # We only expect future iterations if data payloads are
                # saturated.
                keep_going = curr_size >= self.max_block_size
            # Ensure state is cleaned up.
            self.blocks.pop(idx, None)
            self.awaiting_blocks.pop(idx, None)
            if keep_going:
                idx += 1
        # Send the final acknowledgement for a bit ("dally" per rfc).
        success = written > 0 and curr_size < self.max_block_size
        if success:
            await repeat_until(
                self._ack_sender(idx),
                asyncio.Event(),
                DALLY_PERIOD,
                DALLY_TIMEOUT,
            )
        return success 
    async def _process_write_request(self, path: Path, mode: str) -> None:
        """Process a write request."""
        async with AsyncExitStack() as stack:
            # Claim write lock and ignore cancellation.
            stack.enter_context(suppress(asyncio.CancelledError))
            await stack.enter_async_context(self.lock)
            path_fd = stack.enter_context(path.open("wb"))
            with self.log_time(
                "Ingesting (%s) '%s'", mode, path, reminder=True
            ):
                success = await self.ingest_file(path_fd)
            self.logger.info(
                "%s to write (%s) '%s' from %s:%d.",
                "Succeeded" if success else "Failed",
                mode,
                FileInfo.from_file(path),
                self.addr[0],
                self.addr[1],
            )
[docs]
    def handle_write_request(
        self, filename: str, mode: str
    ) -> Optional[asyncio.Task[None]]:
        """Handle a write request."""
        path = self.get_path(filename)
        # Ensure we can service this request.
        if not self._check_permission(path, "wb"):
            return None
        return asyncio.create_task(self._process_write_request(path, mode)) 
[docs]
    async def serve_file(self, path: Path) -> bool:
        """Serve file chunks via this endpoint."""
        # Set up (outgoing) transaction.
        success = True
        idx = 1
        with self.log_time(
            "Serving '%s'", FileInfo.from_file(path), reminder=True
        ):
            for chunk in tftp_chunks(path, self.max_block_size):
                # Validate index. Remove at some point?
                assert idx not in self.awaiting_acks, idx
                assert idx < 2**16, idx
                # Prepare event trigger.
                event = asyncio.Event()
                self.awaiting_acks[idx] = event
                if not await repeat_until(
                    self.chunk_sender(idx, chunk),
                    event,
                    self.period.value,
                    self.timeout.value,
                ):
                    success = False
                    self.awaiting_acks.pop(idx, None)
                    break
                idx += 1
        return success 
    async def _process_read_request(self, path: Path, mode: str) -> None:
        """
        Service a read request by sending file chunk data.
        """
        async with AsyncExitStack() as stack:
            # Claim read lock and ignore cancellation.
            stack.enter_context(suppress(asyncio.CancelledError))
            await stack.enter_async_context(self.lock)
            success = await self.serve_file(path)
            self.logger.info(
                "%s to serve (%s) '%s' to %s:%d.",
                "Succeeded" if success else "Failed",
                mode,
                FileInfo.from_file(path),
                self.addr[0],
                self.addr[1],
            )
[docs]
    def get_path(self, filename: str) -> Path:
        """Get a path from a filename."""
        return self._path.joinpath(filename) 
[docs]
    def handle_read_request(
        self, filename: str, mode: str
    ) -> Optional[asyncio.Task[None]]:
        """Handle a read-request message."""
        path = self.get_path(filename)
        # Ensure we can service this request.
        if not self._check_exists(path) or not self._check_permission(
            path, "rb"
        ):
            return None
        return asyncio.create_task(self._process_read_request(path, mode)) 
    def _check_permission(self, path: Path, mode: str) -> bool:
        """
        Check if a path can be opened in the provided mode, send an error if
        not.
        """
        result = False
        try:
            with path.open(mode):
                pass
            result = True
        except PermissionError:
            self.error_sender(
                TftpErrorCode.ACCESS_VIOLATION,
                f"Can't open={mode} '{path}'",
                self.addr,
            )
        return result
    def _check_exists(self, path: Path) -> bool:
        """Check if a file exists, send an error if not."""
        result = path.is_file()
        if not result:
            self.error_sender(
                TftpErrorCode.FILE_NOT_FOUND,
                f"Path '{path}' is not a file",
                self.addr,
            )
        return result