def make_checkpoint(epoch: int,
                    model: LSTM,
                    loss_function: Union[SplitCrossEntropyLoss,
                                         CrossEntropyLoss],
                    optimizer: torch.optim.Optimizer,
                    use_apex=False,
                    amp=None,
                    prior: Union[str, nn.Module] = None,
                    **kwargs):
    """
    Packages network parameters into a picklable dictionary containing keys
    * epoch: current epoch
    * model: the network model
    * loss: the loss function
    * optimizer: the torch optimizer
    * use_apex: use nvidia apex for AMP or not
    * amp: the nvidia AMP object

    Parameters
    ----------
    epoch : int
        The current epoch of training
    model : LSTM
        The network model
    loss_function : SplitCrossEntropyLoss or CrossEntropyLoss
        The loss function
    optimizer : torch.optim.optimizer
        The optimizer function
    use_apex : bool
        If mixed precision mode is activated. If this is true, the `amp` argument should be supplied as well.
        The default value is False.
    amp :
        The nvidia apex amp object, should contain information about state of training
    kwargs :
        Not used

    Returns
    -------
    checkpoint: dict
        A picklable dict containing the checkpoint

    """
    checkpoint = {
        'epoch': epoch,
        'model': model.state_dict(),
        'loss': loss_function.state_dict(),
        'optimizer': optimizer.state_dict(),
    }
    if use_apex:
        checkpoint['amp'] = amp.state_dict()

    if prior is not None and not isinstance(prior, str):
        checkpoint['prior'] = prior

    return checkpoint
