Skip to content

General-purpose

beignet.features.Feature

Bases: Tensor

Feature

Source code in src/beignet/features/_feature.py
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
class Feature(Tensor):
    """
    Feature
    """

    __f: Optional[ModuleType] = None

    @staticmethod
    def _to_tensor(
        data: Any,
        dtype: Optional[torch.dtype] = None,
        device: Optional[Union[torch.device, str, int]] = None,
        requires_grad: Optional[bool] = None,
    ) -> Tensor:
        if requires_grad is None:
            if isinstance(data, Tensor):
                requires_grad = data.requires_grad
            else:
                requires_grad = False

        tensor = torch.as_tensor(data, dtype=dtype, device=device)

        return tensor.requires_grad_(requires_grad)

    @classmethod
    def wrap_like(cls: Type[F], other: F, tensor: Tensor) -> F:
        raise NotImplementedError

    # NOTE:
    #   We don’t need to wrap the output of ``tensor.requires_grad_``,
    #   because it’s an inplace operation and automatically retains its type.
    _NO_WRAPPING_EXCEPTIONS = {
        Tensor.clone: lambda cls, input, output: cls.wrap_like(input, output),
        Tensor.detach: lambda cls, input, output: cls.wrap_like(input, output),
        Tensor.requires_grad_: lambda cls, input, output: output,
        Tensor.to: lambda cls, input, output: cls.wrap_like(input, output),
    }

    @classmethod
    def __torch_function__(
        cls,
        func: Callable[..., Tensor],
        types: Tuple[Type[Tensor], ...],
        args: Sequence[Any] = (),
        kwargs: Optional[Mapping[str, Any]] = None,
    ) -> Tensor:
        """
        The default behavior of :class:`~Tensor`’s retains the custom tensor
        type. For :class:`Feature`, this creates two problems:

            1.  :class:`Feature` may require metadata and the default wrapping,
                i.e., ``return cls(func(*args, **kwargs))``, will fail.

            2.  For most operations, there is no way of knowing if an input
                type is still valid for the output type.

        To address these two problems, :class:`Feature` disables automatic
        output wrapping for most operators. The exceptions are available from
        :attr:`Feature._NO_WRAPPING_EXCEPTIONS`
        """
        # NOTE:
        #   ``super().__torch_function__`` has no hook to prevent the
        #   coercing of the output type into the input type so this
        #   functionality is reimplemented.
        if not all(issubclass(cls, t) for t in types):
            return NotImplemented

        with DisableTorchFunctionSubclass():
            output = func(*args, **kwargs or dict())

            wrapper = cls._NO_WRAPPING_EXCEPTIONS.get(func)

            # NOTE:
            #   Besides ``func`` being an exception, this class method
            #   requires the input operand, i.e. ``args[0]``, to be an
            #   instance of the class that ``__torch_function__`` invoked.
            #   The ``__torch_function__`` protocol will invoke this method
            #   on each type in the computation by walking the method
            #   resolution  order (MRO). For example,
            #   ``Tensor(...).to(beignet.features.Foo( ...))`` invokes
            #   ``beignet.features.Foo.__torch_function__`` with
            #   ``args = (Tensor(), beignet.featues.Foo())``.
            #   Without this guard, ``Tensor`` would be wrapped into
            #   ``beignet.features.Foo``.
            if wrapper and isinstance(args[0], cls):
                return wrapper(cls, args[0], output)

            # NOTE:
            #   Because inplace ``func``’s, canonically identified with a
            #   trailing underscore in their name, e.g., ``.add_(...)``,
            #   retain their input type, they need to be unwrapped.
            if isinstance(output, cls):
                return output.as_subclass(Tensor)

            return output

    def _make_repr(self, **kwargs: Any) -> str:
        items = []

        for key, value in kwargs.items():
            items.append(f"{key}={value}")

        return f"{super().__repr__()[:-1]}, {', '.join(items)})"

    @property
    def _f(self) -> ModuleType:
        # NOTE:
        #   Lazy import of ``beignet.transforms.functional`` to bypass the
        #   ``ImportError`` raised by the circual import. The
        #   ``beignet.transforms.functional`` import is deferred until the
        #   functional module is referenced and it’s shared across all
        #   instances of the class.
        if Feature.__f is None:
            import beignet.transforms.functional

            Feature.__f = beignet.transforms.functional

        return Feature.__f

    @property
    def device(
        self,
        *args: Any,
        **kwargs: Any,
    ) -> _device:  # type: ignore[override]
        with DisableTorchFunctionSubclass():
            return super().device

    @property
    def ndim(self) -> int:  # type: ignore[override]
        with DisableTorchFunctionSubclass():
            return super().ndim

    @property
    def dtype(self) -> _dtype:  # type: ignore[override]
        with DisableTorchFunctionSubclass():
            return super().dtype

    @property
    def shape(self) -> _size:  # type: ignore[override]
        with DisableTorchFunctionSubclass():
            return super().shape

    def __deepcopy__(self: F, memo: Dict[int, Any]) -> F:
        # NOTE:
        #   Detach, because ``deepcopy(Tensor)``, unlike ``Tensor.clone``,
        #   isn’t be added to the computational graph.

        # NOTE:
        #   Because a side-effect of detaching is clearing
        #   ``Tensor.requires_grad``, it’s refilled before returning.

        # NOTE:
        #   Deep-copying of metadata isn’t explicitly handled.
        return (
            self.detach()
            .clone()
            .requires_grad_(
                self.requires_grad,
            )
        )  # type: ignore[return-value]
