def _get_model(self, dl=None): phono_feat_mat = special_ids = None if g.use_phono_features: phono_feat_mat = get_tensor(self.src_abc.pfm) special_ids = get_tensor(self.src_abc.special_ids) phono_kwargs = { 'phono_feat_mat': phono_feat_mat, 'special_ids': special_ids } if g.use_rl: end_state = self.env.end agent_cls = VanillaPolicyGradient if g.agent == 'vpg' else A2C model = agent_cls(len(self.tgt_abc), self.env, end_state, **phono_kwargs) else: model = OnePairModel(len(self.src_abc), len(self.tgt_abc), **phono_kwargs) if g.saved_model_path is not None: model.load_state_dict( torch.load(g.saved_model_path, map_location=torch.device('cpu'))) logging.imp(f'Loaded from {g.saved_model_path}.') if has_gpus(): model.cuda() logging.info(model) return model
def evaluate(self, states, steps: Optional[Union[int, LT]] = None) -> List[float]: """Expand and evaluate the leaf node.""" values = [None] * len(states) outstanding_idx = list() outstanding_states = list() # Deal with end states first. for i, state in enumerate(states): if state.stopped or state.done: # NOTE(j_luo) This value is used for backup. If already reaching the end state, the final reward is either accounted for by the step reward, or by the value network. Therefore, we need to set it to 0.0 here. values[i] = 0.0 else: outstanding_idx.append(i) outstanding_states.append(state) # Collect states that need evaluation. if outstanding_states: almts1 = almts2 = None if g.use_alignment: id_seqs, almts1, almts2 = parallel_stack_ids( outstanding_states, g.num_workers, True, self.env.max_end_length) almts1 = get_tensor(almts1).rename('batch', 'word', 'pos') almts2 = get_tensor(almts2).rename('batch', 'word', 'pos') else: id_seqs = parallel_stack_ids(outstanding_states, g.num_workers, False, self.env.max_end_length) id_seqs = get_tensor(id_seqs).rename('batch', 'word', 'pos') if steps is not None and not isinstance(steps, int): steps = steps[outstanding_idx] # TODO(j_luo) Scoped might be wrong here. # with ScopedCache('state_repr'): # NOTE(j_luo) Don't forget to call exp(). priors = self.agent.get_policy(id_seqs, almts=(almts1, almts2)).exp() with NoName(priors): meta_priors = priors[:, [0, 2, 3, 4, 5, 6]].cpu().numpy() special_priors = priors[:, 1].cpu().numpy() if g.use_value_guidance: agent_values = self.agent.get_values( id_seqs, steps=steps).cpu().numpy() else: agent_values = np.zeros([len(id_seqs)], dtype='float32') for i, state, mp, sp, v in zip(outstanding_idx, outstanding_states, meta_priors, special_priors, agent_values): # NOTE(j_luo) Values should be returned even if states are duplicates or have been visited. values[i] = v # NOTE(j_luo) Skip duplicate states (due to exploration collapse) or visited states (due to rollout truncation). if not state.is_leaf(): continue # print(mp[1, 111]) self.env.evaluate(state, mp, sp) return values
def collate_fn(batch): words = _get_item('word', batch) forms = _get_item('form', batch) char_seqs = _get_item('char_seq', batch) id_seqs = _get_item('id_seq', batch) lengths, words, forms, char_seqs, id_seqs = sort_all(words, forms, char_seqs, id_seqs) lengths = get_tensor(lengths, dtype='l') # Trim the id_seqs. max_len = max(lengths).item() id_seqs = pad_to_dense(id_seqs, dtype='l') id_seqs = get_tensor(id_seqs[:, :max_len]) lang = batch[0].lang return Map( words=words, forms=forms, char_seqs=char_seqs, id_seqs=id_seqs, lengths=lengths, lang=lang)
def analyze(self, log_probs, almt_distr, words, lost_lengths): self.clear_cache() self._sample(words) assert self._eff_max_length == len(log_probs) tl, nc, bs = log_probs.shape charset = get_charset(self.lang) assert nc == len(charset) # V x bs, or c_s x c_t -> bs x V valid_log_probs = self._eff_weight.matmul(log_probs.view(-1, bs)).t() sl = almt_distr.shape[-1] pos = get_tensor(torch.arange(sl).float(), requires_grad=False) mean_pos = (pos * almt_distr).sum(dim=-1) # bs x tl mean_pos = torch.cat([get_zeros(bs, 1, requires_grad=False).fill_(-1.0), mean_pos], dim=-1) reg_weight = lost_lengths.float().view(-1, 1) - 1.0 - mean_pos[:, :-1] reg_weight.clamp_(0.0, 1.0) rel_pos = mean_pos[:, 1:] - mean_pos[:, :-1] # bs x tl rel_pos_diff = rel_pos - 1 margin = rel_pos_diff != 0 reg_loss = margin.float() * (rel_pos_diff ** 2) # bs x tl reg_loss = (reg_loss * reg_weight).sum() return Map(reg_loss=reg_loss, valid_log_probs=valid_log_probs)
def __init__(self, lost_lang, known_lang, momentum, num_cognates): super().__init__() lost_words = get_words(lost_lang) known_words = get_words(known_lang) flow = get_tensor(np.zeros([len(lost_words), len(known_words)])) self.flow = MagicTensor(flow, lost_words, known_words) self._warmed_up = False
def _get_batch_records(self, dl: OnePairDataLoader, batch: OnePairBatch, K: int) -> List[Dict[str, Any]]: hyps = self.model.predict(batch) # NOTE(j_luo) EOT's have been removed from translations since they don't matter in edit distance computation. preds, pred_lengths, _ = hyps.translate(self.tgt_abc) # HACK(j_luo) Pretty ugly here. dists = compute_edit_dist(g.comp_mode, pred_ids=hyps.tokens, lengths=dl.tgt_seqs.lengths, gold_ids=dl.tgt_seqs.ids.t(), forms=dl.tgt_vocabulary.forms, units=dl.tgt_seqs.units, predictions=preds, pred_lengths=pred_lengths, pfm=self.tgt_abc.pfm) weights = get_beam_probs(hyps.scores) w_dists = weights.align_to(..., 'tgt_vocab') * dists expected_dists = w_dists.sum(dim='beam') top_s, top_i = torch.topk(-expected_dists, K, dim='tgt_vocab') top_s = -top_s records = list() tgt_vocab = dl.tgt_vocabulary # In order to record the edit distance between the top prediction and the ground truth, # we need to find the index of the ground truth in the vocabulary, not in the dataset. tgt_ids = get_tensor( [tgt_vocab.get_id_by_form(form) for form in batch.tgt_seqs.forms]) dists_tx = Tx(dists, ['batch', 'beam', 'tgt_vocab']) tgt_ids_tx = Tx(tgt_ids, ['batch']) top_dists_tx = dists_tx.select('beam', 0) top_dists = top_dists_tx.each_select({'tgt_vocab': tgt_ids_tx}).data normalized_top_dists = top_dists.float() / (batch.tgt_seqs.lengths - 1) top_dists = top_dists.cpu().numpy() normalized_top_dists = normalized_top_dists.cpu().numpy() # We also report the perplexity scores. log_probs, _ = self.model(batch) ppxs = get_ce_loss(log_probs, batch, agg='batch_mean') ppxs = ppxs.cpu().numpy() for pss, pis, src, gold, pbis, pbss, top_dist, n_top_dist, ppx in zip( top_s, top_i, batch.src_seqs.forms, batch.tgt_seqs.forms, preds, weights, top_dists, normalized_top_dists, ppxs): record = { 'source': src, 'gold_target': gold, 'edit_dist': top_dist, 'normalized_edit_dist': n_top_dist, 'ppx': ppx } for i, (pbs, pbi) in enumerate(zip(pbss, pbis), 1): record[f'pred_target_beam@{i}'] = pbi record[f'pred_target_beam@{i}_score'] = f'{pbs.item():.3f}' for i, (ps, pi) in enumerate(zip(pss, pis), 1): pred_closest = tgt_vocab[pi]['form'] record[f'pred_target@{i}'] = pred_closest record[f'pred_target@{i}_score'] = f'{ps.item():.3f}' records.append(record) return records
def _prepare_weight(self): rows = list() cols = list() words = get_words(self.lang) charset = get_charset(self.lang) self._word2rows = defaultdict(list) for row, word in enumerate(words): for i, c in enumerate(word.char_seq): cid = charset.char2id(c) self._word2rows[word].append(len(rows)) rows.append(row) cols.append(len(charset) * i + cid) data = np.ones(len(rows)) # NOTE This is ugly, but it avoids this issue in 0.4.1: https://github.com/pytorch/pytorch/issues/8856. weight = torch.sparse.FloatTensor( get_tensor([rows, cols], dtype='l', use_cuda=False), get_tensor(data, dtype='f', use_cuda=False), (len(words), self._max_length * len(charset))) self._weight = get_tensor(weight)
def _sample(self, words): all_words = get_words(self.lang) word_indices = list() old_to_new = np.zeros([len(all_words)], dtype='int64') self._eff_id2word = list() self._eff_word2id = dict() self._eff_max_length = max(map(len, words)) for w in words: word_indices.extend(self._word2rows[w]) old_to_new[w.idx] = len(self._eff_id2word) self._eff_id2word.append(w) self._eff_word2id[w] = len(self._eff_word2id) old_to_new = get_tensor(old_to_new) indices = get_tensor(word_indices, dtype='l') old_rows, cols = self._weight._indices()[:, word_indices].unbind(dim=0) rows = old_to_new[old_rows] data = self._weight._values()[word_indices] charset = get_charset(self.lang) weight = torch.sparse.FloatTensor( torch.stack([rows, cols], dim=0), data, (len(words), self._eff_max_length * len(charset))) self._eff_weight = get_tensor(weight)
def _train_one_step_mrt(self, batch: OnePairBatch, abc: Alphabet) -> Metrics: """Train for one step using minimum risk training (MRT).""" # Get scores for top predictions and the target sequence. assert g.comp_mode == 'str' assert g.dropout == 0.0, 'Might have some issue due to the fact that ground truth is fed through the normal forward call whereas hyps are not, resulting in discrepant dropout applications.' hyps = self.model.predict(batch) log_probs, _ = self.model(batch) tgt_scores = -get_ce_loss(log_probs, batch, agg='batch') # Mark which ones are duplicates. preds, _, _ = hyps.translate(abc) duplicates = list() for beam_preds, tgt_form in zip(preds, batch.tgt_seqs.forms): duplicates.append([False] + [p == tgt_form for p in beam_preds]) # If no duplicates are found, then we discard the last prediction. if not any(duplicates[-1]): duplicates[-1][-1] = True if sum(duplicates[-1]) != 1: raise RuntimeError(f'Should have exactly one duplicate.') duplicates = get_tensor(duplicates) # Assemble all scores together. scores = torch.cat([tgt_scores.align_as(hyps.scores), hyps.scores], dim='beam') probs = get_beam_probs(scores, duplicates=duplicates) target = np.tile(batch.tgt_seqs.forms.reshape(-1, 1), [1, g.beam_size + 1]) preds = np.concatenate([target[:, 0:1], preds], axis=-1) dists = edit_dist_batch(preds.reshape(-1), target.reshape(-1), 'ed') lengths = batch.tgt_seqs.lengths.align_to('batch', 'beam') dists = get_tensor(dists.reshape(-1, g.beam_size + 1)).float() # / lengths # risk = (probs * (dists ** 2)).sum(dim='beam') risk = (probs * dists).sum(dim='beam') risk = Metric('risk', risk.sum(), len(batch)) return Metrics(risk)
def search_by_probs(self, lengths: LT, label_log_probs: FT) -> Tuple[LT, FT]: max_length = lengths.max().item() samples = get_tensor( torch.LongTensor(list(product([B, I, O], repeat=max_length)))) samples.rename_('sample', 'length') bs = label_log_probs.size('batch') samples = samples.align_to('batch', 'sample', 'length').expand(bs, -1, -1) sample_log_probs = label_log_probs.gather('label', samples) with NoName(lengths): length_mask = get_length_mask(lengths, max_length).rename( 'batch', 'length') length_mask = length_mask.align_to(sample_log_probs) sample_log_probs = (sample_log_probs * length_mask.float()).sum(dim='length') return samples, sample_log_probs
type=str, help='Path to save the output. No suffix should be included.') args = parser.parse_args() if '.' in args.out_name: raise ValueError(f'No suffix should be included.') initiator = setup() initiator.run(saved_g_path=args.saved_g_path) _, _, abc, _ = OneToManyManager.prepare_raw_data() assert g.share_src_tgt_abc sd = torch.load(args.saved_model_path) emb_params = get_emb_params(len(abc), phono_feat_mat=get_tensor(abc.pfm), special_ids=get_tensor(abc.special_ids)) emb = PhonoEmbedding.from_params(emb_params) prefix = 'encoder.embedding' emb.load_state_dict({ 'weight': sd[f'{prefix}.weight'], 'special_weight': sd[f'{prefix}.special_weight'], 'special_mask': sd[f'{prefix}.special_mask'], 'pfm': sd[f'{prefix}.pfm'] }) emb.cuda() char_emb = emb.char_embedding.detach().cpu().numpy() size = char_emb.shape[-1] cols = [f'vec_{i}' for i in range(size)] df = pd.DataFrame(char_emb, columns=cols)
def collect_episodes(self, init_state: VocabState, tracker: Optional[Tracker] = None, num_episodes: int = 0, is_eval: bool = False, no_simulation: bool = False) -> List[Trajectory]: trajectories = list() self.agent.eval() if is_eval: self.eval() else: self.train() num_episodes = num_episodes or g.num_episodes # if no_simulation: # breakpoint() # BREAKPOINT(j_luo) with self.agent.policy_grad(False), self.agent.value_grad(False): for ei in range(num_episodes): root = init_state self.reset() steps = 0 if g.use_finite_horizon else None self.evaluate([root], steps=steps) # Episodes have max rollout length. played_path = None for ri in range(g.max_rollout_length): if not is_eval: self.add_noise(root) if is_eval and no_simulation: new_state = self.select_one_pi_step(root) steps = steps + 1 if g.use_finite_horizon else None values = self.evaluate([new_state], steps=steps) else: # Run many simulations before take one action. Simulations take place in batches. Each batch # would be evaluated and expanded after batched selection. num_batches = g.num_mcts_sims // g.expansion_batch_size for _ in range(num_batches): paths, steps = self.select(root, g.expansion_batch_size, ri, g.max_rollout_length, played_path) steps = get_tensor( steps) if g.use_finite_horizon else None new_states = [ path.get_last_node() for path in paths ] values = self.evaluate(new_states, steps=steps) self.backup(paths, values) if tracker is not None: tracker.update('mcts', incr=g.expansion_batch_size) if ri == 0 and ei % g.episode_check_interval == 0: k = min(20, root.num_actions) logging.debug( pad_for_log( str( get_tensor( root.action_counts).topk(k)))) logging.debug( pad_for_log(str(get_tensor(root.q).topk(k)))) logging.debug( pad_for_log( str(get_tensor(root.max_values).topk(k)))) ps = self.play_strategy if is_eval: if no_simulation: ps = PyPS_SAMPLE_AC else: ps = PyPS_MAX new_path = self.play(root, ri, ps, g.exponent) if played_path is None: played_path = new_path else: played_path.merge(new_path) root = played_path.get_last_node() # print('3') if tracker is not None: tracker.update('rollout') if root.stopped or root.done: break # self.show_stats() trajectory = Trajectory(played_path, self.env.max_end_length) if ei % g.episode_check_interval == 0: logging.debug(pad_for_log(str(trajectory))) trajectories.append(trajectory) if tracker is not None: tracker.update('episode') # if no_simulation: # breakpoint() # BREAKPOINT(j_luo) return trajectories
def __init__(self): all_tgt, self.cog_reg, self.src_abc, self.tgt_abc = self.prepare_raw_data( ) # Get stats for unseen units. stats = self.tgt_abc.stats _, test_tgt_path = get_paths(g.data_path, g.src_lang, g.tgt_lang) mask = (stats.sum() == stats.loc[test_tgt_path]) unseen = mask[mask].index.tolist() total = len(stats.loc[test_tgt_path].dropna()) logging.info( f'Unseen units ({len(unseen)}/{total}) for {g.tgt_lang} are: {unseen}.' ) # Get language-to-id mappings. Used only for the targets (i.e., decoder side). self.lang2id = lang2id = {tgt: i for i, tgt in enumerate(all_tgt)} # Get all data loaders. self.dl_reg = DataLoaderRegistry() def create_setting(name: str, tgt_lang: str, split: Split, for_training: bool, keep_ratio: Optional[float] = None, tgt_sot: bool = False) -> Setting: return Setting(name, 'one_pair', split, g.src_lang, tgt_lang, for_training, keep_ratio=keep_ratio, tgt_sot=tgt_sot) test_setting = create_setting(f'test@{g.tgt_lang}', g.tgt_lang, Split('all'), False, keep_ratio=g.test_keep_ratio) settings: List[Setting] = [test_setting] # Get the training languages. for train_tgt_lang in g.train_tgt_langs: if g.input_format == 'ielex': train_split = Split( 'train', [1, 2, 3, 4]) # Use the first four folds for training. dev_split = Split('dev', [5]) # Use the last fold for dev. else: train_split = Split('train') dev_split = Split('dev') train_setting = create_setting(f'train@{train_tgt_lang}', train_tgt_lang, train_split, True, keep_ratio=g.keep_ratio) train_e_setting = create_setting(f'train@{train_tgt_lang}_e', train_tgt_lang, train_split, False, keep_ratio=g.keep_ratio) dev_setting = create_setting(f'dev@{train_tgt_lang}', train_tgt_lang, dev_split, False) test_setting = create_setting(f'test@{train_tgt_lang}', train_tgt_lang, Split('test'), False) settings.extend( [train_setting, train_e_setting, dev_setting, test_setting]) for setting in settings: self.dl_reg.register_data_loader(setting, self.cog_reg, lang2id=lang2id) phono_feat_mat = special_ids = None if g.use_phono_features: phono_feat_mat = get_tensor(self.src_abc.pfm) special_ids = get_tensor(self.src_abc.special_ids) self.model = OneToManyModel(len(self.src_abc), len(self.tgt_abc), len(g.train_tgt_langs) + 1, lang2id[g.tgt_lang], lang2id=lang2id, phono_feat_mat=phono_feat_mat, special_ids=special_ids) if g.saved_model_path is not None: self.model.load_state_dict( torch.load(g.saved_model_path, map_location=torch.device('cpu'))) logging.imp(f'Loaded from {g.saved_model_path}.') if has_gpus(): self.model.cuda() logging.info(self.model) metric_writer = MetricWriter(g.log_dir, flush_secs=5) # NOTE(j_luo) Evaluate on every loader that is not for training. eval_dls = self.dl_reg.get_loaders_by_name( lambda name: 'train' not in name or '_e' in name) self.evaluator = Evaluator(self.model, eval_dls, self.tgt_abc, metric_writer=metric_writer) if not g.evaluate_only: train_names = [ f'train@{train_tgt_lang}' for train_tgt_lang in g.train_tgt_langs ] train_settings = [ self.dl_reg.get_setting_by_name(name) for name in train_names ] self.trainer = Trainer(self.model, train_settings, [1.0] * len(train_settings), 'step', stage_tnames=['step'], evaluator=self.evaluator, check_interval=g.check_interval, eval_interval=g.eval_interval, save_interval=g.save_interval, metric_writer=metric_writer) if g.saved_model_path is None: # self.trainer.init_params('uniform', -0.1, 0.1) self.trainer.init_params('xavier_uniform') optim_cls = Adam if g.optim_cls == 'adam' else SGD self.trainer.set_optimizer(optim_cls, lr=g.learning_rate)
def compute_edit_dist(comp_mode: str, pred_ids: Optional[LT] = None, lengths: Optional[LT] = None, gold_ids: Optional[LT] = None, forms: Optional[np.ndarray] = None, predictions=None, pred_lengths: Optional[np.ndarray] = None, units: Optional[np.ndarray] = None, pfm=None) -> FT: if comp_mode == 'ids_gpu': # Prepare tensorx. pred_lengths = Tx(get_tensor(pred_lengths), ['pred_batch', 'beam']) pred_tokens = Tx(pred_ids, ['pred_batch', 'beam', 'l']) # NOTE(j_luo) -1 for removing EOT's. tgt_lengths = Tx(lengths, ['tgt_batch']) - 1 tgt_tokens = Tx(gold_ids, ['tgt_batch', 'l']) # Align them to the same names. new_names = ['pred_batch', 'beam', 'tgt_batch'] pred_lengths = pred_lengths.align_to(*new_names) pred_tokens = pred_tokens.align_to(*(new_names + ['l'])) tgt_lengths = tgt_lengths.align_to(*new_names) tgt_tokens = tgt_tokens.align_to(*(new_names + ['l'])) # Expand them to have the same size. pred_bs = pred_tokens.size('pred_batch') tgt_bs = tgt_tokens.size('tgt_batch') pred_lengths = pred_lengths.expand({'tgt_batch': tgt_bs}) pred_tokens = pred_tokens.expand({'tgt_batch': tgt_bs}) tgt_tokens = tgt_tokens.expand({ 'pred_batch': pred_bs, 'beam': g.beam_size }) tgt_lengths = tgt_lengths.expand({ 'pred_batch': pred_bs, 'beam': g.beam_size }) # Flatten names, preparing for DP. def flatten(tx): return tx.flatten(['pred_batch', 'beam', 'tgt_batch'], 'batch') pred_lengths = flatten(pred_lengths) pred_tokens = flatten(pred_tokens) tgt_lengths = flatten(tgt_lengths) tgt_tokens = flatten(tgt_tokens) penalty = None if g.use_phono_edit_dist: x = pfm.rename('src_unit', 'phono_feat') y = pfm.rename('tgt_unit', 'phono_feat') names = ('src_unit', 'tgt_unit', 'phono_feat') diff = x.align_to(*names) - y.align_to(*names) penalty = (diff != 0).sum( 'phono_feat').cuda().float() / g.phono_edit_dist_scale dp = EditDist(pred_tokens, tgt_tokens, pred_lengths, tgt_lengths, penalty=penalty) dp.run() dists = dp.get_results().data dists = dists.view(pred_bs, g.beam_size, tgt_bs) else: eval_all = lambda seqs_0, seqs_1: edit_dist_all( seqs_0, seqs_1, mode='ed') flat_preds = predictions.reshape(-1) if comp_mode == 'units': # NOTE(j_luo) Remove EOT's. flat_golds = [u[:-1] for u in units] elif comp_mode == 'ids': tgt_ids = gold_ids.cpu().numpy() # NOTE(j_luo) -1 for EOT's. tgt_lengths = lengths - 1 flat_golds = [ids[:l] for ids, l in zip(tgt_ids, tgt_lengths)] elif comp_mode == 'str': flat_golds = forms dists = get_tensor(eval_all(flat_preds, flat_golds)).view(-1, g.beam_size, len(flat_golds)) return dists
def compute_expected_edits(known_charset, log_probs, wordlist, valid_log_probs, num_samples=10, alpha=1e1, edit=False): logging.debug('Computing expected edits') log_probs = log_probs.transpose(0, 2).transpose(1, 2) # size: bs x tl x C log_probs = torch.log_softmax(log_probs * alpha, dim=-1) probs = log_probs.exp() bs, tl, nc = probs.shape # get samples if num_samples > 0: samples = torch.multinomial(probs.reshape(bs * tl, nc), num_samples, replacement=True) samples = samples.view(bs, tl, num_samples) # get tokens tokens = known_charset.get_tokens(samples.transpose( 1, 2)) # size: bs x num_samples # get probs sample_log_probs = log_probs[torch.arange(bs).long().view(-1, 1, 1), torch.arange(tl).long().view(1, -1, 1), samples] # bs x tl x ns lengths = get_tensor(np.vectorize(len)(tokens) + 1, dtype='f') # bs x num_samples mask = get_tensor(torch.arange(tl)).float().view(1, -1, 1).expand( bs, tl, num_samples) < lengths.unsqueeze(dim=1) sample_log_probs = (mask.float() * sample_log_probs).sum( dim=1) # bs x num_samples else: # This means we are taking the argmax according to token-level probs, not character-level probs. # Take argmax _, idx = valid_log_probs.max(dim=-1) tokens = wordlist[idx.cpu().numpy()].reshape(bs, 1) num_samples = 1 sample_log_probs = get_tensor(np.ones([bs, 1])) # use chunks to get all edits chunk_size = 1000 num_chunks = len(wordlist) // chunk_size + (len(wordlist) % chunk_size > 0) expected_edits = list() for i in range(num_chunks): logging.debug('Computing chunk %d/%d' % (i + 1, num_chunks)) start = i * chunk_size end = min(start + chunk_size, len(wordlist)) valid_log_prob_chunk = valid_log_probs[:, start:end] if edit: # get dists dists = compute_dists(tokens, wordlist[start:end]) # bs x c_s x (1 + ns) dists = get_tensor(dists, 'f') # remove accidental hits duplicates = compute_duplicates( tokens, wordlist[start:end]) # bs x c_s x (1 + ns) duplicates = get_tensor(duplicates, 'f') edit_chunk = dists * duplicates # compute expected edits ex_sample_log_probs = sample_log_probs.view( bs, 1, num_samples).expand(-1, valid_log_prob_chunk.shape[-1], -1) all_sample_log_probs = torch.cat( [valid_log_prob_chunk.unsqueeze(dim=-1), ex_sample_log_probs], dim=-1) # make it less sharp all_sample_log_probs = all_sample_log_probs + ( 1.0 - duplicates) * (-999.) logits = all_sample_log_probs # * alpha sm_log_probs = torch.log_softmax( logits, dim=-1) # NOTE sm stands for softmax sm_probs = sm_log_probs.exp() expected_edits.append((edit_chunk * sm_probs).sum(dim=-1)) # expected_edits.append(dists[..., 1:].sum(dim=-1)) else: expected_edits.append(-valid_log_prob_chunk.tensor) return torch.cat(expected_edits, dim=1)
def sample_gumbel(shape: Sequence[int], eps: float = 1e-20) -> FT: """Sample from Gumbel(0, 1)""" U = get_tensor(torch.rand(shape)) return -(-(U + eps).log() + eps).log()
def tensor(self) -> LT: """Convert the state into a long tensor.""" return get_tensor(self.vocab_array).rename('word', 'pos')