示例#1
0
    def infer(self,
              features,
              features_mask,
              sos_id,
              eos_id,
              max_steps,
              hidden=None):
        etc = MergeDict(weights={'enc': []})

        input = torch.full((features.size(0), 1),
                           fill_value=sos_id,
                           dtype=torch.long,
                           device=features.device)
        input_mask = None
        finished = input == eos_id

        all_logits = []
        for t in range(max_steps):
            logits, hidden, etc_self = self(input, features, input_mask,
                                            features_mask, hidden)
            etc['weights']['enc'].append(etc_self['weights']['enc'])

            input = logits.argmax(2)
            all_logits.append(logits)

            finished = finished | (input == eos_id)
            if torch.all(finished):
                break

        all_logits = torch.cat(all_logits, 1)
        etc['weights']['enc'] = torch.cat(etc['weights']['enc'], 2)

        return all_logits, hidden, etc
示例#2
0
    def forward(self, input, input_mask):
        etc = MergeDict()

        input, etc.merge['conv'] = self.conv(input)
        input = input.permute(0, 2, 1)
        input, _ = self.rnn(input)

        return input, etc
示例#3
0
    def forward(self, input, features, input_mask, features_mask, hidden=None):
        etc = MergeDict(weights={})

        input = self.embedding(input)
        input = self.dropout(input)
        input, hidden = self.rnn(input, hidden)
        context, etc['weights']['enc'] = self.attention(
            input, features, features_mask.unsqueeze(1))
        input = input + self.dropout(context)
        input = self.output(input)

        return input, hidden, etc
示例#4
0
    def infer(self, sigs, sigs_mask, **kwargs):
        spectras = self.spectra(sigs)
        spectras_mask = modules.downsample_mask(sigs_mask, spectras.size(3))
        etc = MergeDict(spectras=spectras[:32])

        features, etc.merge['encoder'] = self.encoder(spectras, spectras_mask)
        features_mask = modules.downsample_mask(spectras_mask,
                                                features.size(1))

        logits, _, etc.merge['decoder'] = self.decoder.infer(
            features, features_mask, **kwargs)

        return logits, etc
示例#5
0
    def forward(self, input, input_mask):
        etc = MergeDict(weights={})

        input, etc.merge['conv'] = self.conv(input)
        input_mask = modules.downsample_mask(input_mask, input.size(2))
        input = input.permute(0, 2, 1)

        input = self.encoding(input)
        input = self.dropout(input)

        context, etc['weights']['self'] = self.self_attention(input, input, input_mask.unsqueeze(1))
        input = input + self.dropout(context)

        return input, etc
示例#6
0
    def forward(self, input, features, input_mask, features_mask, hidden=None):
        etc = MergeDict(weights={})

        input = self.embedding(input)
        input = self.encoding(input)
        input = self.dropout(input)

        subseq_attention_mask = attention.build_subseq_attention_mask(
            input.size(1), input.device)
        context, etc['weights']['self'] = self.self_attention(
            input, input,
            input_mask.unsqueeze(1) & subseq_attention_mask)
        input = input + self.dropout(context)

        context, etc['weights']['enc'] = self.attention(
            input, features, features_mask.unsqueeze(1))
        input = input + self.dropout(context)

        input = self.output(input)

        return input, hidden, etc
示例#7
0
    def infer(self,
              features,
              features_mask,
              sos_id,
              eos_id,
              max_steps,
              hidden=None,
              debug_input=None):
        etc = MergeDict(weights={'self': [], 'enc': []})

        if debug_input is None:
            input = torch.full((features.size(0), 1),
                               fill_value=sos_id,
                               dtype=torch.long,
                               device=features.device)
        else:
            input = debug_input[:, :1]

        self_features = None
        self_features_mask = None
        finished = input == eos_id

        all_logits = []
        for t in range(max_steps):
            input = self.embedding(input)
            input = self.encoding(input, t)
            input = self.dropout(input)

            if self_features is None:
                self_features = input
                self_features_mask = ~finished
            else:
                self_features = torch.cat([self_features, input], 1)
                self_features_mask = torch.cat([self_features_mask, ~finished],
                                               1)

            context, weights = self.self_attention(
                input, self_features, self_features_mask.unsqueeze(1))
            etc['weights']['self'].append(weights)
            input = input + self.dropout(context)

            context, weights = self.attention(input, features,
                                              features_mask.unsqueeze(1))
            etc['weights']['enc'].append(weights)
            input = input + self.dropout(context)

            logits = self.output(input)
            if debug_input is None:
                input = logits.argmax(2)
            else:
                input = debug_input[:, t + 1:t + 2]

            all_logits.append(logits)

            finished = finished | (input == eos_id)
            if torch.all(finished):
                break

        all_logits = torch.cat(all_logits, 1)
        etc['weights']['self'] = torch.cat(
            [F.pad(w, (0, t + 1 - w.size(3))) for w in etc['weights']['self']],
            2)
        etc['weights']['enc'] = torch.cat(etc['weights']['enc'], 2)

        return all_logits, hidden, etc
示例#8
0
def test_merge_dict():
    d = MergeDict(k1=1)
    d.merge['name'] = MergeDict(k2={'key': 'value'})

    assert d['k1'] == 1
    assert d['k2']['name/key'] == 'value'
示例#9
0
    def forward(self, input):
        input = self.conv(input)
        input = self.project(input)
        input = input.squeeze(2)

        return input, MergeDict()