def infer(self, features, features_mask, sos_id, eos_id, max_steps, hidden=None): etc = MergeDict(weights={'enc': []}) input = torch.full((features.size(0), 1), fill_value=sos_id, dtype=torch.long, device=features.device) input_mask = None finished = input == eos_id all_logits = [] for t in range(max_steps): logits, hidden, etc_self = self(input, features, input_mask, features_mask, hidden) etc['weights']['enc'].append(etc_self['weights']['enc']) input = logits.argmax(2) all_logits.append(logits) finished = finished | (input == eos_id) if torch.all(finished): break all_logits = torch.cat(all_logits, 1) etc['weights']['enc'] = torch.cat(etc['weights']['enc'], 2) return all_logits, hidden, etc
def forward(self, input, input_mask): etc = MergeDict() input, etc.merge['conv'] = self.conv(input) input = input.permute(0, 2, 1) input, _ = self.rnn(input) return input, etc
def forward(self, input, features, input_mask, features_mask, hidden=None): etc = MergeDict(weights={}) input = self.embedding(input) input = self.dropout(input) input, hidden = self.rnn(input, hidden) context, etc['weights']['enc'] = self.attention( input, features, features_mask.unsqueeze(1)) input = input + self.dropout(context) input = self.output(input) return input, hidden, etc
def infer(self, sigs, sigs_mask, **kwargs): spectras = self.spectra(sigs) spectras_mask = modules.downsample_mask(sigs_mask, spectras.size(3)) etc = MergeDict(spectras=spectras[:32]) features, etc.merge['encoder'] = self.encoder(spectras, spectras_mask) features_mask = modules.downsample_mask(spectras_mask, features.size(1)) logits, _, etc.merge['decoder'] = self.decoder.infer( features, features_mask, **kwargs) return logits, etc
def forward(self, input, input_mask): etc = MergeDict(weights={}) input, etc.merge['conv'] = self.conv(input) input_mask = modules.downsample_mask(input_mask, input.size(2)) input = input.permute(0, 2, 1) input = self.encoding(input) input = self.dropout(input) context, etc['weights']['self'] = self.self_attention(input, input, input_mask.unsqueeze(1)) input = input + self.dropout(context) return input, etc
def forward(self, input, features, input_mask, features_mask, hidden=None): etc = MergeDict(weights={}) input = self.embedding(input) input = self.encoding(input) input = self.dropout(input) subseq_attention_mask = attention.build_subseq_attention_mask( input.size(1), input.device) context, etc['weights']['self'] = self.self_attention( input, input, input_mask.unsqueeze(1) & subseq_attention_mask) input = input + self.dropout(context) context, etc['weights']['enc'] = self.attention( input, features, features_mask.unsqueeze(1)) input = input + self.dropout(context) input = self.output(input) return input, hidden, etc
def infer(self, features, features_mask, sos_id, eos_id, max_steps, hidden=None, debug_input=None): etc = MergeDict(weights={'self': [], 'enc': []}) if debug_input is None: input = torch.full((features.size(0), 1), fill_value=sos_id, dtype=torch.long, device=features.device) else: input = debug_input[:, :1] self_features = None self_features_mask = None finished = input == eos_id all_logits = [] for t in range(max_steps): input = self.embedding(input) input = self.encoding(input, t) input = self.dropout(input) if self_features is None: self_features = input self_features_mask = ~finished else: self_features = torch.cat([self_features, input], 1) self_features_mask = torch.cat([self_features_mask, ~finished], 1) context, weights = self.self_attention( input, self_features, self_features_mask.unsqueeze(1)) etc['weights']['self'].append(weights) input = input + self.dropout(context) context, weights = self.attention(input, features, features_mask.unsqueeze(1)) etc['weights']['enc'].append(weights) input = input + self.dropout(context) logits = self.output(input) if debug_input is None: input = logits.argmax(2) else: input = debug_input[:, t + 1:t + 2] all_logits.append(logits) finished = finished | (input == eos_id) if torch.all(finished): break all_logits = torch.cat(all_logits, 1) etc['weights']['self'] = torch.cat( [F.pad(w, (0, t + 1 - w.size(3))) for w in etc['weights']['self']], 2) etc['weights']['enc'] = torch.cat(etc['weights']['enc'], 2) return all_logits, hidden, etc
def test_merge_dict(): d = MergeDict(k1=1) d.merge['name'] = MergeDict(k2={'key': 'value'}) assert d['k1'] == 1 assert d['k2']['name/key'] == 'value'
def forward(self, input): input = self.conv(input) input = self.project(input) input = input.squeeze(2) return input, MergeDict()