Skip to content

General-purpose

beignet.transforms.Transform

Bases: Module

Source code in src/beignet/transforms/_transform.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
class Transform(Module):
    # Class attribute defining transformed types. Other types are
    # passed-through without any transformation
    #
    # We support both Types and callables that are able to do further checks
    # on the type of the input.
    _transformed_types: Tuple[Union[Type, Callable[[Any], bool]], ...] = (
        Tensor,
        str,
    )

    def __init__(self):
        super().__init__()

    def _check_inputs(self, inputs: List[Any]):
        """
        Parameters
        ----------
        inputs : List[Any]
            The inputs to be checked.
        """
        raise NotImplementedError

    def _get_params(self, inputs: List[Any]) -> Dict[str, Any]:
        """
        Parameters
        ----------
        inputs: List[Any]
            The list of input objects.

        Returns
        -------
        Dict[str, Any]
            A dictionary containing the parameters of the method.
        """
        return dict()

    def _transform(self, input: Any, parameters: Dict[str, Any]) -> Any:
        """
        Parameters
        ----------
        input : Any
            The input data to be transformed.

        parameters : Dict[str, Any]
            A dictionary containing the parameters for the transformation.

        Returns
        -------
        Any
            The transformed output data.

        Note
        ----
        This method is expected to be implemented by a subclass of the
        ``Transform`` class. It raises a ``NotImplementedError`` to indicate
        that the subclass must provide its own implementation of the
        ``_transform`` method.
        """
        raise NotImplementedError

    def forward(self, *inputs: Any) -> Any:
        """
        Parameters
        ----------
        inputs : tuple
            The input tensors to be transformed. Can have multiple tensors.

        Returns
        -------
        tensor
            The transformed input tensors.
        """
        if len(inputs) > 1:
            flattened, spec = torch.utils._pytree.tree_flatten(inputs)
        else:
            flattened, spec = torch.utils._pytree.tree_flatten(inputs[0])

        self._check_inputs(flattened)

        transformables = self._transformables(flattened)

        inputs = []

        for x, transformable in zip(flattened, transformables, strict=False):
            if transformable:
                inputs = [*inputs, x]

        inputs = self._get_params(inputs)

        ys = []

        for x, transformable in zip(flattened, transformables, strict=False):
            if transformable:
                y = self._transform(x, inputs)
            else:
                y = x

            ys = [*ys, y]

        return torch.utils._pytree.tree_unflatten(ys, spec)

    def _transformables(self, inputs: List[Any]) -> List[bool]:
        """
        Parameters
        ----------
        inputs : List[Any]
            List of input objects to be checked for transformation.

        Returns
        -------
        transformables : List[bool]
            List indicating whether each input object is transformable or not.
            ``True`` if an object can be transformed, ``False`` otherwise.
        """
        # NOTE:
        #   Heuristic for transforming anonymous tensor inputs:
        #
        #       1.  Anonymous tensors, i.e., non-:class:`Feature` tensors,
        #           are passed through if there is an explicit feature in the
        #           sample.
        #
        #       2.  If there is no explicit feature the sample, only the first
        #           encountered anonymous tensor is transformed, while the rest
        #           are passed. The order is defined by the returned
        #           `flat_inputs` of `tree_flatten`, which recurses
        #           depth-first through the input.
        #
        #   The heuristic should work well for most people in practice. The
        #   only case where it doesn't is if someone tries to transform
        #   multiple anonymous features at the same time, expecting them all
        #   to be treated as named features.
        transformables = []

        transform_anonymous_feature = False

        for input in inputs:
            for t in [str]:
                if isinstance(input, t) if isinstance(t, type) else t(input):
                    transform_anonymous_feature = True

        for input in inputs:
            transformable = True

            checked = False

            for t in self._transformed_types:
                if isinstance(input, t) if isinstance(t, type) else t(input):
                    checked = True

            if not checked:
                transformable = False
            elif isinstance(input, Tensor) and not isinstance(input, Feature):
                if transform_anonymous_feature:
                    transform_anonymous_feature = False
                else:
                    transformable = False

            transformables = [*transformables, transformable]

        return transformables

    def extra_repr(self) -> str:
        """
        Returns
        -------
        str
            A string representation of the extra configuration attributes of
            the Transform module.
        """
        extra = []

        for name, value in self.__dict__.items():
            if name.startswith("_") or name == "training":
                continue

            if not isinstance(
                value,
                (Enum, bool, float, int, list, str, tuple),
            ):
                continue

            extra.append(f"{name}={value}")

        return ", ".join(extra)
