Skip to content

SegmenterTorch

SegmenterTorch

Bases: Module

A PyTorch-based segmenter for input data using windowing techniques.

Supports Weighted Overlap-Add (WOLA) and Overlap-Add (OLA) methods.

Attributes:

Name Type Description
window Window

A class containing hop size and windows.

Source code in src/libsegmenter/backends/SegmenterTorch.py
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
class SegmenterTorch(torch.nn.Module):
    """
    A PyTorch-based segmenter for input data using windowing techniques.

    Supports Weighted Overlap-Add (WOLA) and Overlap-Add (OLA) methods.

    Attributes:
        window (Window): A class containing hop size and windows.

    """

    def __init__(self, window: Window) -> None:
        """
        Initializes the SegmenterTorch instance.

        Args:
            window (Window): A window object containing segmentation parameters.

        """
        super().__init__()  # type: ignore

        self.window = window

    def segment(self, x: torch.Tensor) -> torch.Tensor:
        """
        Segments the input tensor into overlapping windows.

        Args:
            x (torch.Tensor): Input tensor (1D or 2D).

        Returns:
            Segmented tensor of shape (batch_size, num_segments, segment_size).

        Raises:
            ValueError: If types are incorrect.
            ValueError: If input dimensions are invalid.

        """
        if x.ndim not in {1, 2}:
            raise ValueError(f"Only supports 1D or 2D inputs, provided {x.ndim}D.")

        batch_size = x.shape[0] if x.ndim == 2 else None
        num_samples = x.shape[-1]

        if batch_size is None:
            x = x.reshape(1, -1)  # Convert to batch format for consistency

        num_segments = compute_num_segments(
            num_samples, self.window.hop_size, self.window.analysis_window.shape[-1]
        )

        if num_segments <= 0:
            raise ValueError(
                "Input signal is too short for segmentation with the given parameters."
            )

        # Windowing
        analysis_window = torch.tensor(
            self.window.analysis_window, device=x.device, dtype=x.dtype
        )

        idxs = torch.arange(num_segments, device=x.device) * self.window.hop_size
        frame_idxs = idxs.unsqueeze(1) + torch.arange(
            self.window.analysis_window.shape[-1], device=x.device
        )
        y = x[:, frame_idxs] * analysis_window

        return (
            y.squeeze(0) if batch_size is None else y
        )  # Remove batch dimension if needed

    def unsegment(self, y: torch.Tensor) -> torch.Tensor:
        """
        Reconstructs the original signal from segmented data.

        Args:
            y (torch.Tensor): Segmented tensor (2D or 3D).

        Returns:
            Reconstructed 1D or 2D signal.

        Raises:
            ValueError: If types are incorrect.
            ValueError: If input dimensions are invalid.

        """
        if self.window.synthesis_window is None:
            raise ValueError("Given windowing scheme does not support unsegmenting.")

        if y.ndim not in {2, 3}:
            raise ValueError(f"Only supports 2D or 3D inputs, provided {y.ndim}D.")

        batch_size = y.shape[0] if y.ndim == 3 else None
        num_segments = y.shape[-2]
        segment_size = y.shape[-1]

        if batch_size is None:
            y = y.reshape(1, num_segments, -1)  # Convert to batch format

        num_samples = compute_num_samples(
            num_segments, self.window.hop_size, segment_size
        )

        if num_samples <= 0:
            raise ValueError(
                "Invalid segment structure, possibly due to incorrect windowing "
                + "parameters."
            )

        # allocate memory for the reconstructed signal
        x = torch.zeros(
            (batch_size if batch_size is not None else 1, num_samples),
            device=y.device,
            dtype=y.dtype,
        )

        # overlap-add method for reconstructing the original signal
        synthesis_window = torch.tensor(
            self.window.synthesis_window, device=y.device, dtype=y.dtype
        )

        frame_idxs = (
            torch.arange(num_segments, device=y.device) * self.window.hop_size
        ).unsqueeze(1) + torch.arange(segment_size, device=y.device)
        frame_idxs = frame_idxs.flatten()
        x.scatter_add_(
            1,
            frame_idxs.unsqueeze(0).expand(x.shape[0], -1),
            (y * synthesis_window).reshape(x.shape[0], -1),
        )

        return x.squeeze(0) if batch_size is None else x

__init__(window)

Initializes the SegmenterTorch instance.

Parameters:

Name Type Description Default
window Window

A window object containing segmentation parameters.

