Skip to content

TransformSelector

TransformSelector(transform, backend='numpy', *args, **kwargs)

Factory function to create a transform instance based on the specified backend.

Parameters:

Name Type Description Default
transform str

The transform to use. Supported options: ["bpd", "magnitude_phase", "spectrogram"].

required
backend str

The backend to use. Supported options: ["numpy", "torch", "tensorflow"]. Defaults to "numpy".

'numpy'
*args Any

Additional positional arguments to pass to the segmenter.

()
**kwargs Any

Additional keyword arguments to pass to the segmenter.

{}

Returns:

Type Description
Any

An instance of the transform corresponding to the chosen backend.

Raises:

Type Description
ValueError

If an unsupported backend is specified.

NotImplementedError

If the backend is not implemented.

Source code in src/libsegmenter/TransformSelector.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def TransformSelector(
    transform: str, backend: str = "numpy", *args: Any, **kwargs: Any
) -> Any:
    """
    Factory function to create a transform instance based on the specified backend.

    Args:
        transform (str): The transform to use. Supported options:
            ["bpd", "magnitude_phase", "spectrogram"].
        backend (str, optional): The backend to use. Supported options:
            ["numpy", "torch", "tensorflow"]. Defaults to "numpy".
        *args (Any): Additional positional arguments to pass to the segmenter.
        **kwargs (Any): Additional keyword arguments to pass to the segmenter.

    Returns:
        An instance of the transform corresponding to the chosen backend.

    Raises:
        ValueError: If an unsupported backend is specified.
        NotImplementedError: If the backend is not implemented.

    """
    if transform == "spectrogram":
        from libsegmenter.transforms.Spectrogram import Spectrogram

        return Spectrogram(*args, **kwargs, backend=backend)

    if transform == "magnitude_phase":
        from libsegmenter.transforms.MagnitudePhase import MagnitudePhase

        return MagnitudePhase(*args, **kwargs, backend=backend)

    if transform == "bpd":
        from libsegmenter.transforms.BPD import BPD

        return BPD(*args, **kwargs, backend=backend)

    raise ValueError(f"The '{transform}' transform is not known.")