extra_repr
extra_repr()

Returns:

Type Description
str

A string representation of the extra configuration attributes of the Transform module.

Source code in src/beignet/transforms/_transform.py
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
def extra_repr(self) -> str:
    """
    Returns
    -------
    str
        A string representation of the extra configuration attributes of
        the Transform module.
    """
    extra = []

    for name, value in self.__dict__.items():
        if name.startswith("_") or name == "training":
            continue

        if not isinstance(
            value,
            (Enum, bool, float, int, list, str, tuple),
        ):
            continue

        extra.append(f"{name}={value}")

    return ", ".join(extra)
forward
forward(*inputs)

Parameters:

Name Type Description Default
inputs tuple

The input tensors to be transformed. Can have multiple tensors.

()

Returns:

Type Description
tensor

The transformed input tensors.

Source code in src/beignet/transforms/_transform.py
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
def forward(self, *inputs: Any) -> Any:
    """
    Parameters
    ----------
    inputs : tuple
        The input tensors to be transformed. Can have multiple tensors.

    Returns
    -------
    tensor
        The transformed input tensors.
    """
    if len(inputs) > 1:
        flattened, spec = torch.utils._pytree.tree_flatten(inputs)
    else:
        flattened, spec = torch.utils._pytree.tree_flatten(inputs[0])

    self._check_inputs(flattened)

    transformables = self._transformables(flattened)

    inputs = []

    for x, transformable in zip(flattened, transformables, strict=False):
        if transformable:
            inputs = [*inputs, x]

    inputs = self._get_params(inputs)

    ys = []

    for x, transformable in zip(flattened, transformables, strict=False):
        if transformable:
            y = self._transform(x, inputs)
        else:
            y = x

        ys = [*ys, y]

    return torch.utils._pytree.tree_unflatten(ys, spec)

beignet.transforms.Lambda

Bases: Transform

Source code in src/beignet/transforms/_lambda.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
class Lambda(Transform):
    _transformed_types = (object,)

    def __init__(self, fn: Callable[[Any], Any], *types: Type):
        """
        Parameters
        ----------

        fn : Callable[[Any], Any]
            The function to be used as the transformation function.

        types : Type
            The types of the input arguments that will be passed to the
            function. If no types are provided, the default transformed types
            will be used.
        """
        super().__init__()

        self._fn = fn

        self._types = types or self._transformed_types

    def _transform(self, input: Any, parameters: Dict[str, Any]) -> Any:
        """
        Parameters
        ----------

        input : Any
            The input value to be transformed.

        parameters : Dict[str, Any]
            A dictionary containing any additional parameters required for the
            transformation.

        Returns
        -------
        Any
            The transformed value.

        """
        if isinstance(input, self._types):
            return self._fn(input)
        else:
            return input

    def extra_repr(self) -> str:
        """
        Get a string representation of the Lambda transform.

        Returns
        -------
        str
            A string representation of the Lambda transform, including the
            function name and types.
        """
        extras = []

        name = getattr(self._fn, "__name__", None)

        if name:
            extras.append(name)

        extras.append(f"types={[type.__name__ for type in self._types]}")

        return ", ".join(extras)
__init__
__init__(fn, *types)

Parameters:

Name Type Description Default
fn Callable[[Any], Any]

The function to be used as the transformation function.

required
types Type

The types of the input arguments that will be passed to the function. If no types are provided, the default transformed types will be used.

()
Source code in src/beignet/transforms/_lambda.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def __init__(self, fn: Callable[[Any], Any], *types: Type):
    """
    Parameters
    ----------

    fn : Callable[[Any], Any]
        The function to be used as the transformation function.

    types : Type
        The types of the input arguments that will be passed to the
        function. If no types are provided, the default transformed types
        will be used.
    """
    super().__init__()

    self._fn = fn

    self._types = types or self._transformed_types
extra_repr
extra_repr()

Get a string representation of the Lambda transform.

Returns:

Type Description
str

A string representation of the Lambda transform, including the function name and types.

Source code in src/beignet/transforms/_lambda.py
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def extra_repr(self) -> str:
    """
    Get a string representation of the Lambda transform.

    Returns
    -------
    str
        A string representation of the Lambda transform, including the
        function name and types.
    """
    extras = []

    name = getattr(self._fn, "__name__", None)

    if name:
        extras.append(name)

    extras.append(f"types={[type.__name__ for type in self._types]}")

    return ", ".join(extras)