from __future__ import annotations

import functools
import inspect
from typing import Union, Type, Callable, Sequence, Optional, Set, Iterator, Tuple

from torch import nn, Tensor

from analogvnn.backward.BackwardFunction import BackwardFunction
from analogvnn.backward.BackwardModule import BackwardModule
from analogvnn.graph.ArgsKwargs import ArgsKwargs, ArgsKwargsOutput
from analogvnn.utils.common_types import TENSORS

__all__ = ['Layer']

def __nn_Module_init_updated__(function: Callable) -> Callable:
    """Wrapper for nn.Module.__init__ to support multiple parent classes at same time.

        function (Callable): nn.Module.__init__ function

        Callable: Wrapped function

    # noinspection PyUnusedLocal
    def _temp(*args, **kwargs) -> ...:

    def new_function(self, *args, **kwargs):
        super_init = None
        next_mro_index = self.__class__.__mro__.index(nn.Module) + 1
        next_mro_class = self.__class__.__mro__[next_mro_index]

        if next_mro_class is not object:
            super_init = next_mro_class.__init__
            next_mro_class.__init__ = _temp

        function(self, *args, **kwargs)

        if next_mro_class is not object:
            next_mro_class.__init__ = super_init
            super(nn.Module, self).__init__()

    return new_function

if not hasattr(nn.Module, 'call_super_init'):
    nn.Module.__init__ = __nn_Module_init_updated__(nn.Module.__init__)
    """nn.Module.__init__ is updated to support multiple parent classes at same time. """

[docs]class Layer(nn.Module): """Base class for analog neural network modules. Attributes: _inputs (Union[None, ArgsKwargs]): Inputs of the layer. _outputs (Union[None, Tensor, Sequence[Tensor]]): Outputs of the layer. _backward_module (Optional[BackwardModule]): Backward module of the layer. _use_autograd_graph (bool): If True, the autograd graph is used to calculate the gradients. call_super_init (bool): If True, the super class __init__ of nn.Module is called """
[docs] _inputs: Union[None, ArgsKwargs]
[docs] _outputs: Union[None, Tensor, Sequence[Tensor]]
[docs] _backward_module: Optional[BackwardModule]
[docs] _use_autograd_graph: bool
[docs] call_super_init: bool = True
def __init__(self): """Initializes the layer.""" super().__init__() self._inputs = None self._outputs = None self._backward_module = None self._use_autograd_graph = False
[docs] def __call__(self, *inputs, **kwargs): """Calls the forward pass of neural network layer. Args: *inputs: Inputs of the forward pass. **kwargs: Keyword arguments of the forward pass. """ self._forward_wrapper(self.forward) outputs = super().__call__(*inputs, **kwargs) if self._inputs = ArgsKwargs(args=inputs, kwargs=kwargs) self._outputs = outputs return outputs
[docs] def use_autograd_graph(self) -> bool: """If True, the autograd graph is used to calculate the gradients. Returns: bool: use_autograd_graph. """ return self._use_autograd_graph
@use_autograd_graph.setter def use_autograd_graph(self, use_autograd_graph: bool): """Sets the use_autograd_graph attribute. Args: use_autograd_graph (bool): use_autograd_graph. """ self._use_autograd_graph = use_autograd_graph @property
[docs] def inputs(self) -> ArgsKwargsOutput: """Inputs of the layer. Returns: ArgsKwargsOutput: inputs. """ return ArgsKwargs.from_args_kwargs_object(self._inputs)
[docs] def outputs(self) -> Union[None, Tensor, Sequence[Tensor]]: """Outputs of the layer. Returns: Union[None, Tensor, Sequence[Tensor]]: outputs. """ return self._outputs
[docs] def backward_function(self) -> Union[None, Callable, BackwardModule]: """Backward module of the layer. Returns: Union[None, Callable, BackwardModule]: backward_function. """ if self._backward_module is not None: return self._backward_module if isinstance(self, BackwardModule): return self return None
@backward_function.setter def backward_function(self, function: Union[BackwardModule, Type[BackwardModule], Callable]): """Sets the backward_function attribute. Args: function (Union[BackwardModule, Type[BackwardModule], Callable]): backward_function. """ self.set_backward_function(function)
[docs] def set_backward_function(self, backward_class: Union[Callable, BackwardModule, Type[BackwardModule]]) -> Layer: """Sets the backward_function attribute. Args: backward_class (Union[Callable, BackwardModule, Type[BackwardModule]]): backward_function. Returns: Layer: self. Raises: TypeError: If backward_class is not a callable or BackwardModule. """ if backward_class == self: return self if inspect.isclass(backward_class) and issubclass(backward_class, BackwardModule): self._backward_module = backward_class(self) elif isinstance(backward_class, BackwardModule): backward_class.set_layer(self) self._backward_module = backward_class elif callable(backward_class): self._backward_module = BackwardFunction(backward_class, self) else: raise TypeError(f'Backward Module is not set for "{self}"') return self
[docs] def named_registered_children( self, memo: Optional[Set[nn.Module]] = None ) -> Iterator[Tuple[str, nn.Module]]: """Returns an iterator over immediate registered children modules. Args: memo: a memo to store the set of modules already added to the result Yields: (str, Module): Tuple containing a name and child module Note: Duplicate modules are returned only once. In the following example, ``l`` will be returned only once. """ if memo is None: memo = set() memo.add(self) memo.add(self.backward_function) for name, module in self.named_children(): if module in memo: continue yield name, module
[docs] def registered_children(self) -> Iterator[nn.Module]: r"""Returns an iterator over immediate registered children modules. Yields: nn.Module: a module in the network Note: Duplicate modules are returned only once. In the following example, ``l`` will be returned only once. """ for _, module in self.named_registered_children(): yield module
[docs] def _forward_wrapper(self, function: Callable) -> Callable: """Wrapper for the forward function. Args: function (Callable): Forward function. Returns: Callable: Wrapped function. """ # noinspection PyUnresolvedReferences if hasattr(function, '__wrapper__') and function.__wrapper__ == Layer._forward_wrapper: return function if not isinstance(self.backward_function, BackwardModule): return function if not self.backward_function.has_forward(): self.backward_function.forward = self.forward @functools.wraps(function) def new_forward(*args, **kwargs): return self.backward_function.auto_apply( *args, to_apply=self.use_autograd_graph, **kwargs ) new_forward.__wrapped__ = function new_forward.__wrapper__ = Layer._forward_wrapper self.forward = new_forward return new_forward
[docs] def _call_impl_forward(self, *args: Tensor, **kwargs: Tensor) -> TENSORS: """Calls the forward pass of the layer. Args: *args: Inputs of the forward pass. **kwargs: Keyword arguments of the forward pass. Returns: TENSORS: Outputs of the forward pass. """ if isinstance(self.backward_function, BackwardModule) and self.backward_function.has_forward(): forward_functions = self.backward_function.forward else: forward_functions = self.forward if hasattr(forward_functions, '__wrapped__'): forward_functions = forward_functions.__wrapped__ return forward_functions(*args, **kwargs)