Source code for mtc.base.struct
import io
import textwrap
import types
import typing
from typing import NamedTuple, Self
from .parser import Parser, ParserError
class Field(NamedTuple):
name: str
data_type: type[Parser]
[docs]
class StructMeta(type):
"""
The metaclass for :class:`Struct`. This is what enables the dataclass-like behavior of :class:`Struct`, but from
inheritance instead of a class decorator. It reads the class annotation and instantiates the fields accordingly, in
the order defined. It also defines :attr:`__slots__` on the inherited classes to reduce access time and memory usage.
All the metadata processed here is stored in the :attr:`_fields` attribute. For example, if you define a class like
.. code-block::
class HashEmptyInput(Struct):
hash_head: HashHead
index: UInt64
level: UInt8
Then HashEmptyInput._fields will be
.. code-block::
HashEmptyInput._fields = [
Field(name = 'hash_head', data_type = HashHead),
Field(name = 'index', data_type = UInt64),
Field(name = 'level', data_type = UInt8)
]
where :class:`Field` is a named tuple.
"""
def __new__(cls, name, bases, attrs, **kwargs):
annotations = attrs.get("__annotations__")
if annotations is None:
raise AttributeError("Struct is defined without any field")
fields = []
slots = []
for field_name, data_type in annotations.items():
if field_name == "_fields":
continue
if isinstance(data_type, types.UnionType):
for t in typing.get_args(data_type):
if not issubclass(t, Parser):
raise TypeError("Member of union must be a subclass of parser")
else:
if not isinstance(data_type, type):
raise TypeError("Struct fields must be a class or a union")
if not issubclass(data_type, Parser):
raise TypeError("Struct fields must be a subclass of parser")
fields.append(Field(field_name, data_type))
slots.append(field_name)
# use slots to reduce memory footprint and slightly increase access speed
cls_ = super().__new__(cls, name, bases, {**attrs, "__slots__": slots}, **kwargs)
cls_._fields = fields # type: ignore[attr-defined]
return cls_
[docs]
class Struct(Parser, metaclass=StructMeta):
"""
Implements a struct similar to how it works in C. With this class, you can define structs as simple as
.. code-block::
class Assertion(Struct):
subject_type: SubjectType
subject_info: SubjectInfo
claims: ClaimList
"""
_fields: list[Field] = []
def __init__(self, /, *value: Parser) -> None:
super().__setattr__("_bytes_cache", None)
super().__setattr__("value", list(value))
[docs]
@classmethod
def parse(cls, stream: io.BufferedIOBase) -> Self:
parsed = []
for f in cls._fields:
if isinstance(f.data_type, types.UnionType):
for d_type in typing.get_args(f.data_type):
initial = stream.tell()
try:
res = d_type.parse(stream)
except ParserError:
# revert attempt
stream.seek(initial)
else:
parsed.append(res)
break
else:
raise cls.ParsingError(initial, initial, "Cannot decode data as any datatype of the union")
else:
parsed.append(f.data_type.parse(stream))
return cls(*parsed)
[docs]
@classmethod
def skip(cls, stream: io.BufferedIOBase) -> None:
for f in cls._fields:
if isinstance(f.data_type, types.UnionType):
raise NotImplementedError(f"Skipping unions must be implemented in subclass (missing in {cls.__name__})")
else:
f.data_type.skip(stream)
[docs]
def to_bytes(self) -> bytes:
if self._bytes_cache is None:
# using BytesIO because repeated byte concatenation is very slow
bio = io.BytesIO()
for v in self.value:
bio.write(v.to_bytes())
super().__setattr__("_bytes_cache", bio.getvalue())
return self._bytes_cache # type:ignore[return-value]
[docs]
def print(self) -> str:
header = "-" * 20 + f"Struct {self.__class__.__name__} ({len(self)})" + "-" * 20 + "\n"
footer = "-" * 18 + f"End struct {self.__class__.__name__}" + "-" * 18
inner = ""
for v in self.value:
inner += v.print() + "\n"
return header + textwrap.indent(inner, "\t") + footer
def __getattr__(self, item: str):
for i, f in enumerate(self._fields):
if f.name == item:
return self.value[i]
else:
raise AttributeError
def __setattr__(self, key, value):
if self._bytes_cache is not None:
raise AttributeError("Cannot set attrs after to_bytes() is called")
for i, f in enumerate(self._fields):
if f.name == key:
self.value[i] = value
super().__setattr__(key, value)
[docs]
def validate(self) -> None:
"""
Checks if all fields passed into the struct initializer are of the correct type in the correct order
"""
if len(self.value) != len(self._fields):
raise ValueError("Input to a struct must have the same length as struct definition")
for i, v in enumerate(self.value):
name, data_type = self._fields[i]
if not isinstance(v, data_type):
raise ValueError(
f"Item {i} of input to {self.__class__.__name__} is not of type {data_type.__name__} (found {v.__class__.__name__})")
__all__ = ["Struct"]