Exemplo n.º 1
0
def get_model(dataloader):
    if hp.model_name == "CNN":
        from models.CNN import Trainer, CNN
        model = CNN()
    elif hp.model_name == "RNN":
        from models.RNN import Trainer, RNN
        model = RNN(hp.n_features, 64, 3, hp.num_classes, hp.device, classes=None)
    elif hp.model_name == "AudioEncoder":
        from models.AudioEncoder import Trainer, AudioEncoder
        model = AudioEncoder()
    elif hp.model_name == "GAN":
        from models.GAN import Gan, Trainer
        model = Gan()
    elif hp.model_name == "Test":
        from models.TestModul import TestModule, Trainer
        model = TestModule()

    if os.path.isfile(hp.model_path):
        model.load_state_dict(torch.load(hp.model_path))
        print("model loaded from: {}".format(hp.model_path))

    trainer = Trainer(dataloader, model)

    return model, trainer
Exemplo n.º 2
0
class WCRNN(nn.Module):  #, BaseFairseqModel):
    def __init__(self,
                 w2v_cfg,
                 n_in_channel,
                 nclass,
                 attention=False,
                 activation="Relu",
                 dropout=0,
                 train_cnn=True,
                 rnn_type='BGRU',
                 n_RNN_cell=64,
                 n_layers_RNN=1,
                 dropout_recurrent=0,
                 cnn_integration=False,
                 **kwargs):
        super(WCRNN, self).__init__()

        self.w2v = w2v_encoder(w2v_cfg)  #Wav2Vec2Config)
        #self.w2v = Wav2VecEncoder(Wav2Vec2SedConfig, None)
        self.pooling = nn.Sequential(nn.MaxPool2d((1, 4), (1, 4)))

        self.n_in_channel = n_in_channel
        self.attention = attention
        self.cnn_integration = cnn_integration
        n_in_cnn = n_in_channel
        if cnn_integration:
            n_in_cnn = 1
        self.cnn = CNN(n_in_cnn, activation, dropout, **kwargs)
        if not train_cnn:
            for param in self.cnn.parameters():
                param.requires_grad = False
        self.train_cnn = train_cnn
        if rnn_type == 'BGRU':
            nb_in = self.cnn.nb_filters[-1]
            if self.cnn_integration:
                # self.fc = nn.Linear(nb_in * n_in_channel, nb_in)
                nb_in = nb_in * n_in_channel
            self.rnn = BidirectionalGRU(nb_in,
                                        n_RNN_cell,
                                        dropout=dropout_recurrent,
                                        num_layers=n_layers_RNN)
        else:
            NotImplementedError("Only BGRU supported for CRNN for now")
        self.dropout = nn.Dropout(dropout)
        self.dense = nn.Linear(n_RNN_cell * 2, nclass)
        self.sigmoid = nn.Sigmoid()
        if self.attention:
            self.dense_softmax = nn.Linear(n_RNN_cell * 2, nclass)
            self.softmax = nn.Softmax(dim=-1)

    def load_cnn(self, state_dict):
        self.cnn.load_state_dict(state_dict)
        if not self.train_cnn:
            for param in self.cnn.parameters():
                param.requires_grad = False

    def load_state_dict(self, state_dict, strict=True):
        self.w2v.load_state_dice(state_dict["w2v"])
        self.cnn.load_state_dict(state_dict["cnn"])
        self.rnn.load_state_dict(state_dict["rnn"])
        self.dense.load_state_dict(state_dict["dense"])

    def state_dict(self, destination=None, prefix='', keep_vars=False):
        state_dict = {
            "w2v":
            self.w2v.state_dict(destination=destination,
                                prefix=prefix,
                                keep_vars=keep_vars),
            "cnn":
            self.cnn.state_dict(destination=destination,
                                prefix=prefix,
                                keep_vars=keep_vars),
            "rnn":
            self.rnn.state_dict(destination=destination,
                                prefix=prefix,
                                keep_vars=keep_vars),
            'dense':
            self.dense.state_dict(destination=destination,
                                  prefix=prefix,
                                  keep_vars=keep_vars)
        }
        return state_dict

    def save(self, filename):
        parameters = {
            'w2v': self.w2v.state_dict(),
            'cnn': self.cnn.state_dict(),
            'rnn': self.rnn.state_dict(),
            'dense': self.dense.state_dict()
        }
        torch.save(parameters, filename)

    def forward(self, audio):
        x = audio.squeeze()
        import pdb
        pdb.set_trace()
        feature = self.w2v(x)
        x = feature['x']
        x = x.transpose(1, 0)
        x = x.unsqueeze(1)

        # input size : (batch_size, n_channels, n_frames, n_freq)
        if self.cnn_integration:
            bs_in, nc_in = x.size(0), x.size(1)
            x = x.view(bs_in * nc_in, 1, *x.shape[2:])

        # conv features
        before = x
        x = self.cnn(x)
        bs, chan, frames, freq = x.size()
        if self.cnn_integration:
            x = x.reshape(bs_in, chan * nc_in, frames, freq)

        if freq != 1:
            warnings.warn(
                f"Output shape is: {(bs, frames, chan * freq)}, from {freq} staying freq"
            )
            x = x.permute(0, 2, 1, 3)
            x = x.contiguous().view(bs, frames, chan * freq)
        else:
            x = x.squeeze(-1)
            x = x.permute(0, 2, 1)  # [bs, frames, chan]

        # rnn features
        x = self.rnn(x)
        x = self.dropout(x)
        strong = self.dense(x)  # [bs, frames, nclass]
        strong = self.sigmoid(strong)
        if self.attention:
            sof = self.dense_softmax(x)  # [bs, frames, nclass]
            sof = self.softmax(sof)
            sof = torch.clamp(sof, min=1e-7, max=1)
            weak = (strong * sof).sum(1) / (sof.sum(1) + 1e-08)  # [bs, nclass]
        else:
            weak = strong.mean(1)
        return strong, weak