コード例 #1
0
def record2transition_dict(record: Dict[str, np.ndarray],
                           token_dictionary: Dictionary,
                           expansion_dictionary: Dictionary,
                           null_expansion: str) -> Transition:

    prev_tokens = token_dictionary.string(
        record[KEY_PREV_LEVEL_TOKENS].tolist()).split(' ')

    next_tokens = token_dictionary.string(
        record[KEY_NEXT_LEVEL_TOKENS].tolist()).split(' ')

    loss_mask = [
        0 if t == token_dictionary.pad_word else 1 for t in next_tokens
    ]

    next_tokens = unmask_tokens(next_tokens, prev_tokens, loss_mask)

    next_expans = expansion_dictionary.string(
        record[KEY_NEXT_LEVEL_EXPANS].tolist()).split(' ')
    next_expans = unmask_expansions(next_expans, loss_mask, null_expansion)

    head_positions = record[KEY_HEAD_POSITIONS].tolist()

    return Transition(
        previous_level_tokens=prev_tokens,
        loss_mask=loss_mask,
        next_level_tokens=next_tokens,
        next_level_expansions=next_expans,
        heads=head_positions,
    )
コード例 #2
0
def format_ascii(iteration: Iteration,
                 iteration_number: int,
                 token_dictionary: Dictionary,
                 expansion_dictionary: Dictionary,
                 expansion: ExpansionStrategy,
                 no_token: str = '-') -> str:
    tokens = token_dictionary.string(iteration.nlt).split(' ')
    tokens = [
        t if is_new_token else no_token
        for t, is_new_token in zip(tokens, iteration.new_token_mask)
    ]

    expansions = [
        expansion.pretty_format(e)
        for e in expansion_dictionary.string(iteration.nle).split(' ')
    ]
    expansions = [
        e if is_new_token else no_token
        for e, is_new_token in zip(expansions, iteration.new_token_mask)
    ]

    s = f'iteration {iteration_number}\n'
    s += f'  PLT: ' + token_dictionary.string(iteration.plt) + '\n'
    s += f'  NLT: ' + ' '.join(tokens) + '\n'
    s += f'  NLE: ' + ' '.join(expansions)
    return s
コード例 #3
0
def format_latex(iteration: Iteration,
                 iteration_number: int,
                 token_dictionary: Dictionary,
                 expansion_dictionary: Dictionary,
                 expansion: ExpansionStrategy,
                 no_token: str = '-') -> str:

    plt = token_dictionary.string(iteration.plt).split(' ')

    nlt = token_dictionary.string(iteration.nlt).split(' ')
    nlt = [
        t if is_new_token else no_token
        for t, is_new_token in zip(nlt, iteration.new_token_mask)
    ]

    nle = [
        expansion.pretty_format(e)
        for e in expansion_dictionary.string(iteration.nle).split(' ')
    ]
    nle = [
        e if is_new_token else no_token
        for e, is_new_token in zip(nle, iteration.new_token_mask)
    ]

    num_elems = len(nlt)

    s = '\\begin{tabularx}{\\linewidth}{p{6mm} ' + ' '.join(
        ['c'] * num_elems) + '}\n'
    s += '\\multicolumn{' + str(num_elems) + '}{l}{Iteration ' + str(
        iteration_number + 1) + '}\\\\\n'
    s += '\\hline\n'
    s += 'PLT: & ' + ' & '.join(maybe_tt(t) for t in plt) + '\\\\\n'
    s += 'NLT: & ' + ' & '.join(maybe_tt(t) for t in nlt) + '\\\\\n'
    s += 'NLE: & ' + ' & '.join(maybe_tt(e) for e in nle) + '\\\\\n'
    s += '\\end{tabularx}'
    return s