Example #2
0
class SimpleDPLSTMTest(unittest.TestCase):
    def setUp(self):
        self.SEQ_LENGTH = 20
        self.INPUT_DIM = 25
        self.MINIBATCH_SIZE = 30
        self.LSTM_OUT_DIM = 12
        self.NUM_LAYERS = 1
        self.bidirectional = False
        self.batch_first = False

        self.num_directions = 2 if self.bidirectional else 1
        self.h_init = torch.randn(
            self.NUM_LAYERS * self.num_directions,
            self.MINIBATCH_SIZE,
            self.LSTM_OUT_DIM,
        )
        self.c_init = torch.randn(
            self.NUM_LAYERS * self.num_directions,
            self.MINIBATCH_SIZE,
            self.LSTM_OUT_DIM,
        )

        self.original_lstm = LSTM(
            self.INPUT_DIM,
            self.LSTM_OUT_DIM,
            batch_first=self.batch_first,
            num_layers=self.NUM_LAYERS,
            bidirectional=self.bidirectional,
        )
        self.dp_lstm = DPLSTM(
            self.INPUT_DIM,
            self.LSTM_OUT_DIM,
            batch_first=self.batch_first,
            num_layers=self.NUM_LAYERS,
            bidirectional=self.bidirectional,
        )

        self.dp_lstm.load_state_dict(self.original_lstm.state_dict())

    def _reset_seeds(self):
        torch.manual_seed(1337)
        torch.cuda.manual_seed(1337)

    def test_lstm_forward(self):
        x = (
            torch.randn(self.MINIBATCH_SIZE, self.SEQ_LENGTH, self.INPUT_DIM)
            if self.batch_first
            else torch.randn(self.SEQ_LENGTH, self.MINIBATCH_SIZE, self.INPUT_DIM)
        )
        hidden = (self.h_init, self.c_init)

        out, (hn, cn) = self.original_lstm(x, hidden)
        dp_out, (dp_hn, dp_cn) = self.dp_lstm(x, hidden)

        outputs_to_test = [
            (out, dp_out, "LSTM and DPLSTM output"),
            (hn, dp_hn, "LSTM and DPLSTM state `h`"),
            (cn, dp_cn, "LSTM and DPLSTM state `c`"),
        ]

        for output, dp_output, message in outputs_to_test:
            assert_allclose(
                actual=dp_output.expand_as(output),
                expected=output,
                atol=10e-6,
                rtol=10e-5,
                msg=f"Tensor value mismatch between {message}",
            )

    def test_lstm_backward(self):
        x = (
            torch.randn(self.MINIBATCH_SIZE, self.SEQ_LENGTH, self.INPUT_DIM)
            if self.batch_first
            else torch.randn(self.SEQ_LENGTH, self.MINIBATCH_SIZE, self.INPUT_DIM)
        )
        criterion = nn.MSELoss()

        hidden = (self.h_init, self.c_init)

        out, (hn, cn) = self.original_lstm(x, hidden)
        y = torch.zeros_like(out)
        loss = criterion(out, y)
        loss.backward()

        dp_out, (dp_hn, dp_cn) = self.dp_lstm(x, hidden)
        dp_loss = criterion(dp_out, y)
        dp_loss.backward()

        dp_lstm_params = dict(self.dp_lstm.named_parameters())
        for param_name, param in self.original_lstm.named_parameters():
            dp_param = dp_lstm_params[param_name]
            assert_allclose(
                actual=dp_param,
                expected=param,
                atol=10e-5,
                rtol=10e-3,
                msg=f"Tensor value mismatch in the parameter '{param_name}'",
            )
            assert_allclose(
                actual=dp_param.grad,
                expected=param.grad,
                atol=10e-6,
                rtol=10e-5,
                msg=f"Tensor value mismatch in the gradient of parameter '{param_name}'",
            )

    def test_lstm_param_update(self):
        x = (
            torch.randn(self.MINIBATCH_SIZE, self.SEQ_LENGTH, self.INPUT_DIM)
            if self.batch_first
            else torch.randn(self.SEQ_LENGTH, self.MINIBATCH_SIZE, self.INPUT_DIM)
        )
        criterion = nn.MSELoss()

        optimizer = torch.optim.SGD(self.original_lstm.parameters(), lr=0.5)
        dp_optimizer = torch.optim.SGD(self.dp_lstm.parameters(), lr=0.5)

        # Train original LSTM for one step
        logits, (h_n, c_n) = self.original_lstm(x)
        y = torch.zeros_like(logits)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()

        # Train DP LSTM for one step
        dp_logits, (dp_h_n, dp_c_n) = self.dp_lstm(x)
        dp_loss = criterion(dp_logits, y)
        dp_loss.backward()
        dp_optimizer.step()

        dp_lstm_params = dict(self.dp_lstm.named_parameters())
        for param_name, param in self.original_lstm.named_parameters():
            dp_param = dp_lstm_params[param_name]
            assert_allclose(
                actual=dp_param,
                expected=param,
                atol=10e-6,
                rtol=10e-5,
                msg=f"Tensor value mismatch in the parameter '{param_name}'",
            )
            assert_allclose(
                actual=dp_param.grad,
                expected=param.grad,
                atol=10e-6,
                rtol=10e-5,
                msg=f"Tensor value mismatch in the gradient of parameter '{param_name}'",
            )
