def get_model(dataloader): if hp.model_name == "CNN": from models.CNN import Trainer, CNN model = CNN() elif hp.model_name == "RNN": from models.RNN import Trainer, RNN model = RNN(hp.n_features, 64, 3, hp.num_classes, hp.device, classes=None) elif hp.model_name == "AudioEncoder": from models.AudioEncoder import Trainer, AudioEncoder model = AudioEncoder() elif hp.model_name == "GAN": from models.GAN import Gan, Trainer model = Gan() elif hp.model_name == "Test": from models.TestModul import TestModule, Trainer model = TestModule() if os.path.isfile(hp.model_path): model.load_state_dict(torch.load(hp.model_path)) print("model loaded from: {}".format(hp.model_path)) trainer = Trainer(dataloader, model) return model, trainer
class WCRNN(nn.Module): #, BaseFairseqModel): def __init__(self, w2v_cfg, n_in_channel, nclass, attention=False, activation="Relu", dropout=0, train_cnn=True, rnn_type='BGRU', n_RNN_cell=64, n_layers_RNN=1, dropout_recurrent=0, cnn_integration=False, **kwargs): super(WCRNN, self).__init__() self.w2v = w2v_encoder(w2v_cfg) #Wav2Vec2Config) #self.w2v = Wav2VecEncoder(Wav2Vec2SedConfig, None) self.pooling = nn.Sequential(nn.MaxPool2d((1, 4), (1, 4))) self.n_in_channel = n_in_channel self.attention = attention self.cnn_integration = cnn_integration n_in_cnn = n_in_channel if cnn_integration: n_in_cnn = 1 self.cnn = CNN(n_in_cnn, activation, dropout, **kwargs) if not train_cnn: for param in self.cnn.parameters(): param.requires_grad = False self.train_cnn = train_cnn if rnn_type == 'BGRU': nb_in = self.cnn.nb_filters[-1] if self.cnn_integration: # self.fc = nn.Linear(nb_in * n_in_channel, nb_in) nb_in = nb_in * n_in_channel self.rnn = BidirectionalGRU(nb_in, n_RNN_cell, dropout=dropout_recurrent, num_layers=n_layers_RNN) else: NotImplementedError("Only BGRU supported for CRNN for now") self.dropout = nn.Dropout(dropout) self.dense = nn.Linear(n_RNN_cell * 2, nclass) self.sigmoid = nn.Sigmoid() if self.attention: self.dense_softmax = nn.Linear(n_RNN_cell * 2, nclass) self.softmax = nn.Softmax(dim=-1) def load_cnn(self, state_dict): self.cnn.load_state_dict(state_dict) if not self.train_cnn: for param in self.cnn.parameters(): param.requires_grad = False def load_state_dict(self, state_dict, strict=True): self.w2v.load_state_dice(state_dict["w2v"]) self.cnn.load_state_dict(state_dict["cnn"]) self.rnn.load_state_dict(state_dict["rnn"]) self.dense.load_state_dict(state_dict["dense"]) def state_dict(self, destination=None, prefix='', keep_vars=False): state_dict = { "w2v": self.w2v.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars), "cnn": self.cnn.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars), "rnn": self.rnn.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars), 'dense': self.dense.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) } return state_dict def save(self, filename): parameters = { 'w2v': self.w2v.state_dict(), 'cnn': self.cnn.state_dict(), 'rnn': self.rnn.state_dict(), 'dense': self.dense.state_dict() } torch.save(parameters, filename) def forward(self, audio): x = audio.squeeze() import pdb pdb.set_trace() feature = self.w2v(x) x = feature['x'] x = x.transpose(1, 0) x = x.unsqueeze(1) # input size : (batch_size, n_channels, n_frames, n_freq) if self.cnn_integration: bs_in, nc_in = x.size(0), x.size(1) x = x.view(bs_in * nc_in, 1, *x.shape[2:]) # conv features before = x x = self.cnn(x) bs, chan, frames, freq = x.size() if self.cnn_integration: x = x.reshape(bs_in, chan * nc_in, frames, freq) if freq != 1: warnings.warn( f"Output shape is: {(bs, frames, chan * freq)}, from {freq} staying freq" ) x = x.permute(0, 2, 1, 3) x = x.contiguous().view(bs, frames, chan * freq) else: x = x.squeeze(-1) x = x.permute(0, 2, 1) # [bs, frames, chan] # rnn features x = self.rnn(x) x = self.dropout(x) strong = self.dense(x) # [bs, frames, nclass] strong = self.sigmoid(strong) if self.attention: sof = self.dense_softmax(x) # [bs, frames, nclass] sof = self.softmax(sof) sof = torch.clamp(sof, min=1e-7, max=1) weak = (strong * sof).sum(1) / (sof.sum(1) + 1e-08) # [bs, nclass] else: weak = strong.mean(1) return strong, weak