Skip to content

SegmenterTensorFlow

SegmenterTensorFlow

Bases: Layer

A TensorFlow-based segmenter for input data using windowing techniques.

Supports Weighted Overlap-Add (WOLA) and Overlap-Add (OLA) methods.

Attributes:

Name Type Description
window Window

A class containing hop size, and windows.

Source code in src/libsegmenter/backends/SegmenterTensorFlow.py
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
class SegmenterTensorFlow(tf.keras.layers.Layer):
    """
    A TensorFlow-based segmenter for input data using windowing techniques.

    Supports Weighted Overlap-Add (WOLA) and Overlap-Add (OLA) methods.

    Attributes:
        window (Window): A class containing hop size, and windows.

    """

    def __init__(self, window: Window) -> None:
        """
        Initializes the SegmenterTensorFlow instance.

        Args:
            window (Window): A window object containing segmentation parameters.

        """
        super(SegmenterTensorFlow, self).__init__()  # type: ignore

        self.window = window

    def segment(self, x: tf.Tensor) -> tf.Tensor:
        """
        Segments the input tensor into overlapping windows.

        Args:
            x (tf.Tensor): Input tensor (1D or 2D).

        Returns:
            Segmented tensor of shape (batch_size, num_segments, segment_size).

        """
        if len(x.shape) not in {1, 2}:
            raise ValueError(
                f"Only supports 1D or 2D inputs, provided {len(x.shape)}D."
            )

        batch_size = x.shape[0] if len(x.shape) == 2 else None
        num_samples = x.shape[-1]

        if batch_size is None:
            x = tf.reshape(x, (1, -1))  # Convert to batch format

        num_segments = compute_num_segments(
            num_samples, self.window.hop_size, self.window.analysis_window.shape[-1]
        )

        if num_segments <= 0:
            raise ValueError(
                "Input signal is too short for segmentation with the given parameters."
            )

        # Pre-allocation
        X = tf.zeros(
            (
                batch_size if batch_size is not None else 1,
                num_segments,
                self.window.analysis_window.shape[-1],
            ),
            dtype=x.dtype,
        )

        # Windowing
        analysis_window = tf.convert_to_tensor(
            self.window.analysis_window, dtype=x.dtype
        )
        for k in range(num_segments):
            start_idx = k * self.window.hop_size
            X = tf.tensor_scatter_nd_update(
                X,
                [
                    [i, k, j]
                    for i in range(batch_size if batch_size is not None else 1)
                    for j in range(self.window.analysis_window.shape[-1])
                ],
                tf.reshape(
                    x[:, start_idx : start_idx + self.window.analysis_window.shape[-1]]
                    * analysis_window,
                    [-1],
                ),
            )

        return tf.squeeze(X, axis=0) if batch_size is None else X

    def unsegment(self, X: tf.Tensor) -> tf.Tensor:
        """
        Reconstructs the original signal from segmented data.

        Args:
            X (tf.Tensor): Segmented tensor (2D or 3D).

        Returns:
            Reconstructed 1D or 2D signal.

        """
        if self.window.synthesis_window is None:
            raise ValueError("Given windowing scheme does not support unsegmenting.")

        if len(X.shape) not in {2, 3}:
            raise ValueError(
                f"Only supports 2D or 3D inputs, provided {len(X.shape)}D."
            )

        batch_size = X.shape[0] if len(X.shape) == 3 else None
        num_segments = X.shape[-2]
        segment_size = X.shape[-1]

        if batch_size is None:
            X = tf.reshape(X, (1, num_segments, -1))  # Convert to batch format

        num_samples = compute_num_samples(
            num_segments, self.window.hop_size, segment_size
        )

        if num_samples <= 0:
            raise ValueError(
                "Invalid segment structure, possibly due to incorrect windowing "
                + "parameters."
            )

        # Allocate memory for the reconstructed signal
        x = tf.zeros(
            (batch_size if batch_size is not None else 1, num_samples), dtype=X.dtype
        )

        # Overlap-add method for reconstructing the original signal
        tf.convert_to_tensor(self.window.synthesis_window, dtype=X.dtype)

        for k in range(num_segments):
            tmpIdx = tf.reshape(
                tf.range(
                    k * self.window.hop_size, k * self.window.hop_size + segment_size
                ),
                shape=(segment_size, 1),
            )

            for b in range(batch_size if batch_size is not None else 1):
                idx = tf.concat([tf.fill((segment_size, 1), b), tmpIdx], axis=1)
                x = tf.tensor_scatter_nd_add(
                    x, idx, self.window.synthesis_window * X[b, k, :]
                )

        return tf.squeeze(x, axis=0) if batch_size is None else x

__init__(window)

Initializes the SegmenterTensorFlow instance.

Parameters:

Name Type Description Default
window Window

A window object containing segmentation parameters.

required
Source code in src/libsegmenter/backends/SegmenterTensorFlow.py
37
38
39
40
41
42
43
44
45
46
47
def __init__(self, window: Window) -> None:
    """
    Initializes the SegmenterTensorFlow instance.

    Args:
        window (Window): A window object containing segmentation parameters.

    """
    super(SegmenterTensorFlow, self).__init__()  # type: ignore

    self.window = window

segment(x)

Segments the input tensor into overlapping windows.

Parameters:

Name Type Description Default
x Tensor

Input tensor (1D or 2D).

required

Returns:

Type Description
Tensor

Segmented tensor of shape (batch_size, num_segments, segment_size).

Source code in src/libsegmenter/backends/SegmenterTensorFlow.py
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
def segment(self, x: tf.Tensor) -> tf.Tensor:
    """
    Segments the input tensor into overlapping windows.

    Args:
        x (tf.Tensor): Input tensor (1D or 2D).

    Returns:
        Segmented tensor of shape (batch_size, num_segments, segment_size).

    """
    if len(x.shape) not in {1, 2}:
        raise ValueError(
            f"Only supports 1D or 2D inputs, provided {len(x.shape)}D."
        )

    batch_size = x.shape[0] if len(x.shape) == 2 else None
    num_samples = x.shape[-1]

    if batch_size is None:
        x = tf.reshape(x, (1, -1))  # Convert to batch format

    num_segments = compute_num_segments(
        num_samples, self.window.hop_size, self.window.analysis_window.shape[-1]
    )

    if num_segments <= 0:
        raise ValueError(
            "Input signal is too short for segmentation with the given parameters."
        )

    # Pre-allocation
    X = tf.zeros(
        (
            batch_size if batch_size is not None else 1,
            num_segments,
            self.window.analysis_window.shape[-1],
        ),
        dtype=x.dtype,
    )

    # Windowing
    analysis_window = tf.convert_to_tensor(
        self.window.analysis_window, dtype=x.dtype
    )
    for k in range(num_segments):
        start_idx = k * self.window.hop_size
        X = tf.tensor_scatter_nd_update(
            X,
            [
                [i, k, j]
                for i in range(batch_size if batch_size is not None else 1)
                for j in range(self.window.analysis_window.shape[-1])
            ],
            tf.reshape(
                x[:, start_idx : start_idx + self.window.analysis_window.shape[-1]]
                * analysis_window,
                [-1],
            ),
        )

    return tf.squeeze(X, axis=0) if batch_size is None else X

unsegment(X)

Reconstructs the original signal from segmented data.

Parameters:

Name Type Description Default
X Tensor

Segmented tensor (2D or 3D).

required

Returns:

Type Description
Tensor

Reconstructed 1D or 2D signal.

Source code in src/libsegmenter/backends/SegmenterTensorFlow.py
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
def unsegment(self, X: tf.Tensor) -> tf.Tensor:
    """
    Reconstructs the original signal from segmented data.

    Args:
        X (tf.Tensor): Segmented tensor (2D or 3D).

    Returns:
        Reconstructed 1D or 2D signal.

    """
    if self.window.synthesis_window is None:
        raise ValueError("Given windowing scheme does not support unsegmenting.")

    if len(X.shape) not in {2, 3}:
        raise ValueError(
            f"Only supports 2D or 3D inputs, provided {len(X.shape)}D."
        )

    batch_size = X.shape[0] if len(X.shape) == 3 else None
    num_segments = X.shape[-2]
    segment_size = X.shape[-1]

    if batch_size is None:
        X = tf.reshape(X, (1, num_segments, -1))  # Convert to batch format

    num_samples = compute_num_samples(
        num_segments, self.window.hop_size, segment_size
    )

    if num_samples <= 0:
        raise ValueError(
            "Invalid segment structure, possibly due to incorrect windowing "
            + "parameters."
        )

    # Allocate memory for the reconstructed signal
    x = tf.zeros(
        (batch_size if batch_size is not None else 1, num_samples), dtype=X.dtype
    )

    # Overlap-add method for reconstructing the original signal
    tf.convert_to_tensor(self.window.synthesis_window, dtype=X.dtype)

    for k in range(num_segments):
        tmpIdx = tf.reshape(
            tf.range(
                k * self.window.hop_size, k * self.window.hop_size + segment_size
            ),
            shape=(segment_size, 1),
        )

        for b in range(batch_size if batch_size is not None else 1):
            idx = tf.concat([tf.fill((segment_size, 1), b), tmpIdx], axis=1)
            x = tf.tensor_scatter_nd_add(
                x, idx, self.window.synthesis_window * X[b, k, :]
            )

    return tf.squeeze(x, axis=0) if batch_size is None else x