class OpenUnmix(nn.Module):
    def __init__(
        self,
        n_fft=4096,
        n_hop=1024,
        input_is_spectrogram=False,
        hidden_size=512,
        nb_channels=2,
        sample_rate=44100,
        nb_layers=3,
        input_mean=None,
        input_scale=None,
        max_bin=None,
        unidirectional=False,
        power=1,
        first_iter = 1,
    ):
        """
        Input: (nb_samples, nb_channels, nb_timesteps)
            or (nb_frames, nb_samples, nb_channels, nb_bins)
        Output: Power/Mag Spectrogram
                (nb_frames, nb_samples, nb_channels, nb_bins)
        """

        super(OpenUnmix, self).__init__()

        self.nb_output_bins = n_fft // 2 + 1
        if max_bin:
            self.nb_bins = max_bin
        else:
            self.nb_bins = self.nb_output_bins

        self.hidden_size = hidden_size

        self.stft = STFT(n_fft=n_fft, n_hop=n_hop)
        self.spec = Spectrogram(power=power, mono=(nb_channels == 1))
        self.register_buffer('sample_rate', torch.tensor(sample_rate))

        if input_is_spectrogram:
            self.transform = NoOp()
        else:
            self.transform = nn.Sequential(self.stft, self.spec)

        self.fc1 = Linear(
            self.nb_bins*nb_channels, hidden_size,
            bias=False
        )

        self.bn1 = BatchNorm1d(hidden_size)

        if unidirectional:
            lstm_hidden_size = hidden_size
        else:
            lstm_hidden_size = hidden_size // 2

        self.lstm = LSTM(
            input_size=hidden_size,
            hidden_size=lstm_hidden_size,
            num_layers=nb_layers,
            bidirectional=not unidirectional,
            batch_first=False,
            dropout=0.4,
        )

        self.state = self.lstm.state_dict()
        
        self.fc2 = Linear(
            in_features=hidden_size*2,
            out_features=hidden_size,
            bias=False
        )

        self.bn2 = BatchNorm1d(hidden_size)

        self.fc3 = Linear(
            in_features=hidden_size,
            out_features=self.nb_output_bins*nb_channels,
            bias=False
        )

        self.bn3 = BatchNorm1d(self.nb_output_bins*nb_channels)

        if input_mean is not None:
            input_mean = torch.from_numpy(
                -input_mean[:self.nb_bins]
            ).float()
        else:
            input_mean = torch.zeros(self.nb_bins)

        if input_scale is not None:
            input_scale = torch.from_numpy(
                1.0/input_scale[:self.nb_bins]
            ).float()
        else:
            input_scale = torch.ones(self.nb_bins)

        self.input_mean = Parameter(input_mean)
        self.input_scale = Parameter(input_scale)

        self.output_scale = Parameter(
            torch.ones(self.nb_output_bins).float()
        )
        self.output_mean = Parameter(
            torch.ones(self.nb_output_bins).float()
        )
    
    def forward(self, x, h_t_minus1, c_t_minus1):
        # check for waveform or spectrogram
        # transform to spectrogram if (nb_samples, nb_channels, nb_timesteps)
        # and reduce feature dimensions, therefore we reshape
        x = self.transform(x)

        nb_frames, nb_samples, nb_channels, nb_bins = x.data.shape

        mix = x.detach().clone()
        
        # crop
        x = x[..., :self.nb_bins]

        # shift and scale input to mean=0 std=1 (across all bins)
        x += self.input_mean
        x *= self.input_scale

        # to (nb_frames*nb_samples, nb_channels*nb_bins)
        # and encode to (nb_frames*nb_samples, hidden_size)
        x = self.fc1(x.reshape(-1, nb_channels*self.nb_bins))
        # normalize every instance in a batch
        x = self.bn1(x)
        x = x.reshape(nb_frames, nb_samples, self.hidden_size)
        # squash range ot [-1, 1]
        x = torch.tanh(x)
        
        # apply 3-layers of stacked LSTM
        # cell and activation states are uninitialized on the first iteration
        if(h_t_minus1 is None):
            lstm_out = self.lstm(x)
        else:
            lstm_out = self.lstm(x, (h_t_minus1, c_t_minus1))
        
        # lstm skip connection
        x = torch.cat([x, lstm_out[0]], -1)

        # first dense stage + batch norm
        x = self.fc2(x.reshape(-1, x.shape[-1]))
        x = self.bn2(x)

        x = F.relu(x)

        # second dense stage + layer norm
        x = self.fc3(x)
        x = self.bn3(x)

        # reshape back to original dim
        x = x.reshape(nb_frames, nb_samples, nb_channels, self.nb_output_bins)

        # apply output scaling
        x *= self.output_scale
        x += self.output_mean

        # since our output is non-negative, we can apply RELU
        x = F.relu(x) * mix
        
        # Get current activation and cell states from LSTM
        h_t_minus1, c_t_minus1 = lstm_out[1]

        return x, h_t_minus1, c_t_minus1