required
Source code in src/libsegmenter/backends/SegmenterTorch.py
37
38
39
40
41
42
43
44
45
46
47
def __init__(self, window: Window) -> None:
    """
    Initializes the SegmenterTorch instance.

    Args:
        window (Window): A window object containing segmentation parameters.

    """
    super().__init__()  # type: ignore

    self.window = window

segment(x)

Segments the input tensor into overlapping windows.

Parameters:

Name Type Description Default
x Tensor

Input tensor (1D or 2D).

required

Returns:

Type Description
Tensor

Segmented tensor of shape (batch_size, num_segments, segment_size).

Raises:

Type Description
ValueError

If types are incorrect.

ValueError

If input dimensions are invalid.

Source code in src/libsegmenter/backends/SegmenterTorch.py
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
def segment(self, x: torch.Tensor) -> torch.Tensor:
    """
    Segments the input tensor into overlapping windows.

    Args:
        x (torch.Tensor): Input tensor (1D or 2D).

    Returns:
        Segmented tensor of shape (batch_size, num_segments, segment_size).

    Raises:
        ValueError: If types are incorrect.
        ValueError: If input dimensions are invalid.

    """
    if x.ndim not in {1, 2}:
        raise ValueError(f"Only supports 1D or 2D inputs, provided {x.ndim}D.")

    batch_size = x.shape[0] if x.ndim == 2 else None
    num_samples = x.shape[-1]

    if batch_size is None:
        x = x.reshape(1, -1)  # Convert to batch format for consistency

    num_segments = compute_num_segments(
        num_samples, self.window.hop_size, self.window.analysis_window.shape[-1]
    )

    if num_segments <= 0:
        raise ValueError(
            "Input signal is too short for segmentation with the given parameters."
        )

    # Windowing
    analysis_window = torch.tensor(
        self.window.analysis_window, device=x.device, dtype=x.dtype
    )

    idxs = torch.arange(num_segments, device=x.device) * self.window.hop_size
    frame_idxs = idxs.unsqueeze(1) + torch.arange(
        self.window.analysis_window.shape[-1], device=x.device
    )
    y = x[:, frame_idxs] * analysis_window

    return (
        y.squeeze(0) if batch_size is None else y
    )  # Remove batch dimension if needed

unsegment(y)

Reconstructs the original signal from segmented data.

Parameters:

Name Type Description Default
y Tensor

Segmented tensor (2D or 3D).

required

Returns:

Type Description
Tensor

Reconstructed 1D or 2D signal.

Raises:

Type Description
ValueError

If types are incorrect.

ValueError

If input dimensions are invalid.

Source code in src/libsegmenter/backends/SegmenterTorch.py
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
def unsegment(self, y: torch.Tensor) -> torch.Tensor:
    """
    Reconstructs the original signal from segmented data.

    Args:
        y (torch.Tensor): Segmented tensor (2D or 3D).

    Returns:
        Reconstructed 1D or 2D signal.

    Raises:
        ValueError: If types are incorrect.
        ValueError: If input dimensions are invalid.

    """
    if self.window.synthesis_window is None:
        raise ValueError("Given windowing scheme does not support unsegmenting.")

    if y.ndim not in {2, 3}:
        raise ValueError(f"Only supports 2D or 3D inputs, provided {y.ndim}D.")

    batch_size = y.shape[0] if y.ndim == 3 else None
    num_segments = y.shape[-2]
    segment_size = y.shape[-1]

    if batch_size is None:
        y = y.reshape(1, num_segments, -1)  # Convert to batch format

    num_samples = compute_num_samples(
        num_segments, self.window.hop_size, segment_size
    )

    if num_samples <= 0:
        raise ValueError(
            "Invalid segment structure, possibly due to incorrect windowing "
            + "parameters."
        )

    # allocate memory for the reconstructed signal
    x = torch.zeros(
        (batch_size if batch_size is not None else 1, num_samples),
        device=y.device,
        dtype=y.dtype,
    )

    # overlap-add method for reconstructing the original signal
    synthesis_window = torch.tensor(
        self.window.synthesis_window, device=y.device, dtype=y.dtype
    )

    frame_idxs = (
        torch.arange(num_segments, device=y.device) * self.window.hop_size
    ).unsqueeze(1) + torch.arange(segment_size, device=y.device)
    frame_idxs = frame_idxs.flatten()
    x.scatter_add_(
        1,
        frame_idxs.unsqueeze(0).expand(x.shape[0], -1),
        (y * synthesis_window).reshape(x.shape[0], -1),
    )

    return x.squeeze(0) if batch_size is None else x