Пример #1
0
def get_model(state, args, init_model_name=None):
    if init_model_name is not None and os.path.exists(init_model_name):
        model, optimizer, state = load_model(init_model_name,
                                             return_optimizer=True,
                                             return_state=True)
    else:
        if "conv_dropout" in args:
            conv_dropout = args.conv_dropout
        else:
            conv_dropout = cfg.conv_dropout
        cnn_args = {1}

        if args.fixed_segment is not None:
            frames = cfg.frames
        else:
            frames = None

        nb_layers = 4
        cnn_kwargs = {
            "activation": cfg.activation,
            "conv_dropout": conv_dropout,
            "batch_norm": cfg.batch_norm,
            "kernel_size": nb_layers * [3],
            "padding": nb_layers * [1],
            "stride": nb_layers * [1],
            "nb_filters": [16, 16, 32, 65],
            "pooling": [(2, 2), (2, 2), (1, 4), (1, 2)],
            "aggregation": args.agg_time,
            "norm_out": args.norm_embed,
            "frames": frames,
        }
        nb_frames_staying = cfg.frames // (2**2)
        model = CNN(*cnn_args, **cnn_kwargs)
        # model.apply(weights_init)
        state.update({
            'model': {
                "name": model.__class__.__name__,
                'args': cnn_args,
                "kwargs": cnn_kwargs,
                'state_dict': model.state_dict()
            },
            'nb_frames_staying': nb_frames_staying
        })
        if init_model_name is not None:
            save_model(state, init_model_name)
    pytorch_total_params = sum(p.numel() for p in model.parameters()
                               if p.requires_grad)
    LOG.info(
        "number of parameters in the model: {}".format(pytorch_total_params))
    return model, state
Пример #2
0
class WCRNN(nn.Module):  #, BaseFairseqModel):
    def __init__(self,
                 w2v_cfg,
                 n_in_channel,
                 nclass,
                 attention=False,
                 activation="Relu",
                 dropout=0,
                 train_cnn=True,
                 rnn_type='BGRU',
                 n_RNN_cell=64,
                 n_layers_RNN=1,
                 dropout_recurrent=0,
                 cnn_integration=False,
                 **kwargs):
        super(WCRNN, self).__init__()

        self.w2v = w2v_encoder(w2v_cfg)  #Wav2Vec2Config)
        #self.w2v = Wav2VecEncoder(Wav2Vec2SedConfig, None)
        self.pooling = nn.Sequential(nn.MaxPool2d((1, 4), (1, 4)))

        self.n_in_channel = n_in_channel
        self.attention = attention
        self.cnn_integration = cnn_integration
        n_in_cnn = n_in_channel
        if cnn_integration:
            n_in_cnn = 1
        self.cnn = CNN(n_in_cnn, activation, dropout, **kwargs)
        if not train_cnn:
            for param in self.cnn.parameters():
                param.requires_grad = False
        self.train_cnn = train_cnn
        if rnn_type == 'BGRU':
            nb_in = self.cnn.nb_filters[-1]
            if self.cnn_integration:
                # self.fc = nn.Linear(nb_in * n_in_channel, nb_in)
                nb_in = nb_in * n_in_channel
            self.rnn = BidirectionalGRU(nb_in,
                                        n_RNN_cell,
                                        dropout=dropout_recurrent,
                                        num_layers=n_layers_RNN)
        else:
            NotImplementedError("Only BGRU supported for CRNN for now")
        self.dropout = nn.Dropout(dropout)
        self.dense = nn.Linear(n_RNN_cell * 2, nclass)
        self.sigmoid = nn.Sigmoid()
        if self.attention:
            self.dense_softmax = nn.Linear(n_RNN_cell * 2, nclass)
            self.softmax = nn.Softmax(dim=-1)

    def load_cnn(self, state_dict):
        self.cnn.load_state_dict(state_dict)
        if not self.train_cnn:
            for param in self.cnn.parameters():
                param.requires_grad = False

    def load_state_dict(self, state_dict, strict=True):
        self.w2v.load_state_dice(state_dict["w2v"])
        self.cnn.load_state_dict(state_dict["cnn"])
        self.rnn.load_state_dict(state_dict["rnn"])
        self.dense.load_state_dict(state_dict["dense"])

    def state_dict(self, destination=None, prefix='', keep_vars=False):
        state_dict = {
            "w2v":
            self.w2v.state_dict(destination=destination,
                                prefix=prefix,
                                keep_vars=keep_vars),
            "cnn":
            self.cnn.state_dict(destination=destination,
                                prefix=prefix,
                                keep_vars=keep_vars),
            "rnn":
            self.rnn.state_dict(destination=destination,
                                prefix=prefix,
                                keep_vars=keep_vars),
            'dense':
            self.dense.state_dict(destination=destination,
                                  prefix=prefix,
                                  keep_vars=keep_vars)
        }
        return state_dict

    def save(self, filename):
        parameters = {
            'w2v': self.w2v.state_dict(),
            'cnn': self.cnn.state_dict(),
            'rnn': self.rnn.state_dict(),
            'dense': self.dense.state_dict()
        }
        torch.save(parameters, filename)

    def forward(self, audio):
        x = audio.squeeze()
        import pdb
        pdb.set_trace()
        feature = self.w2v(x)
        x = feature['x']
        x = x.transpose(1, 0)
        x = x.unsqueeze(1)

        # input size : (batch_size, n_channels, n_frames, n_freq)
        if self.cnn_integration:
            bs_in, nc_in = x.size(0), x.size(1)
            x = x.view(bs_in * nc_in, 1, *x.shape[2:])

        # conv features
        before = x
        x = self.cnn(x)
        bs, chan, frames, freq = x.size()
        if self.cnn_integration:
            x = x.reshape(bs_in, chan * nc_in, frames, freq)

        if freq != 1:
            warnings.warn(
                f"Output shape is: {(bs, frames, chan * freq)}, from {freq} staying freq"
            )
            x = x.permute(0, 2, 1, 3)
            x = x.contiguous().view(bs, frames, chan * freq)
        else:
            x = x.squeeze(-1)
            x = x.permute(0, 2, 1)  # [bs, frames, chan]

        # rnn features
        x = self.rnn(x)
        x = self.dropout(x)
        strong = self.dense(x)  # [bs, frames, nclass]
        strong = self.sigmoid(strong)
        if self.attention:
            sof = self.dense_softmax(x)  # [bs, frames, nclass]
            sof = self.softmax(sof)
            sof = torch.clamp(sof, min=1e-7, max=1)
            weak = (strong * sof).sum(1) / (sof.sum(1) + 1e-08)  # [bs, nclass]
        else:
            weak = strong.mean(1)
        return strong, weak
