def run_and_get_hidden_activations(checkpoint_path, test_data_path, attention_method, use_attention_loss,
                                   ignore_output_eos, max_len=50, save_path=None):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    LOG_FORMAT = '%(asctime)s %(name)-12s %(levelname)-8s %(message)s'
    logging.basicConfig(format=LOG_FORMAT, level=getattr(logging, 'INFO'))


    IGNORE_INDEX = -1
    output_eos_used = not ignore_output_eos

    # load model
    logging.info("loading checkpoint from {}".format(os.path.join(checkpoint_path)))
    checkpoint = AnalysableSeq2seq.load(checkpoint_path)
    seq2seq = checkpoint.model
    input_vocab = checkpoint.input_vocab
    output_vocab = checkpoint.output_vocab

    # Prepare dataset and loss
    src = SourceField()
    tgt = TargetField(output_eos_used)

    tabular_data_fields = [('src', src), ('tgt', tgt)]

    if use_attention_loss or attention_method == 'hard':
      attn = AttentionField(use_vocab=False, ignore_index=IGNORE_INDEX)
      tabular_data_fields.append(('attn', attn))

    src.vocab = input_vocab
    tgt.vocab = output_vocab
    tgt.eos_id = tgt.vocab.stoi[tgt.SYM_EOS]
    tgt.sos_id = tgt.vocab.stoi[tgt.SYM_SOS]

    def len_filter(example):
        return len(example.src) <= max_len and len(example.tgt) <= max_len

    # generate test set
    test = torchtext.data.TabularDataset(
        path=test_data_path, format='tsv',
        fields=tabular_data_fields,
        filter_pred=len_filter
    )

    # When chosen to use attentive guidance, check whether the data is correct for the first
    # example in the data set. We can assume that the other examples are then also correct.
    if use_attention_loss or attention_method == 'hard':
        if len(test) > 0:
            if 'attn' not in vars(test[0]):
                raise Exception("AttentionField not found in test data")
            tgt_len = len(vars(test[0])['tgt']) - 1 # -1 for SOS
            attn_len = len(vars(test[0])['attn']) - 1 # -1 for preprended ignore_index
            if attn_len != tgt_len:
                raise Exception("Length of output sequence does not equal length of attention sequence in test data.")

    data_func = SupervisedTrainer.get_batch_data

    activations_dataset = run_model_on_test_data(model=seq2seq, data=test, get_batch_data=data_func)

    if save_path is not None:
        activations_dataset.save(save_path)
Пример #2
0
#################################################################################
# load model

logging.info("loading checkpoint from {}".format(os.path.join(opt.checkpoint_path)))
checkpoint = Checkpoint.load(opt.checkpoint_path)
seq2seq = checkpoint.model
input_vocab = checkpoint.input_vocab
output_vocab = checkpoint.output_vocab

############################################################################
# Prepare dataset and loss
src = SourceField()
tgt = TargetField()
src.vocab = input_vocab
tgt.vocab = output_vocab
max_len = opt.max_len

def len_filter(example):
    return len(example.src) <= max_len and len(example.tgt) <= max_len

# generate test set
test = torchtext.data.TabularDataset(
    path=opt.test_data, format='tsv',
    fields=[('src', src), ('tgt', tgt)],
    filter_pred=len_filter
)

# Prepare loss
weight = torch.ones(len(output_vocab))
pad = output_vocab.stoi[tgt.pad_token]
    if opt.load_checkpoint is None:
        raise Exception(
            'load_checkpoint must be specified when --resume is specified')
    else:
        logging.info("loading checkpoint from {}".format(
            os.path.join(opt.expt_dir, Checkpoint.CHECKPOINT_DIR_NAME,
                         opt.load_checkpoint)))
        checkpoint_path = os.path.join(opt.expt_dir,
                                       Checkpoint.CHECKPOINT_DIR_NAME,
                                       opt.load_checkpoint)
        checkpoint = Checkpoint.load(checkpoint_path)
        seq2seq = checkpoint.model
        # input_vocab = checkpoint.input_vocab
        # output_vocab = checkpoint.output_vocab
        src.vocab = checkpoint.input_vocab
        tgt.vocab = checkpoint.output_vocab
else:
    src.build_vocab(train,
                    max_size=params['src_vocab_size'],
                    specials=replace_tokens)
    tgt.build_vocab(train, max_size=params['tgt_vocab_size'])
    # input_vocab = src.vocab
    # output_vocab = tgt.vocab

logging.info('Indices of special replace tokens:\n')
for rep in replace_tokens:
    logging.info("%s, %d; " % (rep, src.vocab.stoi[rep]))
logging.info('\n')

# Prepare loss
weight = torch.ones(len(tgt.vocab))
                                          fields=[('src', src), ('tgt', tgt),
                                                  ('beh', beh)],
                                          filter_pred=len_filter)
    dev = torchtext.data.TabularDataset(path=opt.dev_path,
                                        format='tsv',
                                        fields=[('src', src), ('tgt', tgt)],
                                        filter_pred=len_filter)

    if not os.path.exists(opt.ckpt_dir):
        os.makedirs(opt.ckpt_dir)

    if opt.resume:
        latest_checkpoint_path = Checkpoint.get_latest_checkpoint(opt.ckpt_dir)
        resume_checkpoint = Checkpoint.load(latest_checkpoint_path)
        src.vocab = resume_checkpoint.input_vocab
        tgt.vocab = resume_checkpoint.output_vocab
    else:
        print('Building vocab')
        #src.build_vocab(train, max_size=50000)
        #tgt.build_vocab(train, max_size=opt.vocab_size, vectors='glove.840B.300d')
        if hidden_size == 300:
            vectors = 'glove.42B.300d'
        elif hidden_size == 100:
            vectors = 'glove.6B.100d'
        else:
            vectors = None

        tgt.build_vocab(train, max_size=vocab_size, vectors=vectors)
        src.vocab = tgt.vocab
        input_vocab = src.vocab
        output_vocab = tgt.vocab