def _forward_step(self, input_: LT, src_emb: FT, state: LstmStatesByLayers, src_states: FT, mask_src: BT, lang_emb: Optional[FT] = None, prev_att: Optional[FT] = None) -> Tuple[FT, FT, FT, FT]: emb = self.char_emb(input_) if lang_emb is not None: emb = emb + lang_emb inp = torch.cat([emb, prev_att], dim=-1) if g.input_feeding else emb hid_rnn, next_state = self.cell( inp, state) # hid_rnn has gone through dropout already. almt, ctx = self.attn.forward(hid_rnn, src_states, mask_src) # So has src_states. with NoName(hid_rnn, ctx): cat = torch.cat([hid_rnn, ctx], dim=-1) hid_cat = self.hidden(cat) hid_cat = self.drop(hid_cat) with NoName(src_emb, hid_cat, almt): ctx_emb = (src_emb * almt.t().unsqueeze(dim=-1)).sum(dim=0) hid_res = self.nc_residual(ctx_emb, hid_cat).rename('batch', 'hidden') logit = self.char_emb.project(hid_res) log_prob = logit.log_softmax(dim=-1).refine_names('batch', 'unit') return next_state, log_prob, almt, hid_res
def gumbel_softmax(logits: FT, temperature: float, num_samples: Optional[int] = None) -> Tuple[FT, FT, LT]: """Sample from the Gumbel-Softmax distribution and optionally discretize.""" logits = logits.align_to('batch', 'length', 'label') y = gumbel_softmax_sample(logits, temperature, num_samples) y = y.align_to('batch', 'length', 'label', ...) max_values, max_inds = y.max(dim='label') y_one_hot = (max_values.align_as(y) == y).float() y_one_hot = (y_one_hot - y).detach() + y bi = get_named_range(logits.size('batch'), 'batch').align_as(max_inds) li = get_named_range(logits.size('length'), 'length').align_as(max_inds) if num_samples is None: with NoName(max_inds, y_one_hot, bi, li): probs = y_one_hot[bi, li, max_inds] probs.rename_('batch', 'length') else: si = get_named_range(max_inds.size('sample'), 'sample').align_as(max_inds) with NoName(max_inds, y_one_hot, bi, li, si): probs = y_one_hot[bi, li, max_inds, si] probs.rename_('batch', 'length', 'sample') seq_probs = (1e-8 + probs).log().sum(dim='length').exp() return y, y_one_hot, max_inds, seq_probs
def forward(self, h_t: FT, h_s: FT, mask_src: BT) -> Tuple[FT, FT]: dt = h_t.shape[-1] Wh_s = self._get_Wh_s(h_s) with NoName(h_t): scores = (Wh_s * h_t).sum(dim=-1) scores = torch.where(mask_src, scores, torch.full_like(scores, -9999.9)) almt_distr = nn.functional.log_softmax(scores, dim=0).exp() # sl x bs with NoName(almt_distr): ctx = (almt_distr.unsqueeze(dim=-1) * h_s).sum(dim=0) # bs x d almt_distr = almt_distr.t() return almt_distr, ctx
def forward(self, input_: LT, lengths: LT) -> Tuple[FT, LstmOutputTuple]: # input_: seq_length x batch_size # note that input_size == hidden_size # define n_conv as the number of parallel convolutional layers emb = self.embedding(input_) # seq_length x batch_size x input_size with NoName(emb, lengths): reshaped_emb = emb.permute( 1, 2, 0 ) # reshape to batch_size x input_size x seq_length for CNN input conv_outputs = [ self.dropout(F.relu(conv(reshaped_emb))) for conv in self.conv_layers ] # each conv layer's output is batch_size x hidden_size x seq_length # stack the CNN outputs on the hidden_size dimension x = torch.cat( conv_outputs, dim=1) # batch_size x n_conv*hidden_size x seq_length x = x.permute(2, 0, 1) # seq_length x batch_size x n_conv*hidden_size # project the concatenated convolutional layer outputs into 2*hidden_size dimensions so that `output` looks as though it were the states of a bidirectional lstm output = self.W_output( x) # seq_length x batch_size x 2*hidden_size # we don't try to reconstruct the state, so we just pass (None, None) return emb, (output, (None, None))
def forward(self, curr_ids: LT, end_ids: LT, steps: Optional[LT] = None, done: Optional[BT] = None) -> FT: """Get policy evaluation. if `done` is provided, we get values for s1 instead of s0. In that case, end states should have values set to 0. `step` should start with 0. """ state_repr = self.enc(curr_ids, end_ids) # NOTE(j_luo) If s1 is being evaluated, we should increment `step`. if done is not None and g.use_finite_horizon: steps = steps + 1 with NoName(state_repr, steps): if g.use_finite_horizon: rel_step = steps.float() / g.max_rollout_length state_repr = torch.cat( [state_repr, rel_step.unsqueeze(dim=-1)], dim=-1) values = self.regressor(state_repr).squeeze(dim=-1) # Deal with special cases. We start with final step case, and then overwrite it if done. if g.use_finite_horizon: final_step = steps == g.max_rollout_length values = torch.where(final_step, torch.zeros_like(values), values) if done is not None: # NOTE(j_luo) Use final reward for the value of the end state. values = torch.where(done, torch.full_like(values, g.final_reward), values) return values
def _restore_shape(tensor, bi, lsi, lei, viable, value: Optional[float] = None): bs = bi.size('batch') len_s = lsi.size('len_s') len_e = lei.size('len_e') shape = (bs, len_s, len_e) names = ('batch', 'len_s', 'len_e') if tensor.ndim > 1: shape += tensor.shape[1:] names += tensor.names[1:] with NoName(bi, lsi, lei, viable, tensor): v_bi = bi[viable] v_lsi = lsi[viable] v_lei = lei[viable] ret = get_zeros(*shape).to(tensor.dtype) if value is not None: ret.fill_(value) ret[v_bi, v_lsi, v_lei] = tensor ret.rename_(*names) return ret
def _analyze_unsupervised(self, model_ret: DecipherModelReturn, batch: ContinuousIpaBatch) -> Metrics: metrics = Metrics() # TODO(j_luo) Check the sample scores for hyps that are dummies (i.e., the length of the segment is too small to get beam_size hyps). is_unique = model_ret.packed_words.is_unique modified_logits = model_ret.probs.sample_log_probs * g.concentration + ( ~is_unique).float() * (-999.9) sample_scores = model_ret.scores.phi_score ptb_sample_scores = model_ret.ptb_scores.phi_score duplicates = model_ret.duplicates with NoName(ptb_sample_scores): ptb_sample_scores[duplicates] = -999.9 bs = sample_scores.size('batch') ptb_sample_scores = ptb_sample_scores.unflatten( 'batch', [('batch', bs), ('contrast', g.n_times * 2)]) sample_scores = sample_scores.align_as(ptb_sample_scores) all_scores = torch.cat([sample_scores, ptb_sample_scores], dim='contrast') all_probs = all_scores.log_softmax(dim='contrast').exp() sample_probs = all_probs.align_to(..., 'contrast')[..., 0] utility = _compute_utility(modified_logits, sample_probs) total_loss = Metric('total_loss', -utility, batch.batch_size) metrics += total_loss return metrics
def retrieve(tensor, last_name: str = 'hidden') -> torch.Tensor: with NoName(tensor, batch_i, beam_i): ret = tensor[batch_i, beam_i] new_names = ('batch', 'beam') if last_name: new_names += (last_name, ) return ret.refine_names(*new_names)
def forward(self, dense_feat_matrices: Dict[Category, FT], padding: Optional[BT] = None, masked_positions: Optional[LT] = None) -> FT: if padding is not None: padding = padding.align_to('batch', 'length') embs = list() for cat in Category: if cat.name in self.embed_layer and cat in dense_feat_matrices: sfm = dense_feat_matrices[cat] emb_param = self.embed_layer[cat.name] sfm = sfm.align_to('batch', 'length', ...) emb = sfm @ emb_param if padding is not None: emb.rename(None)[padding.rename(None)] = 0.0 embs.append(emb) feat_emb = torch.cat(embs, dim=-1).refine_names('batch', 'length', self.char_emb_name) if masked_positions is not None: batch_i = get_range(padding.size('batch'), 1, 0) feat_emb = feat_emb.align_to('batch', 'char_emb', 'length') # feat_emb = self.feat_embeddings(feat_matrix).view(bs, l, -1).transpose(1, 2) # size: bs x D x l with NoName(feat_emb, masked_positions): feat_emb[batch_i, :, masked_positions] = 0.0 return feat_emb
def forward(self, sot_id: int, src_emb: FT, src_outputs: FT, mask_src: BT, max_length: Optional[int] = None, target: Optional[LT] = None, lang_emb: Optional[FT] = None) -> Tuple[FT, FT]: # Prepare inputs. max_length = self._get_max_length(max_length, target) batch_size = mask_src.size('batch') input_ = self._prepare_first_input(sot_id, batch_size, mask_src.device) prev_att = get_zeros(batch_size, g.hidden_size) if g.input_feeding else None state = LstmStatesByLayers.zero_state(self.cell.num_layers, batch_size, self.attn.input_tgt_size, bidirectional=False) # Main loop. log_probs = list() almt_distrs = list() with ScopedCache('Wh_s'): for l in range(max_length): state, log_prob, almt_distr, prev_att = self._forward_step( input_, src_emb, state, src_outputs, mask_src, lang_emb=lang_emb, prev_att=prev_att) if target is None: input_ = log_prob.max(dim=-1)[1].rename('batch') else: input_ = target[l] log_probs.append(log_prob) almt_distrs.append(almt_distr) # Prepare outputs. with NoName(*log_probs), NoName(*almt_distrs): log_probs = torch.stack(log_probs).rename('pos', 'batch', 'unit') almt_distrs = torch.stack(almt_distrs).rename( 'tgt_pos', 'batch', 'src_pos') return log_probs, almt_distrs
def _stack_beam(lst: List[torch.Tensor], last_name=None): new_names = ('batch', 'beam', 'pos') if last_name: new_names += (last_name, ) with NoName(*lst): # NOTE(j_luo) Set dim = 2 instead of -1 since some tensors might have an extra dimension. ret = torch.stack(lst, dim=2).refine_names(*new_names) return ret
def forward(self, ku_id_seqs: LT, lu_repr: FT) -> Tuple[FT, FT]: """Returns lu x ku representation and bs x l x ku representation.""" ku_char_weight = self.unit_aligner.weight ku_char_repr = ku_char_weight @ lu_repr ku_char_repr = ku_char_repr.refine_names('ku_char_emb', 'char_emb') with NoName(ku_char_repr, ku_id_seqs): _ku_repr = ku_char_repr[ku_id_seqs].rename('batch', 'length', 'char_emb') _ku_repr = _ku_repr.align_to('batch', 'char_emb', ...) with NoName(_ku_repr): ku_ctx_repr = self.conv(_ku_repr).rename('batch', 'char_emb', 'length') ku_ctx_repr = ku_ctx_repr.align_to(..., 'char_emb') ku_ctx_repr = self.dropout(ku_ctx_repr) return ku_char_repr, ku_ctx_repr
def evaluate(self, states, steps: Optional[Union[int, LT]] = None) -> List[float]: """Expand and evaluate the leaf node.""" values = [None] * len(states) outstanding_idx = list() outstanding_states = list() # Deal with end states first. for i, state in enumerate(states): if state.stopped or state.done: # NOTE(j_luo) This value is used for backup. If already reaching the end state, the final reward is either accounted for by the step reward, or by the value network. Therefore, we need to set it to 0.0 here. values[i] = 0.0 else: outstanding_idx.append(i) outstanding_states.append(state) # Collect states that need evaluation. if outstanding_states: almts1 = almts2 = None if g.use_alignment: id_seqs, almts1, almts2 = parallel_stack_ids( outstanding_states, g.num_workers, True, self.env.max_end_length) almts1 = get_tensor(almts1).rename('batch', 'word', 'pos') almts2 = get_tensor(almts2).rename('batch', 'word', 'pos') else: id_seqs = parallel_stack_ids(outstanding_states, g.num_workers, False, self.env.max_end_length) id_seqs = get_tensor(id_seqs).rename('batch', 'word', 'pos') if steps is not None and not isinstance(steps, int): steps = steps[outstanding_idx] # TODO(j_luo) Scoped might be wrong here. # with ScopedCache('state_repr'): # NOTE(j_luo) Don't forget to call exp(). priors = self.agent.get_policy(id_seqs, almts=(almts1, almts2)).exp() with NoName(priors): meta_priors = priors[:, [0, 2, 3, 4, 5, 6]].cpu().numpy() special_priors = priors[:, 1].cpu().numpy() if g.use_value_guidance: agent_values = self.agent.get_values( id_seqs, steps=steps).cpu().numpy() else: agent_values = np.zeros([len(id_seqs)], dtype='float32') for i, state, mp, sp, v in zip(outstanding_idx, outstanding_states, meta_priors, special_priors, agent_values): # NOTE(j_luo) Values should be returned even if states are duplicates or have been visited. values[i] = v # NOTE(j_luo) Skip duplicate states (due to exploration collapse) or visited states (due to rollout truncation). if not state.is_leaf(): continue # print(mp[1, 111]) self.env.evaluate(state, mp, sp) return values
def get_scores(self, batch: OnePairBatch, tgt_vocab_seqs: PaddedUnitSeqs, chunk_size: int = 100) -> FT: """Given a batch and a list of target tokens (provided as id sequences), return scores produced by the model.""" src_emb, (output, state) = self.encoder(batch.src_seqs.ids, batch.src_seqs.lengths) src_emb = src_emb.refine_names('pos', 'batch', 'src_emb') output = output.refine_names('pos', 'batch', 'output') batch_size = src_emb.size('batch') lang_emb = self._prepare_lang_emb(batch) def create_chunk(size, base, old_chunk, interleave: bool = True): if not interleave: return base.repeat(1, batch_size) if old_chunk is not None and old_chunk.size( 'batch') == batch_size * size: return old_chunk new_chunk = torch.repeat_interleave(base, size, dim='batch') return new_chunk chunk_src_emb = None chunk_output = None chunk_src_paddings = None scores = list() for split in pbar(tgt_vocab_seqs.split(chunk_size), desc='Get scores: chunk'): split: PaddedUnitSeqs bs_split = len(split) chunk_src_emb = create_chunk(bs_split, src_emb, chunk_src_emb) chunk_output = create_chunk(bs_split, output, chunk_output) chunk_src_paddings = create_chunk(bs_split, batch.src_seqs.paddings, chunk_src_paddings) chunk_target = create_chunk(None, split.ids, None, interleave=False) chunk_tgt_paddings = create_chunk(None, split.paddings, None, interleave=False) chunk_log_probs, _ = self.decoder(SOT_ID, chunk_src_emb, chunk_output, chunk_src_paddings, target=chunk_target, lang_emb=lang_emb) chunk_scores = chunk_log_probs.gather('unit', chunk_target) chunk_scores = (chunk_scores * chunk_tgt_paddings).sum('pos') with NoName(chunk_scores): scores.append( chunk_scores.view(batch_size, bs_split).refine_names( 'batch', 'tgt_vocab')) scores = torch.cat(scores, dim='tgt_vocab') return scores
def trace_back(self, *attr_names: str) -> Dict[str, torch.Tensor]: """Trace back some attribute by going backwards through the beam search procedure.""" beam_i = get_named_range(self.beam_size, 'beam').expand_as(self.beam_ids) batch_i = get_named_range(self.batch_size, 'batch').expand_as(beam_i) beam = self ret = defaultdict(list) while beam.last_beam is not None: with NoName(beam.beam_ids, beam_i, batch_i): for attr_name in attr_names: attr = getattr(beam, attr_name) with NoName(attr): ret[attr_name].append(attr[batch_i, beam_i]) beam_i = beam.beam_ids[batch_i, beam_i] beam = beam.last_beam for attr_name in attr_names: # NOTE(j_luo) Reverse the list since we are going backwards. last_name = 'src_pos' if attr_name == 'almt' else None ret[attr_name] = _stack_beam(ret[attr_name][::-1], last_name=last_name) return ret
def forward(self, inp: FT, sparse: bool = False, indices: Optional[NDA] = None) -> Tuple[FT, FT]: is_2d = inp.ndim == 2 if g.use_conditional and not is_2d: raise RuntimeError(f'Not sure why you end up here.') # assert True, 'Cannot deal with dense action space.' assert is_2d potentials = self.potential_block(inp) with NoName(potentials): potentials = potentials.view(-1, 7, len(self.env.abc)) return potentials.rename('batch', 'phase', 'action')
def forward(self, input_: LT, lengths: LT) -> Tuple[FT, LstmOutputTuple]: emb = self.embedding(input_) with NoName(emb, lengths): packed_emb = pack_padded_sequence(emb, lengths, enforce_sorted=False) output, state = self.lstm(packed_emb) output = pad_packed_sequence(output)[0] output = self.drop( output ) # Dropout after last output, different from the behavior for nn.LSTM. return emb, (output, LstmStateTuple(state, bidirectional=self.lstm.bidirectional))
def forward(self, curr_ids: LT, end_ids: LT, almts: Optional[Tuple[LT, LT]] = None): if g.repr_mode != 'state' and almts is None: raise RuntimeError( f'Must pass `almts` if `repr_mode` is not "state".') if g.repr_mode != 'state': curr_almts, end_almts = almts assert curr_almts.shape == curr_ids.shape assert end_almts.shape[1:] == end_ids.shape # NOTE(j_luo) +1 for 0-index, +1 for storing fake scattered values. max_len = max(curr_almts.max(), end_almts.max()) + 2 new_shape = curr_almts.shape[:-1] + (max_len, ) aligned_curr_ids = get_zeros(*new_shape).long().fill_(PAD_ID) aligned_end_ids = get_zeros(*new_shape).long().fill_(PAD_ID) with NoName(curr_almts, curr_ids, end_almts, end_ids): curr_mask = curr_almts == -1 curr_almts[curr_mask] = max_len - 1 end_mask = end_almts == -1 end_almts[end_mask] = max_len - 1 aligned_curr_ids.scatter_(-1, curr_almts, curr_ids) aligned_end_ids.scatter_(-1, end_almts, end_ids.expand_as(end_almts)) aligned_curr_ids = aligned_curr_ids.narrow( -1, 0, max_len - 1).rename('batch', 'word', 'pos') aligned_end_ids = aligned_end_ids.narrow( -1, 0, max_len - 1).rename('batch', 'word', 'pos') curr_char_emb = self._get_char_embedding(aligned_curr_ids) end_char_emb = self._get_char_embedding(aligned_end_ids) if g.repr_mode == 'char': state_repr = self._get_word_embedding_from_chars( curr_char_emb - end_char_emb).mean(dim='word') else: curr_word_emb = self._get_word_embedding_from_chars( curr_char_emb) end_word_emb = self._get_word_embedding_from_chars( end_char_emb) state_repr = (curr_word_emb - end_word_emb).mean(dim='word') else: word_repr = self._get_word_embedding(curr_ids) end_word_repr = self._get_word_embedding(end_ids) state_repr = (word_repr - end_word_repr).mean(dim='word') return state_repr
def split(self, size: int) -> List[PaddedUnitSeqs]: with NoName(self.ids, self.paddings): ids_lst = self.ids.split(size, dim=-1) paddings_lst = self.paddings.split(size, dim=-1) start = 0 ret = list() for ids, paddings in zip(ids_lst, paddings_lst): length = ids.size(1) units = self.units[start: start + length] forms = self.forms[start: start + length] split = PaddedUnitSeqs(self.lang, forms, units, ids, paddings, lang_id=self.lang_id) ret.append(split) start += length assert start == self.ids.size('batch') return ret
def search_by_probs(self, lengths: LT, label_log_probs: FT) -> Tuple[LT, FT]: max_length = lengths.max().item() samples = get_tensor( torch.LongTensor(list(product([B, I, O], repeat=max_length)))) samples.rename_('sample', 'length') bs = label_log_probs.size('batch') samples = samples.align_to('batch', 'sample', 'length').expand(bs, -1, -1) sample_log_probs = label_log_probs.gather('label', samples) with NoName(lengths): length_mask = get_length_mask(lengths, max_length).rename( 'batch', 'length') length_mask = length_mask.align_to(sample_log_probs) sample_log_probs = (sample_log_probs * length_mask.float()).sum(dim='length') return samples, sample_log_probs
def forward(self, input_: FT, state: LstmStatesByLayers, state_direction: Optional[str] = None) -> LstmOutputsByLayers: assert state.num_layers == self.num_layers new_states = list() for i in range(self.num_layers): h, c = state.get_layer(i, state_direction) with NoName(input_, h, c): new_h, new_c = self.cells[i](input_, (h, c)) new_h.rename_(*h.names) new_c.rename_(*c.names) new_states.append((new_h, new_c)) input_ = new_h.refine_names('batch', ...) # Note that the last layer also uses dropout, which is different from nn.LSTM. input_ = self.drop(input_) return input_, LstmStatesByLayers(new_states)
def forward(self, feat_matrix: LT, pos_to_predict: LT, source_padding: BT) -> FT: bs = source_padding.size('batch') l = source_padding.size('length') batch_i = get_range(bs, 1, 0) feat_emb = self.feat_embedding(feat_matrix, source_padding, masked_positions=pos_to_predict) feat_emb = feat_emb.align_to('batch', 'char_emb', 'length') output = self.conv_layers(feat_emb.rename(None)) output = output.refine_names('batch', 'char_conv_repr', 'length') # size: bs x D x l output = self.linear(output.align_to( ..., 'char_conv_repr')) # size: bs x l x n_hid output = output.refine_names('batch', 'length', 'hidden_repr') output = nn.functional.leaky_relu(output, negative_slope=0.1) # NOTE(j_luo) This is actually quite wasteful because we are discarding all the irrelevant information, which is computed anyway. This is equivalent to training on ngrams. with NoName(output, pos_to_predict): h = output[batch_i, pos_to_predict] h = h.refine_names('batch', 'hidden_repr') # size: bs x n_hid return h
def forward(self, feat_matrix: LT, padding: Optional[BT] = None, masked_positions: Optional[LT] = None) -> FT: feat_matrix = adv_index(feat_matrix, 'feat_group', self.c_idx) # Convert old style to new style ipa features. if g.new_style: new_feat_matrix = list() for c_idx, one_feat_group in zip( self.c_idx.unbind(dim=self.group_name), feat_matrix.unbind(dim=self.group_name)): one_feat_group = one_feat_group.rename(None) new_enum = get_new_style_enum(c_idx.item()) l = new_enum.num_groups() if l > 1: new_feat_matrix.append( self.complex_conversions[one_feat_group][..., :l]) else: new_feat_matrix.append( self.simple_conversions[one_feat_group].unsqueeze( dim=-1)) new_feat_matrix = torch.cat( new_feat_matrix, dim=-1).refine_names(*feat_matrix.names) feat_matrix = new_feat_matrix feat_emb = embed(self.embed_layer, feat_matrix, self.feat_emb_name) feat_emb = feat_emb.flatten([self.group_name, self.feat_emb_name], self.char_emb_name) feat_emb = feat_emb.align_to('batch', 'length', self.char_emb_name) if padding is not None: padding = padding.align_to('batch', 'length') feat_emb.rename(None)[padding.rename(None)] = 0.0 if masked_positions is not None: batch_i = get_range(padding.size('batch'), 1, 0) feat_emb = feat_emb.align_to('batch', 'char_emb', 'length') # feat_emb = self.feat_embeddings(feat_matrix).view(bs, l, -1).transpose(1, 2) # size: bs x D x l with NoName(feat_emb, masked_positions): feat_emb[batch_i, :, masked_positions] = 0.0 return feat_emb
def search(self, lengths: LT, label_log_probs: FT, gold_tag_seqs: Optional[LT] = None) -> Tuple[LT, FT]: samples, sample_log_probs = self.search_by_probs( lengths, label_log_probs) if gold_tag_seqs is not None: gold_tag_seqs = gold_tag_seqs.align_as(samples) max_length = lengths.max().item() with NoName(lengths): length_mask = get_length_mask(lengths, max_length).rename( 'batch', 'length') gold_log_probs = label_log_probs.gather('label', gold_tag_seqs) gold_log_probs = ( gold_log_probs * length_mask.align_as(gold_log_probs)).sum('length') samples = torch.cat([gold_tag_seqs, samples], dim='sample') sample_log_probs = torch.cat([gold_log_probs, sample_log_probs], dim='sample') return samples, sample_log_probs
def forward(self, positions: LT): with NoName(self.embeddings, positions): ret = self.embeddings[positions] new_names = positions.names + ('char_emb', ) return ret.refine_names(*new_names)
def forward(self, batch: ExtractBatch) -> ExtractModelReturn: """ The generating story is: v | w | x -- ww -- theta Pr(x) = sum_w Pr(w) Pr(ww) = sum_w Pr(w) theta^|ww| = sum_{w, v} Pr(w | v) Pr(v) theta^|ww| Terminologies: matched_: the prefix after selecting v score: after multiplication with |w| best_: the prefix after selecting w """ # Prepare representations. alignment = None if g.dense_input: # IDEA(j_luo) NoName shouldn't use reveal_name. Just keep the name in the context manager. with NoName(*self.unit_dense_feat_matrix.values()): unit_repr = torch.cat([ self.unit_dense_feat_matrix[cat] for cat in self.effective_categories ], dim=-1) unit_repr = unit_repr.rename('batch', 'length', 'char_emb').squeeze(dim='length') if g.input_format == 'text': ku_char_repr, word_repr = self.g2p(batch.unit_id_seqs, unit_repr) char_log_probs = (ku_char_repr @ unit_repr.t()).log_softmax( dim=-1) alignment = char_log_probs.exp() else: dfm = batch.dense_feat_matrix with Rename(*self.unit_dense_feat_matrix.values(), unit='batch'): adapted_dfm = self.adapter(dfm) with NoName(*adapted_dfm.values()): word_repr = torch.cat([ adapted_dfm[cat] for cat in self.effective_categories ], dim=-1) word_repr.rename_('batch', 'length', 'char_emb') else: with Rename(self.unit_feat_matrix, unit='batch'): word_repr = self.embedding(batch.feat_matrix, batch.source_padding) unit_repr = self.embedding(self.unit_feat_matrix) unit_repr = unit_repr.squeeze('length') unit_repr.rename_(batch='unit') # Main body: extract one span. extracted = Extracted(batch.batch_size) new_extracted = self._extract_one_span(batch, extracted, word_repr, unit_repr, char_log_probs) matches = new_extracted.matches len_e = matches.ll.size('len_e') vs = len(self.vocab) # Get the best score and span. # NOTE(j_luo) Some segments don't have any viable spans. flat_ll = matches.ll.flatten(['len_s', 'len_e', 'vocab'], 'cand') flat_viable = new_extracted.viable.expand_as(matches.ll).flatten( ['len_s', 'len_e', 'vocab'], 'cand') flat_viable_ll = (~flat_viable) * (-9999.9) + flat_ll # Add probs for unextracted characters. unextracted = batch.lengths.align_as( new_extracted.len_candidates) - new_extracted.len_candidates unextracted = unextracted.expand_as(matches.ll) flat_unextracted = unextracted.flatten(['len_s', 'len_e', 'vocab'], 'cand') flat_unextracted_ll = flat_unextracted * math.log(g.unextracted_prob) flat_total_ll = flat_viable_ll + flat_unextracted_ll # Get the top candiates based on total scores. best_matched_ll, best_span_ind = flat_total_ll.max(dim='cand') start = best_span_ind // (len_e * vs) # NOTE(j_luo) Don't forget the length is off by g.min_word_length - 1. end = best_span_ind % (len_e * vs) // vs + start + g.min_word_length - 1 best_matched_vocab = best_span_ind % vs if self.training: any_viable = new_extracted.viable.any('len_s').any('len_e') best_matched_ll = flat_total_ll.logsumexp(dim='cand') best_matched_ll = best_matched_ll * any_viable ret = ExtractModelReturn(start, end, best_matched_ll, best_matched_vocab, new_extracted, alignment) return ret
def _extract_one_span(self, batch: ExtractBatch, extracted: Extracted, word_repr: FT, unit_repr: FT, char_log_probs: FT) -> Extracted: # Propose all span start/end positions. start_candidates = get_named_range(batch.max_length, 'len_s').align_to( 'batch', 'len_s', 'len_e') # Range from `min_word_length` to `max_word_length`. len_candidates = get_named_range( g.max_word_length + 1 - g.min_word_length, 'len_e') + g.min_word_length len_candidates = len_candidates.align_to('batch', 'len_s', 'len_e') # This is inclusive. end_candidates = start_candidates + len_candidates - 1 # Only keep the viable/valid spans around. viable = (end_candidates < batch.lengths.align_as(end_candidates)) start_candidates = start_candidates.expand_as(viable) len_candidates = len_candidates.expand_as(viable) # NOTE(j_luo) Use `viable` to get the lengths. `len_candidates` has dummy axes. # IDEA(j_luo) Any better way of handling this? Perhaps persistent names? len_s = viable.size('len_s') len_e = viable.size('len_e') bi = get_named_range(batch.batch_size, 'batch').expand_as(viable) with NoName(start_candidates, end_candidates, len_candidates, bi, viable): viable_starts = start_candidates[viable].rename('viable') viable_lens = len_candidates[viable].rename('viable') viable_bi = bi[viable].rename('viable') # Get the word positions to get the corresponding representations. viable_starts = viable_starts.align_to('viable', 'len_w') word_pos_offsets = get_named_range(g.max_word_length, 'len_w').align_as(viable_starts) word_pos = viable_starts + word_pos_offsets word_pos = word_pos.clamp(max=batch.max_length - 1) # Get the corresponding representations. nh = NameHelper() viable_bi = viable_bi.expand_as(word_pos) word_pos = nh.flatten(word_pos, ['viable', 'len_w'], 'viable_X_len_w') viable_bi = nh.flatten(viable_bi, ['viable', 'len_w'], 'viable_X_len_w') word_repr = word_repr.align_to('batch', 'length', 'char_emb') if g.input_format == 'text': with NoName(word_repr, viable_bi, word_pos, batch.unit_id_seqs): extracted_word_repr = word_repr[viable_bi, word_pos].rename( 'viable_X_len_w', 'char_emb') extracted_unit_ids = batch.unit_id_seqs[ viable_bi, word_pos].rename('viable_X_len_w') else: with NoName(word_repr, viable_bi, word_pos): extracted_word_repr = word_repr[viable_bi, word_pos].rename( 'viable_X_len_w', 'char_emb') extracted_unit_ids = None extracted_word_repr = nh.unflatten(extracted_word_repr, 'viable_X_len_w', ['viable', 'len_w']) # Main body: Run DP to find the best matches. matches = self._get_matches(extracted_word_repr, unit_repr, viable_lens, extracted_unit_ids, char_log_probs) # Revert to the old shape (so that invalid spans are included). bi = get_named_range(batch.batch_size, 'batch').expand_as(viable) lsi = get_named_range(len_s, 'len_s').expand_as(viable) lei = get_named_range(len_e, 'len_e').expand_as(viable) vs = matches.ll.size('vocab') # IDEA(j_luo) NoName shouldn't make size() calls unavaiable. Otherwise size() calls have to be moved outside the context. Also the names should be preserved as well. with NoName(bi, lsi, lei, viable, matches.ll): v_bi = bi[viable] v_lsi = lsi[viable] v_lei = lei[viable] all_ll = get_zeros(batch.batch_size, len_s, len_e, vs) all_ll = all_ll.float().fill_(-9999.9) all_ll[v_bi, v_lsi, v_lei] = matches.ll matches.ll = all_ll.rename('batch', 'len_s', 'len_e', 'vocab') new_extracted = Extracted(batch.batch_size, matches, viable, len_candidates) return new_extracted
def _get_matches(self, extracted_word_repr: FT, unit_repr: FT, viable_lens: LT, extracted_unit_ids: LT, char_log_probs: FT) -> Matches: ns = extracted_word_repr.size('viable') len_w = extracted_word_repr.size('len_w') nt = len(self.vocab_feat_matrix) msl = extracted_word_repr.size('len_w') mtl = self.vocab_feat_matrix.size('length') # Compute cosine distances all at once: for each viable span, compare it against all units. ctx_logits = extracted_word_repr @ unit_repr.t() ctx_log_probs = ctx_logits.log_softmax(dim='unit').flatten( ['viable', 'len_w'], 'viable_X_len_w') with NoName(char_log_probs, extracted_unit_ids): global_log_probs = char_log_probs[extracted_unit_ids].rename( 'viable_X_len_w', 'unit') weighted_log_probs = g.context_weight * ctx_log_probs + ( 1.0 - g.context_weight) * global_log_probs costs = -weighted_log_probs # Name: viable x len_w x unit costs = costs.unflatten('viable_X_len_w', [('viable', ns), ('len_w', len_w)]) # NOTE(j_luo) Use dictionary to save every state. fs = dict() for i in range(msl + 1): fs[(i, 0)] = get_zeros(ns, nt).fill_(i * self.ins_del_cost) for j in range(mtl + 1): fs[(0, j)] = get_zeros(ns, nt).fill_(j * self.ins_del_cost) # ------------------------ Main body: DP ----------------------- # # Transition. with NoName(self.indexed_segments, costs): for ls in range(1, msl + 1): min_lt = max(ls - 2, 1) max_lt = min(ls + 2, mtl + 1) for lt in range(min_lt, max_lt): transitions = list() if (ls - 1, lt) in fs: transitions.append(fs[(ls - 1, lt)] + self.ins_del_cost) if (ls, lt - 1) in fs: transitions.append(fs[(ls, lt - 1)] + self.ins_del_cost) if (ls - 1, lt - 1) in fs: vocab_inds = self.indexed_segments[:, lt - 1] sub_cost = costs[:, ls - 1, vocab_inds] transitions.append(fs[(ls - 1, lt - 1)] + sub_cost) if transitions: all_s = torch.stack(transitions, dim=-1) new_s, _ = all_s.min(dim=-1) fs[(ls, lt)] = new_s f_lst = list() for i in range(msl + 1): for j in range(mtl + 1): if (i, j) not in fs: fs[(i, j)] = get_zeros(ns, nt).fill_(9999.9) f_lst.append(fs[(i, j)]) f = torch.stack(f_lst, dim=0).view(msl + 1, mtl + 1, -1, len(self.vocab)) f.rename_('len_w_src', 'len_w_tgt', 'viable', 'vocab') # Get the values wanted. with NoName(f, viable_lens, self.vocab_length): idx_src = viable_lens.unsqueeze(dim=-1) idx_tgt = self.vocab_length viable_i = get_range(ns, 2, 0) vocab_i = get_range(len(self.vocab_length), 2, 1) nll = f[idx_src, idx_tgt, viable_i, vocab_i] nll.rename_('viable', 'vocab') # Get the best spans. matches = Matches(-nll, f) return matches
def forward(self, input_: LT) -> FT: with NoName(self.char_embedding, input_): return self.char_embedding[input_]
def _get_Wh_s(self, h_s: FT) -> FT: sl, bs, ds = h_s.size() with NoName(h_s): Wh_s = h_s.reshape(sl * bs, -1).mm(self.Wa).view(sl, bs, -1) return Wh_s