def run_and_get_hidden_activations(checkpoint_path, test_data_path, attention_method, use_attention_loss,
                                   ignore_output_eos, max_len=50, save_path=None):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    LOG_FORMAT = '%(asctime)s %(name)-12s %(levelname)-8s %(message)s'
    logging.basicConfig(format=LOG_FORMAT, level=getattr(logging, 'INFO'))


    IGNORE_INDEX = -1
    output_eos_used = not ignore_output_eos

    # load model
    logging.info("loading checkpoint from {}".format(os.path.join(checkpoint_path)))
    checkpoint = AnalysableSeq2seq.load(checkpoint_path)
    seq2seq = checkpoint.model
    input_vocab = checkpoint.input_vocab
    output_vocab = checkpoint.output_vocab

    # Prepare dataset and loss
    src = SourceField()
    tgt = TargetField(output_eos_used)

    tabular_data_fields = [('src', src), ('tgt', tgt)]

    if use_attention_loss or attention_method == 'hard':
      attn = AttentionField(use_vocab=False, ignore_index=IGNORE_INDEX)
      tabular_data_fields.append(('attn', attn))

    src.vocab = input_vocab
    tgt.vocab = output_vocab
    tgt.eos_id = tgt.vocab.stoi[tgt.SYM_EOS]
    tgt.sos_id = tgt.vocab.stoi[tgt.SYM_SOS]

    def len_filter(example):
        return len(example.src) <= max_len and len(example.tgt) <= max_len

    # generate test set
    test = torchtext.data.TabularDataset(
        path=test_data_path, format='tsv',
        fields=tabular_data_fields,
        filter_pred=len_filter
    )

    # When chosen to use attentive guidance, check whether the data is correct for the first
    # example in the data set. We can assume that the other examples are then also correct.
    if use_attention_loss or attention_method == 'hard':
        if len(test) > 0:
            if 'attn' not in vars(test[0]):
                raise Exception("AttentionField not found in test data")
            tgt_len = len(vars(test[0])['tgt']) - 1 # -1 for SOS
            attn_len = len(vars(test[0])['attn']) - 1 # -1 for preprended ignore_index
            if attn_len != tgt_len:
                raise Exception("Length of output sequence does not equal length of attention sequence in test data.")

    data_func = SupervisedTrainer.get_batch_data

    activations_dataset = run_model_on_test_data(model=seq2seq, data=test, get_batch_data=data_func)

    if save_path is not None:
        activations_dataset.save(save_path)
Beispiel #2
0
checkpoint = Checkpoint.load(opt.checkpoint_path)
seq2seq = checkpoint.model
input_vocab = checkpoint.input_vocab
output_vocab = checkpoint.output_vocab

############################################################################
# Prepare dataset and loss
src = SourceField()
tgt = TargetField(output_eos_used)

tabular_data_fields = [('src', src), ('tgt', tgt)]

src.vocab = input_vocab
tgt.vocab = output_vocab
tgt.eos_id = tgt.vocab.stoi[tgt.SYM_EOS]
tgt.sos_id = tgt.vocab.stoi[tgt.SYM_SOS]
max_len = opt.max_len


def len_filter(example):
    return len(example.src) <= max_len and len(example.tgt) <= max_len


# generate test set
test = torchtext.data.TabularDataset(
    path=opt.test_data, format='tsv',
    fields=tabular_data_fields,
    filter_pred=len_filter
)

# Prepare loss and metrics