Пример #3
0
class CRNN(nn.Module):
    def __init__(self,
                 n_in_channel,
                 nclass,
                 attention=False,
                 activation="Relu",
                 dropout=0,
                 train_cnn=True,
                 rnn_type='BGRU',
                 n_RNN_cell=64,
                 n_layers_RNN=1,
                 dropout_recurrent=0,
                 **kwargs):
        super(CRNN, self).__init__()
        self.attention = attention
        self.cnn = CNN(n_in_channel, activation, dropout, **kwargs)
        if not train_cnn:
            for param in self.cnn.parameters():
                param.requires_grad = False
        self.train_cnn = train_cnn
        if rnn_type == 'BGRU':
            self.rnn = BidirectionalGRU(self.cnn.nb_filters[-1],
                                        n_RNN_cell,
                                        dropout=dropout_recurrent,
                                        num_layers=n_layers_RNN)
        else:
            NotImplementedError("Only BGRU supported for CRNN for now")
        self.dropout = nn.Dropout(dropout)
        self.dense = nn.Linear(n_RNN_cell * 2, nclass)
        self.sigmoid = nn.Sigmoid()
        if self.attention:
            self.dense_softmax = nn.Linear(n_RNN_cell * 2, nclass)
            self.softmax = nn.Softmax(dim=-1)

    def load_cnn(self, parameters):
        self.cnn.load(parameters)
        if not self.train_cnn:
            for param in self.cnn.parameters():
                param.requires_grad = False

    def load(self, filename=None, parameters=None):
        if filename is not None:
            parameters = torch.load(filename)
        if parameters is None:
            raise NotImplementedError(
                "load is a filename or a list of parameters (state_dict)")

        self.cnn.load(parameters=parameters["cnn"])
        self.rnn.load_state_dict(parameters["rnn"])
        self.dense.load_state_dict(parameters["dense"])

    def state_dict(self, destination=None, prefix='', keep_vars=False):
        state_dict = {
            "cnn":
            self.cnn.state_dict(destination=destination,
                                prefix=prefix,
                                keep_vars=keep_vars),
            "rnn":
            self.rnn.state_dict(destination=destination,
                                prefix=prefix,
                                keep_vars=keep_vars),
            'dense':
            self.dense.state_dict(destination=destination,
                                  prefix=prefix,
                                  keep_vars=keep_vars)
        }
        return state_dict

    def save(self, filename):
        parameters = {
            'cnn': self.cnn.state_dict(),
            'rnn': self.rnn.state_dict(),
            'dense': self.dense.state_dict()
        }
        torch.save(parameters, filename)

    def forward(self, x):
        # input size : (batch_size, n_channels, n_frames, n_freq)
        # conv features
        x = self.cnn(x)
        bs, chan, frames, freq = x.size()
        if freq != 1:
            warnings.warn("Output shape is: {}".format(
                (bs, frames, chan * freq)))
            x = x.permute(0, 2, 1, 3)
            x = x.contiguous().view(bs, frames, chan * freq)
        else:
            x = x.squeeze(-1)
            x = x.permute(0, 2, 1)  # [bs, frames, chan]

        # rnn features
        x = self.rnn(x)
        x = self.dropout(x)
        strong = self.dense(x)  # [bs, frames, nclass]
        strong = self.sigmoid(strong)
        if self.attention:
            sof = self.dense_softmax(x)  # [bs, frames, nclass]
            sof = self.softmax(sof)
            sof = torch.clamp(sof, min=1e-7, max=1)
            weak = (strong * sof).sum(1) / sof.sum(1)  # [bs, nclass]
        else:
            weak = strong.mean(1)
        return strong, weak
