def get_model(state, args, init_model_name=None): if init_model_name is not None and os.path.exists(init_model_name): model, optimizer, state = load_model(init_model_name, return_optimizer=True, return_state=True) else: if "conv_dropout" in args: conv_dropout = args.conv_dropout else: conv_dropout = cfg.conv_dropout cnn_args = {1} if args.fixed_segment is not None: frames = cfg.frames else: frames = None nb_layers = 4 cnn_kwargs = { "activation": cfg.activation, "conv_dropout": conv_dropout, "batch_norm": cfg.batch_norm, "kernel_size": nb_layers * [3], "padding": nb_layers * [1], "stride": nb_layers * [1], "nb_filters": [16, 16, 32, 65], "pooling": [(2, 2), (2, 2), (1, 4), (1, 2)], "aggregation": args.agg_time, "norm_out": args.norm_embed, "frames": frames, } nb_frames_staying = cfg.frames // (2**2) model = CNN(*cnn_args, **cnn_kwargs) # model.apply(weights_init) state.update({ 'model': { "name": model.__class__.__name__, 'args': cnn_args, "kwargs": cnn_kwargs, 'state_dict': model.state_dict() }, 'nb_frames_staying': nb_frames_staying }) if init_model_name is not None: save_model(state, init_model_name) pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) LOG.info( "number of parameters in the model: {}".format(pytorch_total_params)) return model, state
class WCRNN(nn.Module): #, BaseFairseqModel): def __init__(self, w2v_cfg, n_in_channel, nclass, attention=False, activation="Relu", dropout=0, train_cnn=True, rnn_type='BGRU', n_RNN_cell=64, n_layers_RNN=1, dropout_recurrent=0, cnn_integration=False, **kwargs): super(WCRNN, self).__init__() self.w2v = w2v_encoder(w2v_cfg) #Wav2Vec2Config) #self.w2v = Wav2VecEncoder(Wav2Vec2SedConfig, None) self.pooling = nn.Sequential(nn.MaxPool2d((1, 4), (1, 4))) self.n_in_channel = n_in_channel self.attention = attention self.cnn_integration = cnn_integration n_in_cnn = n_in_channel if cnn_integration: n_in_cnn = 1 self.cnn = CNN(n_in_cnn, activation, dropout, **kwargs) if not train_cnn: for param in self.cnn.parameters(): param.requires_grad = False self.train_cnn = train_cnn if rnn_type == 'BGRU': nb_in = self.cnn.nb_filters[-1] if self.cnn_integration: # self.fc = nn.Linear(nb_in * n_in_channel, nb_in) nb_in = nb_in * n_in_channel self.rnn = BidirectionalGRU(nb_in, n_RNN_cell, dropout=dropout_recurrent, num_layers=n_layers_RNN) else: NotImplementedError("Only BGRU supported for CRNN for now") self.dropout = nn.Dropout(dropout) self.dense = nn.Linear(n_RNN_cell * 2, nclass) self.sigmoid = nn.Sigmoid() if self.attention: self.dense_softmax = nn.Linear(n_RNN_cell * 2, nclass) self.softmax = nn.Softmax(dim=-1) def load_cnn(self, state_dict): self.cnn.load_state_dict(state_dict) if not self.train_cnn: for param in self.cnn.parameters(): param.requires_grad = False def load_state_dict(self, state_dict, strict=True): self.w2v.load_state_dice(state_dict["w2v"]) self.cnn.load_state_dict(state_dict["cnn"]) self.rnn.load_state_dict(state_dict["rnn"]) self.dense.load_state_dict(state_dict["dense"]) def state_dict(self, destination=None, prefix='', keep_vars=False): state_dict = { "w2v": self.w2v.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars), "cnn": self.cnn.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars), "rnn": self.rnn.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars), 'dense': self.dense.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) } return state_dict def save(self, filename): parameters = { 'w2v': self.w2v.state_dict(), 'cnn': self.cnn.state_dict(), 'rnn': self.rnn.state_dict(), 'dense': self.dense.state_dict() } torch.save(parameters, filename) def forward(self, audio): x = audio.squeeze() import pdb pdb.set_trace() feature = self.w2v(x) x = feature['x'] x = x.transpose(1, 0) x = x.unsqueeze(1) # input size : (batch_size, n_channels, n_frames, n_freq) if self.cnn_integration: bs_in, nc_in = x.size(0), x.size(1) x = x.view(bs_in * nc_in, 1, *x.shape[2:]) # conv features before = x x = self.cnn(x) bs, chan, frames, freq = x.size() if self.cnn_integration: x = x.reshape(bs_in, chan * nc_in, frames, freq) if freq != 1: warnings.warn( f"Output shape is: {(bs, frames, chan * freq)}, from {freq} staying freq" ) x = x.permute(0, 2, 1, 3) x = x.contiguous().view(bs, frames, chan * freq) else: x = x.squeeze(-1) x = x.permute(0, 2, 1) # [bs, frames, chan] # rnn features x = self.rnn(x) x = self.dropout(x) strong = self.dense(x) # [bs, frames, nclass] strong = self.sigmoid(strong) if self.attention: sof = self.dense_softmax(x) # [bs, frames, nclass] sof = self.softmax(sof) sof = torch.clamp(sof, min=1e-7, max=1) weak = (strong * sof).sum(1) / (sof.sum(1) + 1e-08) # [bs, nclass] else: weak = strong.mean(1) return strong, weak
class CRNN(nn.Module): def __init__(self, n_in_channel, nclass, attention=False, activation="Relu", dropout=0, train_cnn=True, rnn_type='BGRU', n_RNN_cell=64, n_layers_RNN=1, dropout_recurrent=0, **kwargs): super(CRNN, self).__init__() self.attention = attention self.cnn = CNN(n_in_channel, activation, dropout, **kwargs) if not train_cnn: for param in self.cnn.parameters(): param.requires_grad = False self.train_cnn = train_cnn if rnn_type == 'BGRU': self.rnn = BidirectionalGRU(self.cnn.nb_filters[-1], n_RNN_cell, dropout=dropout_recurrent, num_layers=n_layers_RNN) else: NotImplementedError("Only BGRU supported for CRNN for now") self.dropout = nn.Dropout(dropout) self.dense = nn.Linear(n_RNN_cell * 2, nclass) self.sigmoid = nn.Sigmoid() if self.attention: self.dense_softmax = nn.Linear(n_RNN_cell * 2, nclass) self.softmax = nn.Softmax(dim=-1) def load_cnn(self, parameters): self.cnn.load(parameters) if not self.train_cnn: for param in self.cnn.parameters(): param.requires_grad = False def load(self, filename=None, parameters=None): if filename is not None: parameters = torch.load(filename) if parameters is None: raise NotImplementedError( "load is a filename or a list of parameters (state_dict)") self.cnn.load(parameters=parameters["cnn"]) self.rnn.load_state_dict(parameters["rnn"]) self.dense.load_state_dict(parameters["dense"]) def state_dict(self, destination=None, prefix='', keep_vars=False): state_dict = { "cnn": self.cnn.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars), "rnn": self.rnn.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars), 'dense': self.dense.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) } return state_dict def save(self, filename): parameters = { 'cnn': self.cnn.state_dict(), 'rnn': self.rnn.state_dict(), 'dense': self.dense.state_dict() } torch.save(parameters, filename) def forward(self, x): # input size : (batch_size, n_channels, n_frames, n_freq) # conv features x = self.cnn(x) bs, chan, frames, freq = x.size() if freq != 1: warnings.warn("Output shape is: {}".format( (bs, frames, chan * freq))) x = x.permute(0, 2, 1, 3) x = x.contiguous().view(bs, frames, chan * freq) else: x = x.squeeze(-1) x = x.permute(0, 2, 1) # [bs, frames, chan] # rnn features x = self.rnn(x) x = self.dropout(x) strong = self.dense(x) # [bs, frames, nclass] strong = self.sigmoid(strong) if self.attention: sof = self.dense_softmax(x) # [bs, frames, nclass] sof = self.softmax(sof) sof = torch.clamp(sof, min=1e-7, max=1) weak = (strong * sof).sum(1) / sof.sum(1) # [bs, nclass] else: weak = strong.mean(1) return strong, weak
def main(): songs = get_notes() vocab_set = set() for song in songs: for note in song: vocab_set.add(note) n_in, n_out = prep_sequences(songs, sequence_length=100) X_train, X_val, y_train, y_val = train_test_split(n_in, n_out, test_size=0.2) train_ds = MusicDataset(X_train, y_train) val_ds = MusicDataset(X_val, y_val) train_dataloader = DataLoader(train_ds, batch_size=512, shuffle=True, num_workers=0) val_dataloader = DataLoader(val_ds, batch_size=512, shuffle=False, num_workers=0) model = CNN(100, len(vocab_set)) model.cuda() epochs = 25 initial_lr = 0.001 optimizer = optim.Adam(model.parameters(), lr=initial_lr) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs) loss_fn = CrossEntropyLoss() train_losses = [] val_losses = [] train_accuracies = [] val_accuracies = [] for epoch in tqdm(range(1, epochs + 1)): model.train() train_loss_total = 0.0 num_steps = 0 correct = 0 ### Train for i, batch in enumerate(train_dataloader): X, y = batch[0].cuda(), batch[1].cuda() train_preds = model(X) loss = loss_fn(train_preds, y) train_loss_total += loss.item() optimizer.zero_grad() loss.backward() optimizer.step() num_steps += 1 train_preds = torch.max(train_preds, 1)[1] correct += (train_preds == y).float().sum() train_loss_total_avg = train_loss_total / num_steps train_accuracy = correct / len(train_ds) train_accuracies.append(train_accuracy) train_losses.append(train_loss_total_avg) model.eval() val_loss_total = 0.0 num_steps = 0 correct = 0 for i, batch in enumerate(val_dataloader): with torch.no_grad(): X, y = batch[0].cuda(), batch[1].cuda() val_preds = model(X) loss = loss_fn(val_preds, y) val_loss_total += loss.item() val_preds = torch.max(val_preds, 1)[1] correct += (val_preds == y).float().sum() num_steps += 1 val_loss_total_avg = val_loss_total / num_steps val_accuracy = correct / len(val_ds) val_accuracies.append(val_accuracy) val_losses.append(val_loss_total_avg) scheduler.step() print('\nTrain loss: {:.4f}'.format(train_loss_total_avg)) print('Train accuracy: {:.4f}'.format(train_accuracy)) print('Val loss: {:.4f}'.format(val_loss_total_avg)) print('Val accuracy\n: {:.4f}'.format(val_accuracy)) torch.save(model.state_dict(), "weights/model_params_epoch" + str(epoch)) torch.save(optimizer.state_dict(), "weights/optim_params_epoch" + str(epoch)) plt.xlabel("Epoch") plt.ylabel("Accuracy") plt.plot(range(1, len(train_accuracies) + 1), train_accuracies) plt.plot(range(1, len(val_accuracies) + 1), val_accuracies) plt.savefig("plots/accuracies.png") plt.close() plt.xlabel("Epoch") plt.ylabel("Loss") plt.plot(range(1, len(train_losses) + 1), train_losses) plt.plot(range(1, len(val_losses) + 1), val_losses) plt.savefig("plots/losses.png") plt.close() generate_midi(model, val_ds, vocab_set, output_filename="output.mid")