__torch_function__ classmethod
__torch_function__(func, types, args=(), kwargs=None)

The default behavior of :class:~Tensor’s retains the custom tensor type. For :class:Feature, this creates two problems:

1.  :class:`Feature` may require metadata and the default wrapping,
    i.e., ``return cls(func(*args, **kwargs))``, will fail.

2.  For most operations, there is no way of knowing if an input
    type is still valid for the output type.

To address these two problems, :class:Feature disables automatic output wrapping for most operators. The exceptions are available from :attr:Feature._NO_WRAPPING_EXCEPTIONS

Source code in src/beignet/features/_feature.py
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
@classmethod
def __torch_function__(
    cls,
    func: Callable[..., Tensor],
    types: Tuple[Type[Tensor], ...],
    args: Sequence[Any] = (),
    kwargs: Optional[Mapping[str, Any]] = None,
) -> Tensor:
    """
    The default behavior of :class:`~Tensor`’s retains the custom tensor
    type. For :class:`Feature`, this creates two problems:

        1.  :class:`Feature` may require metadata and the default wrapping,
            i.e., ``return cls(func(*args, **kwargs))``, will fail.

        2.  For most operations, there is no way of knowing if an input
            type is still valid for the output type.

    To address these two problems, :class:`Feature` disables automatic
    output wrapping for most operators. The exceptions are available from
    :attr:`Feature._NO_WRAPPING_EXCEPTIONS`
    """
    # NOTE:
    #   ``super().__torch_function__`` has no hook to prevent the
    #   coercing of the output type into the input type so this
    #   functionality is reimplemented.
    if not all(issubclass(cls, t) for t in types):
        return NotImplemented

    with DisableTorchFunctionSubclass():
        output = func(*args, **kwargs or dict())

        wrapper = cls._NO_WRAPPING_EXCEPTIONS.get(func)

        # NOTE:
        #   Besides ``func`` being an exception, this class method
        #   requires the input operand, i.e. ``args[0]``, to be an
        #   instance of the class that ``__torch_function__`` invoked.
        #   The ``__torch_function__`` protocol will invoke this method
        #   on each type in the computation by walking the method
        #   resolution  order (MRO). For example,
        #   ``Tensor(...).to(beignet.features.Foo( ...))`` invokes
        #   ``beignet.features.Foo.__torch_function__`` with
        #   ``args = (Tensor(), beignet.featues.Foo())``.
        #   Without this guard, ``Tensor`` would be wrapped into
        #   ``beignet.features.Foo``.
        if wrapper and isinstance(args[0], cls):
            return wrapper(cls, args[0], output)

        # NOTE:
        #   Because inplace ``func``’s, canonically identified with a
        #   trailing underscore in their name, e.g., ``.add_(...)``,
        #   retain their input type, they need to be unwrapped.
        if isinstance(output, cls):
            return output.as_subclass(Tensor)

        return output