def record2transition_dict(record: Dict[str, np.ndarray], token_dictionary: Dictionary, expansion_dictionary: Dictionary, null_expansion: str) -> Transition: prev_tokens = token_dictionary.string( record[KEY_PREV_LEVEL_TOKENS].tolist()).split(' ') next_tokens = token_dictionary.string( record[KEY_NEXT_LEVEL_TOKENS].tolist()).split(' ') loss_mask = [ 0 if t == token_dictionary.pad_word else 1 for t in next_tokens ] next_tokens = unmask_tokens(next_tokens, prev_tokens, loss_mask) next_expans = expansion_dictionary.string( record[KEY_NEXT_LEVEL_EXPANS].tolist()).split(' ') next_expans = unmask_expansions(next_expans, loss_mask, null_expansion) head_positions = record[KEY_HEAD_POSITIONS].tolist() return Transition( previous_level_tokens=prev_tokens, loss_mask=loss_mask, next_level_tokens=next_tokens, next_level_expansions=next_expans, heads=head_positions, )
def format_ascii(iteration: Iteration, iteration_number: int, token_dictionary: Dictionary, expansion_dictionary: Dictionary, expansion: ExpansionStrategy, no_token: str = '-') -> str: tokens = token_dictionary.string(iteration.nlt).split(' ') tokens = [ t if is_new_token else no_token for t, is_new_token in zip(tokens, iteration.new_token_mask) ] expansions = [ expansion.pretty_format(e) for e in expansion_dictionary.string(iteration.nle).split(' ') ] expansions = [ e if is_new_token else no_token for e, is_new_token in zip(expansions, iteration.new_token_mask) ] s = f'iteration {iteration_number}\n' s += f' PLT: ' + token_dictionary.string(iteration.plt) + '\n' s += f' NLT: ' + ' '.join(tokens) + '\n' s += f' NLE: ' + ' '.join(expansions) return s
def format_latex(iteration: Iteration, iteration_number: int, token_dictionary: Dictionary, expansion_dictionary: Dictionary, expansion: ExpansionStrategy, no_token: str = '-') -> str: plt = token_dictionary.string(iteration.plt).split(' ') nlt = token_dictionary.string(iteration.nlt).split(' ') nlt = [ t if is_new_token else no_token for t, is_new_token in zip(nlt, iteration.new_token_mask) ] nle = [ expansion.pretty_format(e) for e in expansion_dictionary.string(iteration.nle).split(' ') ] nle = [ e if is_new_token else no_token for e, is_new_token in zip(nle, iteration.new_token_mask) ] num_elems = len(nlt) s = '\\begin{tabularx}{\\linewidth}{p{6mm} ' + ' '.join( ['c'] * num_elems) + '}\n' s += '\\multicolumn{' + str(num_elems) + '}{l}{Iteration ' + str( iteration_number + 1) + '}\\\\\n' s += '\\hline\n' s += 'PLT: & ' + ' & '.join(maybe_tt(t) for t in plt) + '\\\\\n' s += 'NLT: & ' + ' & '.join(maybe_tt(t) for t in nlt) + '\\\\\n' s += 'NLE: & ' + ' & '.join(maybe_tt(e) for e in nle) + '\\\\\n' s += '\\end{tabularx}' return s