Source code for runtimepy.net.udp.tftp.endpoint

"""
A module implementing an interface for individual tftp endpoints.
"""

# built-in
import asyncio
from contextlib import AsyncExitStack, suppress
import logging
from pathlib import Path
from typing import BinaryIO, Callable, Optional, Union

# third-party
from vcorelib.asyncio.poll import repeat_until
from vcorelib.logging import LoggerMixin, LoggerType
from vcorelib.math import RateLimiter
from vcorelib.paths.info import FileInfo

# internal
from runtimepy.net import IpHost
from runtimepy.net.udp.tftp.enums import TftpErrorCode
from runtimepy.net.udp.tftp.io import tftp_chunks
from runtimepy.primitives import Double

TftpDataSender = Callable[[int, bytes, Union[IpHost, tuple[str, int]]], None]
TftpAckSender = Callable[[int, Union[IpHost, tuple[str, int]]], None]
TftpErrorSender = Callable[
    [TftpErrorCode, str, Union[IpHost, tuple[str, int]]], None
]

TFTP_MAX_BLOCK = 512

DALLY_PERIOD = 0.05
DALLY_TIMEOUT = 0.25


[docs] class TftpEndpoint(LoggerMixin): """A data structure for endpoint-related runtime storage.""" def __init__( self, root: Path, logger: LoggerType, addr: IpHost, data_sender: TftpDataSender, ack_sender: TftpAckSender, error_sender: TftpErrorSender, period: Double, timeout: Double, ) -> None: """Initialize instance.""" super().__init__(logger=logger) self._path = root self.addr = addr self.data_sender = data_sender self.ack_sender = ack_sender self.error_sender = error_sender # Avoid concurrency bugs when actively writing or reading. self.lock = asyncio.Lock() # Message receiving. self.awaiting_acks: dict[int, asyncio.Event] = {} self.awaiting_blocks: dict[int, asyncio.Event] = {} self.blocks: dict[int, bytes] = {} # Can be upgraded via RFC 2347. self.max_block_size = TFTP_MAX_BLOCK # Runtime settings. self.period = period self.timeout = timeout self.log_limiter = RateLimiter.from_s(1.0)
[docs] def update_from_other(self, other: "TftpEndpoint") -> "TftpEndpoint": """Update this endpoint's attributes with attributes of another's.""" self.logger.info("Updating address to '%s'.", other.addr) self.addr = other.addr return self
[docs] def chunk_sender(self, block: int, data: bytes) -> Callable[[], None]: """Create a method that sends a specific block of data.""" def sender() -> None: """Send a block of data.""" self.data_sender(block, data, self.addr) return sender
def _ack_sender(self, block: int) -> Callable[[], None]: """ Create a method that sends an acknowledgement for a specific block number. """ def sender() -> None: """Send an acknowledgement.""" self.ack_sender(block, self.addr) return sender
[docs] def set_root(self, path: Path) -> None: """Set a new root path for this instance.""" self._path = path
[docs] def handle_data(self, block: int, data: bytes) -> None: """Handle a data payload.""" if block in self.awaiting_blocks: self.blocks[block] = data self.awaiting_blocks[block].set() del self.awaiting_blocks[block] else: self.error_sender( TftpErrorCode.UNKNOWN_ID, "Not expecting any data (got " f"block={block} - {len(data)} bytes)", self.addr, )
[docs] def handle_ack(self, block: int) -> None: """Handle a block acknowledgement.""" if block in self.awaiting_acks: self.awaiting_acks[block].set() del self.awaiting_acks[block] else: self.governed_log( self.log_limiter, "Not expecting any ack (got %d).", block, level=logging.ERROR, )
# Sending an error seems to cause more harm than good. # self.error_sender(TftpErrorCode.UNKNOWN_ID, msg, self.addr) def __str__(self) -> str: """Get this instance as a string.""" return str(self.addr)
[docs] def handle_error(self, error_code: TftpErrorCode, message: str) -> None: """Handle a tftp error message.""" self.governed_log( self.log_limiter, "%s '%s' %s.", self, error_code.name, message, level=logging.ERROR, )
[docs] async def ingest_file(self, stream: BinaryIO) -> bool: """Ingest incoming file data and write to a stream.""" keep_going = True idx = 1 curr_size = 0 written = 0 while keep_going: # Set up event trigger for expected data payload. event = asyncio.Event() self.awaiting_blocks[idx] = event keep_going = ( await repeat_until( # Acknowledge the previous message until we get new # data. self._ack_sender(idx - 1), event, self.period.value, self.timeout.value, ) and idx in self.blocks ) if keep_going: # Write chunk. data = self.blocks[idx] curr_size = len(data) # If this occurs, it's probably RFC 2348 (using this assertion # to determine practical need for that support). assert curr_size <= self.max_block_size, curr_size stream.write(data) written += curr_size # We only expect future iterations if data payloads are # saturated. keep_going = curr_size >= self.max_block_size # Ensure state is cleaned up. self.blocks.pop(idx, None) self.awaiting_blocks.pop(idx, None) if keep_going: idx += 1 # Send the final acknowledgement for a bit ("dally" per rfc). success = written > 0 and curr_size < self.max_block_size if success: await repeat_until( self._ack_sender(idx), asyncio.Event(), DALLY_PERIOD, DALLY_TIMEOUT, ) return success
async def _process_write_request(self, path: Path, mode: str) -> None: """Process a write request.""" async with AsyncExitStack() as stack: # Claim write lock and ignore cancellation. stack.enter_context(suppress(asyncio.CancelledError)) await stack.enter_async_context(self.lock) path_fd = stack.enter_context(path.open("wb")) with self.log_time( "Ingesting (%s) '%s'", mode, path, reminder=True ): success = await self.ingest_file(path_fd) self.logger.info( "%s to write (%s) '%s' from %s:%d.", "Succeeded" if success else "Failed", mode, FileInfo.from_file(path), self.addr[0], self.addr[1], )
[docs] def handle_write_request( self, filename: str, mode: str ) -> Optional[asyncio.Task[None]]: """Handle a write request.""" path = self.get_path(filename) # Ensure we can service this request. if not self._check_permission(path, "wb"): return None return asyncio.create_task(self._process_write_request(path, mode))
[docs] async def serve_file(self, path: Path) -> bool: """Serve file chunks via this endpoint.""" # Set up (outgoing) transaction. success = True idx = 1 with self.log_time( "Serving '%s'", FileInfo.from_file(path), reminder=True ): for chunk in tftp_chunks(path, self.max_block_size): # Validate index. Remove at some point? assert idx not in self.awaiting_acks, idx assert idx < 2**16, idx # Prepare event trigger. event = asyncio.Event() self.awaiting_acks[idx] = event if not await repeat_until( self.chunk_sender(idx, chunk), event, self.period.value, self.timeout.value, ): success = False self.awaiting_acks.pop(idx, None) break idx += 1 return success
async def _process_read_request(self, path: Path, mode: str) -> None: """ Service a read request by sending file chunk data. """ async with AsyncExitStack() as stack: # Claim read lock and ignore cancellation. stack.enter_context(suppress(asyncio.CancelledError)) await stack.enter_async_context(self.lock) success = await self.serve_file(path) self.logger.info( "%s to serve (%s) '%s' to %s:%d.", "Succeeded" if success else "Failed", mode, FileInfo.from_file(path), self.addr[0], self.addr[1], )
[docs] def get_path(self, filename: str) -> Path: """Get a path from a filename.""" return self._path.joinpath(filename)
[docs] def handle_read_request( self, filename: str, mode: str ) -> Optional[asyncio.Task[None]]: """Handle a read-request message.""" path = self.get_path(filename) # Ensure we can service this request. if not self._check_exists(path) or not self._check_permission( path, "rb" ): return None return asyncio.create_task(self._process_read_request(path, mode))
def _check_permission(self, path: Path, mode: str) -> bool: """ Check if a path can be opened in the provided mode, send an error if not. """ result = False try: with path.open(mode): pass result = True except PermissionError: self.error_sender( TftpErrorCode.ACCESS_VIOLATION, f"Can't open={mode} '{path}'", self.addr, ) return result def _check_exists(self, path: Path) -> bool: """Check if a file exists, send an error if not.""" result = path.is_file() if not result: self.error_sender( TftpErrorCode.FILE_NOT_FOUND, f"Path '{path}' is not a file", self.addr, ) return result