def gumbel_softmax(logits: FT, temperature: float, num_samples: Optional[int] = None) -> Tuple[FT, FT, LT]: """Sample from the Gumbel-Softmax distribution and optionally discretize.""" logits = logits.align_to('batch', 'length', 'label') y = gumbel_softmax_sample(logits, temperature, num_samples) y = y.align_to('batch', 'length', 'label', ...) max_values, max_inds = y.max(dim='label') y_one_hot = (max_values.align_as(y) == y).float() y_one_hot = (y_one_hot - y).detach() + y bi = get_named_range(logits.size('batch'), 'batch').align_as(max_inds) li = get_named_range(logits.size('length'), 'length').align_as(max_inds) if num_samples is None: with NoName(max_inds, y_one_hot, bi, li): probs = y_one_hot[bi, li, max_inds] probs.rename_('batch', 'length') else: si = get_named_range(max_inds.size('sample'), 'sample').align_as(max_inds) with NoName(max_inds, y_one_hot, bi, li, si): probs = y_one_hot[bi, li, max_inds, si] probs.rename_('batch', 'length', 'sample') seq_probs = (1e-8 + probs).log().sum(dim='length').exp() return y, y_one_hot, max_inds, seq_probs
def _train_one_step_mle(self, batch: OnePairBatch) -> Metrics: """Train for one step using maximum likelihood.""" log_probs, almt_distrs = self.model(batch) metrics = Metrics() # Cross-entropy loss. ce_loss = get_ce_loss(log_probs, batch, agg='all') ce_loss = Metric('ce_loss', ce_loss, len(batch)) metrics += ce_loss # Compute alignment regularization loss if needed. if g.almt_reg_hyper > 0: sl = almt_distrs.size("src_pos") pos = get_named_range(sl, 'src_pos').float() mean_pos = (pos.align_as(almt_distrs) * almt_distrs).sum(dim='src_pos') mean_pos = mean_pos.align_to('batch', 'tgt_pos') mean_pos = torch.cat([get_zeros(len(batch), 1), mean_pos], dim=-1) src_lengths = batch.src_seqs.lengths.float().rename(None) reg_weight = src_lengths.unsqueeze(dim=-1) - 1.0 - mean_pos[:, :-1] reg_weight.clamp_(0.0, 1.0) rel_pos = mean_pos[:, 1:] - mean_pos[:, :-1] # bs x tl rel_pos_diff = rel_pos - 1 margin = rel_pos_diff != 0 almt_reg = margin.float() * (rel_pos_diff**2) # bs x tl almt_reg = (almt_reg * reg_weight).sum() almt_reg = Metric('almt_reg', almt_reg, len(batch)) metrics += almt_reg loss = ce_loss.mean + g.almt_reg_hyper * almt_reg.mean else: loss = ce_loss.mean loss = Metric('loss', loss * len(batch), len(batch)) metrics += loss return metrics
def trace_back(self, *attr_names: str) -> Dict[str, torch.Tensor]: """Trace back some attribute by going backwards through the beam search procedure.""" beam_i = get_named_range(self.beam_size, 'beam').expand_as(self.beam_ids) batch_i = get_named_range(self.batch_size, 'batch').expand_as(beam_i) beam = self ret = defaultdict(list) while beam.last_beam is not None: with NoName(beam.beam_ids, beam_i, batch_i): for attr_name in attr_names: attr = getattr(beam, attr_name) with NoName(attr): ret[attr_name].append(attr[batch_i, beam_i]) beam_i = beam.beam_ids[batch_i, beam_i] beam = beam.last_beam for attr_name in attr_names: # NOTE(j_luo) Reverse the list since we are going backwards. last_name = 'src_pos' if attr_name == 'almt' else None ret[attr_name] = _stack_beam(ret[attr_name][::-1], last_name=last_name) return ret
def _sample(self, label_probs: FT, sampling_probs: FT, source_padding: FT, gold_tag_seqs: Optional[FT] = None) -> Tuple[LT, FT]: """Return samples based on `label_probs`.""" # Ignore padded indices. label_probs = label_probs.align_to('batch', 'length', 'label') sampling_probs = sampling_probs.align_to('batch', 'length', 'label') source_padding = source_padding.align_to('batch', 'length') # Get packed batches. label_distr = Categorical(probs=sampling_probs.rename(None)) label_samples = label_distr.sample([g.num_samples]).refine_names( 'sample', 'batch', 'length') label_samples = label_samples.align_to('batch', 'sample', 'length') # Add the ground truth if needed. if gold_tag_seqs is not None: gold_tag_seqs = gold_tag_seqs.align_as(label_samples) all_other_tag_seqs = torch.full_like(gold_tag_seqs, O) label_samples = torch.cat( [gold_tag_seqs, all_other_tag_seqs, label_samples], dim='sample') batch_idx = get_named_range( label_samples.size('batch'), 'batch').align_as(label_samples).rename(None) length_idx = get_named_range( label_samples.size('length'), 'length').align_as(label_samples).rename(None) label_sample_probs = label_probs.rename(None)[ batch_idx, length_idx, label_samples.rename(None)] label_sample_probs = label_sample_probs.refine_names( *label_samples.names) label_sample_log_probs = (1e-8 + label_sample_probs).log() label_sample_log_probs = ( (~source_padding).align_as(label_sample_log_probs).float() * label_sample_log_probs).sum(dim='length') return label_samples, label_sample_log_probs
def get_next_beam(self, beam: Beam, cand: Candidates) -> Beam: nh = NameHelper() # Get the new scores. For finished hypotheses, we should keep adding EOT. placeholder = torch.full_like(cand.log_probs, -9999.9) placeholder[..., EOT_ID] = 0.0 new_scores = torch.where(beam.finished.align_as(placeholder), placeholder, cand.log_probs) accum = new_scores + beam.accum_scores.align_as(cand.log_probs) lp = nh.flatten(accum, ['beam', 'unit'], 'BU') top_s, top_i = torch.topk(lp, beam.beam_size, dim='BU') num_units = accum.size('unit') beam_i = top_i // num_units tokens = top_i % num_units batch_i = get_named_range(beam.batch_size, 'batch') batch_i = batch_i.align_as(top_i) def retrieve(tensor, last_name: str = 'hidden') -> torch.Tensor: with NoName(tensor, batch_i, beam_i): ret = tensor[batch_i, beam_i] new_names = ('batch', 'beam') if last_name: new_names += (last_name, ) return ret.refine_names(*new_names) next_scores = top_s.rename(BU='beam') next_tokens = tokens.rename(BU='beam') next_beam_ids = beam_i.rename(BU='beam') next_state = cand.state.apply(retrieve) next_almt = retrieve(cand.almt, last_name='tgt_pos') next_att = retrieve(cand.att, last_name='hidden') if g.input_feeding else None last_finished = retrieve(beam.finished, last_name=None) this_ended = next_tokens == EOT_ID reached_max = (beam.step + 1 == beam.constants.max_lengths) next_finished = last_finished | this_ended | reached_max next_beam = beam.follow(next_finished, next_scores, next_tokens, next_state, next_beam_ids, next_almt, prev_att=next_att) return next_beam
def finish_search(self, lengths: LT): last_beam_id = get_zeros(lengths.size('batch'), g.beam_size).long().rename('batch', 'beam') start_beam_id = get_named_range(g.beam_size, 'beam').align_as(last_beam_id) samples = list() for i, (hyp, beam_id) in enumerate( zip(reversed(self.hyps), reversed(self.beam_ids))): step = len(self.beam_ids) - i start_backtrack = (step == lengths).align_as(beam_id) # new_last_beam_id = beam_id.gather('beam', last_beam_id) this_beam_id = torch.where(start_backtrack, start_beam_id, last_beam_id) samples.append(hyp.gather('beam', this_beam_id)) last_beam_id = beam_id.gather('beam', this_beam_id) self.samples = torch.stack(samples[::-1], new_name='length') hyp_log_probs = torch.stack(self.hyp_log_probs, new_name='length') self.sample_log_probs = hyp_log_probs.gather( 'length', lengths.align_as(hyp_log_probs)).squeeze('length')
def _extract_one_span(self, batch: ExtractBatch, extracted: Extracted, word_repr: FT, unit_repr: FT, char_log_probs: FT) -> Extracted: # Propose all span start/end positions. start_candidates = get_named_range(batch.max_length, 'len_s').align_to( 'batch', 'len_s', 'len_e') # Range from `min_word_length` to `max_word_length`. len_candidates = get_named_range( g.max_word_length + 1 - g.min_word_length, 'len_e') + g.min_word_length len_candidates = len_candidates.align_to('batch', 'len_s', 'len_e') # This is inclusive. end_candidates = start_candidates + len_candidates - 1 # Only keep the viable/valid spans around. viable = (end_candidates < batch.lengths.align_as(end_candidates)) start_candidates = start_candidates.expand_as(viable) len_candidates = len_candidates.expand_as(viable) # NOTE(j_luo) Use `viable` to get the lengths. `len_candidates` has dummy axes. # IDEA(j_luo) Any better way of handling this? Perhaps persistent names? len_s = viable.size('len_s') len_e = viable.size('len_e') bi = get_named_range(batch.batch_size, 'batch').expand_as(viable) with NoName(start_candidates, end_candidates, len_candidates, bi, viable): viable_starts = start_candidates[viable].rename('viable') viable_lens = len_candidates[viable].rename('viable') viable_bi = bi[viable].rename('viable') # Get the word positions to get the corresponding representations. viable_starts = viable_starts.align_to('viable', 'len_w') word_pos_offsets = get_named_range(g.max_word_length, 'len_w').align_as(viable_starts) word_pos = viable_starts + word_pos_offsets word_pos = word_pos.clamp(max=batch.max_length - 1) # Get the corresponding representations. nh = NameHelper() viable_bi = viable_bi.expand_as(word_pos) word_pos = nh.flatten(word_pos, ['viable', 'len_w'], 'viable_X_len_w') viable_bi = nh.flatten(viable_bi, ['viable', 'len_w'], 'viable_X_len_w') word_repr = word_repr.align_to('batch', 'length', 'char_emb') if g.input_format == 'text': with NoName(word_repr, viable_bi, word_pos, batch.unit_id_seqs): extracted_word_repr = word_repr[viable_bi, word_pos].rename( 'viable_X_len_w', 'char_emb') extracted_unit_ids = batch.unit_id_seqs[ viable_bi, word_pos].rename('viable_X_len_w') else: with NoName(word_repr, viable_bi, word_pos): extracted_word_repr = word_repr[viable_bi, word_pos].rename( 'viable_X_len_w', 'char_emb') extracted_unit_ids = None extracted_word_repr = nh.unflatten(extracted_word_repr, 'viable_X_len_w', ['viable', 'len_w']) # Main body: Run DP to find the best matches. matches = self._get_matches(extracted_word_repr, unit_repr, viable_lens, extracted_unit_ids, char_log_probs) # Revert to the old shape (so that invalid spans are included). bi = get_named_range(batch.batch_size, 'batch').expand_as(viable) lsi = get_named_range(len_s, 'len_s').expand_as(viable) lei = get_named_range(len_e, 'len_e').expand_as(viable) vs = matches.ll.size('vocab') # IDEA(j_luo) NoName shouldn't make size() calls unavaiable. Otherwise size() calls have to be moved outside the context. Also the names should be preserved as well. with NoName(bi, lsi, lei, viable, matches.ll): v_bi = bi[viable] v_lsi = lsi[viable] v_lei = lei[viable] all_ll = get_zeros(batch.batch_size, len_s, len_e, vs) all_ll = all_ll.float().fill_(-9999.9) all_ll[v_bi, v_lsi, v_lei] = matches.ll matches.ll = all_ll.rename('batch', 'len_s', 'len_e', 'vocab') new_extracted = Extracted(batch.batch_size, matches, viable, len_candidates) return new_extracted
def forward( self, batch: Union[ContinuousIpaBatch, IpaBatch]) -> DecipherModelReturn: # Get the samples of label sequences first. out = self.emb_for_label(batch.feat_matrix, batch.source_padding) positions = get_named_range(batch.feat_matrix.size('length'), name='length') pos_emb = self.positional_embedding(positions).align_as(out) out = out + pos_emb out = out.align_to('length', 'batch', 'char_emb') with NoName(out, batch.source_padding): for i, layer in enumerate(self.self_attn_layers): out = layer(out, src_key_padding_mask=batch.source_padding) state = out.refine_names('length', 'batch', ...) logits = self.label_predictor(state) label_log_probs = logits.log_softmax(dim='label') label_probs = label_log_probs.exp() # NOTE(j_luo) O is equivalent to None. mask = expand_as(batch.source_padding, label_probs) source = expand_as( get_tensor([0.0, 0.0, 1.0]).refine_names('label').float(), label_probs) label_probs = label_probs.rename(None).masked_scatter( mask.rename(None), source.rename(None)) label_probs = label_probs.refine_names('length', 'batch', 'label') if not self.training or (g.supervised and not g.train_phi): probs = DecipherModelProbReturn(label_log_probs, None) return DecipherModelReturn(state, probs, None, None, None, None, None) # ------------------ More info during training ----------------- # # Get the lm score. gold_tag_seqs = batch.gold_tag_seqs if g.supervised and g.train_phi else None samples, sample_log_probs = self.searcher.search( batch.lengths, label_log_probs, gold_tag_seqs=gold_tag_seqs) probs = DecipherModelProbReturn(label_log_probs, sample_log_probs) packed_words, scores = self._get_scores(samples, batch.segments, batch.lengths, batch.feat_matrix, batch.source_padding) if g.supervised and g.train_phi: return DecipherModelReturn(state, probs, packed_words, None, scores, None, None) # ------------------- Contrastive estimation ------------------- # ptb_segments = list() duplicates = list() for segment in batch.segments: _ptb_segments, _duplicates = segment.perturb_n_times(g.n_times) # NOTE(j_luo) Ignore the first one. ptb_segments.extend(_ptb_segments[1:]) duplicates.extend(_duplicates[1:]) # ptb_segments = [segment.perturb_n_times(5) for segment in batch.segments] ptb_feat_matrix = [segment.feat_matrix for segment in ptb_segments] ptb_feat_matrix = torch.nn.utils.rnn.pad_sequence(ptb_feat_matrix, batch_first=True) ptb_feat_matrix.rename_('batch', 'length', 'feat_group') samples = samples.align_to('batch', ...) with NoName(samples, batch.lengths, batch.source_padding): ptb_samples = torch.repeat_interleave(samples, g.n_times * 2, dim=0) ptb_lengths = torch.repeat_interleave(batch.lengths, g.n_times * 2, dim=0) ptb_source_padding = torch.repeat_interleave(batch.source_padding, g.n_times * 2, dim=0) ptb_samples.rename_(*samples.names) ptb_lengths.rename_('batch') ptb_source_padding.rename_('batch', 'length') ptb_packed_words, ptb_scores = self._get_scores( ptb_samples, ptb_segments, ptb_lengths, ptb_feat_matrix, ptb_source_padding) ret = DecipherModelReturn(state, probs, packed_words, ptb_packed_words, scores, ptb_scores, duplicates) return ret