Пример #4
0
def main():
    songs = get_notes()

    vocab_set = set()
    for song in songs:
        for note in song:
            vocab_set.add(note)

    n_in, n_out = prep_sequences(songs, sequence_length=100)
    X_train, X_val, y_train, y_val = train_test_split(n_in,
                                                      n_out,
                                                      test_size=0.2)

    train_ds = MusicDataset(X_train, y_train)
    val_ds = MusicDataset(X_val, y_val)

    train_dataloader = DataLoader(train_ds,
                                  batch_size=512,
                                  shuffle=True,
                                  num_workers=0)
    val_dataloader = DataLoader(val_ds,
                                batch_size=512,
                                shuffle=False,
                                num_workers=0)

    model = CNN(100, len(vocab_set))
    model.cuda()
    epochs = 25
    initial_lr = 0.001
    optimizer = optim.Adam(model.parameters(), lr=initial_lr)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
    loss_fn = CrossEntropyLoss()

    train_losses = []
    val_losses = []

    train_accuracies = []
    val_accuracies = []

    for epoch in tqdm(range(1, epochs + 1)):

        model.train()
        train_loss_total = 0.0
        num_steps = 0
        correct = 0
        ### Train
        for i, batch in enumerate(train_dataloader):
            X, y = batch[0].cuda(), batch[1].cuda()
            train_preds = model(X)

            loss = loss_fn(train_preds, y)
            train_loss_total += loss.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            num_steps += 1

            train_preds = torch.max(train_preds, 1)[1]
            correct += (train_preds == y).float().sum()

        train_loss_total_avg = train_loss_total / num_steps
        train_accuracy = correct / len(train_ds)
        train_accuracies.append(train_accuracy)
        train_losses.append(train_loss_total_avg)

        model.eval()
        val_loss_total = 0.0
        num_steps = 0
        correct = 0
        for i, batch in enumerate(val_dataloader):
            with torch.no_grad():
                X, y = batch[0].cuda(), batch[1].cuda()

                val_preds = model(X)
                loss = loss_fn(val_preds, y)
                val_loss_total += loss.item()
                val_preds = torch.max(val_preds, 1)[1]
                correct += (val_preds == y).float().sum()

            num_steps += 1

        val_loss_total_avg = val_loss_total / num_steps
        val_accuracy = correct / len(val_ds)
        val_accuracies.append(val_accuracy)
        val_losses.append(val_loss_total_avg)

        scheduler.step()
        print('\nTrain loss: {:.4f}'.format(train_loss_total_avg))
        print('Train accuracy: {:.4f}'.format(train_accuracy))

        print('Val loss: {:.4f}'.format(val_loss_total_avg))
        print('Val accuracy\n: {:.4f}'.format(val_accuracy))

        torch.save(model.state_dict(),
                   "weights/model_params_epoch" + str(epoch))
        torch.save(optimizer.state_dict(),
                   "weights/optim_params_epoch" + str(epoch))

    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.plot(range(1, len(train_accuracies) + 1), train_accuracies)
    plt.plot(range(1, len(val_accuracies) + 1), val_accuracies)
    plt.savefig("plots/accuracies.png")
    plt.close()

    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.plot(range(1, len(train_losses) + 1), train_losses)
    plt.plot(range(1, len(val_losses) + 1), val_losses)
    plt.savefig("plots/losses.png")
    plt.close()

    generate_midi(model, val_ds, vocab_set, output_filename="output.mid")