def run_and_get_hidden_activations(checkpoint_path, test_data_path, attention_method, use_attention_loss,
                                   ignore_output_eos, max_len=50, save_path=None):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    LOG_FORMAT = '%(asctime)s %(name)-12s %(levelname)-8s %(message)s'
    logging.basicConfig(format=LOG_FORMAT, level=getattr(logging, 'INFO'))


    IGNORE_INDEX = -1
    output_eos_used = not ignore_output_eos

    # load model
    logging.info("loading checkpoint from {}".format(os.path.join(checkpoint_path)))
    checkpoint = AnalysableSeq2seq.load(checkpoint_path)
    seq2seq = checkpoint.model
    input_vocab = checkpoint.input_vocab
    output_vocab = checkpoint.output_vocab

    # Prepare dataset and loss
    src = SourceField()
    tgt = TargetField(output_eos_used)

    tabular_data_fields = [('src', src), ('tgt', tgt)]

    if use_attention_loss or attention_method == 'hard':
      attn = AttentionField(use_vocab=False, ignore_index=IGNORE_INDEX)
      tabular_data_fields.append(('attn', attn))

    src.vocab = input_vocab
    tgt.vocab = output_vocab
    tgt.eos_id = tgt.vocab.stoi[tgt.SYM_EOS]
    tgt.sos_id = tgt.vocab.stoi[tgt.SYM_SOS]

    def len_filter(example):
        return len(example.src) <= max_len and len(example.tgt) <= max_len

    # generate test set
    test = torchtext.data.TabularDataset(
        path=test_data_path, format='tsv',
        fields=tabular_data_fields,
        filter_pred=len_filter
    )

    # When chosen to use attentive guidance, check whether the data is correct for the first
    # example in the data set. We can assume that the other examples are then also correct.
    if use_attention_loss or attention_method == 'hard':
        if len(test) > 0:
            if 'attn' not in vars(test[0]):
                raise Exception("AttentionField not found in test data")
            tgt_len = len(vars(test[0])['tgt']) - 1 # -1 for SOS
            attn_len = len(vars(test[0])['attn']) - 1 # -1 for preprended ignore_index
            if attn_len != tgt_len:
                raise Exception("Length of output sequence does not equal length of attention sequence in test data.")

    data_func = SupervisedTrainer.get_batch_data

    activations_dataset = run_model_on_test_data(model=seq2seq, data=test, get_batch_data=data_func)

    if save_path is not None:
        activations_dataset.save(save_path)
Esempio n. 2
0
def build_time_ds():
    """" Wrapper class of torchtext.data.Field that forces batch_first to be True
    and prepend <sos> and append <eos> to sequences in preprocessing step. """
    tokenize = lambda s: ['%', '&'] + ['BASEBALL']

    TEXT_TARGET = TargetField(batch_first=True,
                              sequential=True,
                              use_vocab=True,
                              lower=False,
                              init_token=TargetField.SYM_SOS,
                              eos_token=TargetField.SYM_EOS,
                              tokenize=tokenize,
                              preprocessing=lambda x: x)  # fix_length=10
    TEXT = SourceField(batch_first=True,
                       sequential=True,
                       use_vocab=True,
                       lower=False,
                       tokenize=tokenize,
                       preprocessing=lambda x: x)  # fix_length=10
    LABEL = data.Field(batch_first=True,
                       sequential=False,
                       use_vocab=False,
                       tensor_type=torch.FloatTensor)

    fields = [('sent_0', TEXT), ('sent_1', TEXT), ('sent_x', TEXT),
              ('is_x_0', LABEL), ('sent_0_target', TEXT_TARGET)]
    ds_train = data.Dataset(TimeStyleDataset(1e3, 1, label_smoothing=True),
                            fields)
    ds_eval = data.Dataset(TimeStyleDataset(1e3, 2), fields)

    print('printing dataset directly, before tokenizing:')
    print('sent_0', ds_train[2].sent_0)  # not processed
    print('is_x_0', ds_train[2].is_x_0)  # not processed

    print('\nbuilding vocab:')
    # TEXT.build_vocab(ds, max_size=80000)
    TEXT_TARGET.build_vocab(ds_train, max_size=80000)
    TEXT.vocab = TEXT_TARGET.vocab  # same except from the added <sos>,<eos>

    print('vocab TEXT: len', len(TEXT.vocab), 'common',
          TEXT.vocab.freqs.most_common()[:50])
    print('vocab TEXT_TARGET:', len(TEXT_TARGET.vocab), 'uncommon',
          TEXT_TARGET.vocab.freqs.most_common()[-10::])
    print('vocab ', TEXT_TARGET.SYM_SOS, TEXT_TARGET.sos_id,
          TEXT_TARGET.vocab.stoi[TEXT_TARGET.SYM_SOS])
    print('vocab ', TEXT_TARGET.SYM_EOS, TEXT_TARGET.eos_id,
          TEXT_TARGET.vocab.stoi[TEXT_TARGET.SYM_EOS])
    print('vocab ', 'out-of-vocab', TEXT_TARGET.eos_id,
          TEXT_TARGET.vocab.stoi['out-of-vocab'])

    device = None if torch.cuda.is_available() else -1
    # READ:  https://github.com/mjc92/TorchTextTutorial/blob/master/01.%20Getting%20started.ipynb
    sort_within_batch = True
    train_iter = iter(
        data.BucketIterator(dataset=ds_train,
                            device=device,
                            batch_size=32,
                            sort_within_batch=sort_within_batch,
                            sort_key=lambda x: len(x.sent_0)))
    eval_iter = iter(
        data.BucketIterator(dataset=ds_eval,
                            device=device,
                            batch_size=32,
                            sort_within_batch=sort_within_batch,
                            sort_key=lambda x: len(x.sent_0)))
    # performance note: the first next, takes 3.5s, the next are fast (10000 is 1s)

    for i in range(1):
        b = next(train_iter)
        # usage
        print('\nb.is_x_0', b.is_x_0[0], b.is_x_0.type())
        # print ('b.src is values+len tuple',b.src[0].shape,b.src[1].shape )
        print('b.sent_0_target', b.sent_0_target.shape, b.sent_0_target[0],
              revers_vocab(TEXT_TARGET.vocab, b.sent_0_target[0], ''))
        print('b_sent0', b.sent_0[0].shape, b.sent_0[1].shape, b.sent_0[0][0],
              revers_vocab(TEXT.vocab, b.sent_0[0][0], ''))
        print('b_sent1', b.sent_1[0].shape, b.sent_1[1].shape, b.sent_1[0][0],
              revers_vocab(TEXT.vocab, b.sent_1[0][0], ''))
        print('b_sentx', b.sent_x[0].shape, b.sent_x[1].shape, b.sent_x[0][0],
              revers_vocab(TEXT.vocab, b.sent_x[0][0], ''))
        print('b_y', b.is_x_0.shape, b.is_x_0)
        print(b.sent_0[1])

    return ds_train, ds_eval, train_iter, eval_iter
Esempio n. 3
0
def build_bible_datasets(verbose=False):
    """
    :return: bucket_iter_train, bucket_iter_valid
     To get an epoch-iterator , do iter= iter(bucket_iter_train). and then loop on next(iter)
     It easy to get dataset/fields from it , using bucket_iter_train.dataset.fields
    """
    def as_id_to_sentence(filename):
        d = {}
        with open(filename, 'r') as f:
            f.readline()
            for l in csv.reader(f.readlines(),
                                quotechar='"',
                                delimiter=',',
                                quoting=csv.QUOTE_ALL,
                                skipinitialspace=True):
                # id,b,c,v,t
                # 1001001,1,1,1,At the first God made the heaven and the earth.
                d[l[0]] = l[4]
        return d

    bbe = as_id_to_sentence('t_bbe.csv')
    wbt = as_id_to_sentence('t_wbt.csv')
    print('num of sentences', len(bbe), len(wbt))

    # merge into a list with (s1,s2) tuple
    bibles = []
    for sent_id, sent_wbt in wbt.items():
        if sent_id in bbe:
            sent_bbe = bbe[sent_id]
            bibles.append((sent_wbt, sent_bbe))
    if verbose:
        print(len(bibles), bibles[0])

    tokenize = 'revtok'  # lambda x: x.split(' ') # 'revtok' #
    TEXT_TARGET = TargetField(batch_first=True,
                              sequential=True,
                              use_vocab=True,
                              lower=True,
                              init_token=TargetField.SYM_SOS,
                              eos_token=TargetField.SYM_EOS,
                              tokenize=tokenize)  # fix_length=30)
    TEXT = SourceField(batch_first=True,
                       sequential=True,
                       use_vocab=True,
                       lower=True,
                       tokenize=tokenize)  # , fix_length=20)
    LABEL = data.Field(batch_first=True,
                       sequential=False,
                       use_vocab=False,
                       tensor_type=torch.FloatTensor)

    bible_style_ds_trn = BibleStyleDS(
        [x for (i, x) in enumerate(bibles) if i % 10 != 9],
        TEXT,
        TEXT_TARGET,
        LABEL,
        label_smoothing=False)
    bible_style_ds_val = BibleStyleDS(
        [x for (i, x) in enumerate(bibles) if i % 10 == 9],
        TEXT,
        TEXT_TARGET,
        LABEL,
        label_smoothing=False)
    if verbose:
        for i in range(1):
            print("RAW SENTENCES", bible_style_ds_val[i])
        # print (type(bible_style_ds[i].sent_0),type(bible_style_ds[i].is_x_0),bible_style_ds[i])

    fields = [('sent_0', TEXT), ('sent_1', TEXT), ('sent_x', TEXT),
              ('is_x_0', LABEL), ('sent_0_target', TEXT_TARGET)]
    ds_train = data.Dataset(bible_style_ds_trn, fields)
    ds_val = data.Dataset(bible_style_ds_val, fields)

    # import pdb; pdb.set_trace()
    if verbose:
        print('printing dataset directly, before tokenizing:')
        print('sent_0', ds_train[2].sent_0)  # not processed
        print('is_x_0', ds_train[2].is_x_0)  # not processed

    print('\nbuilding vocab:')

    TEXT_TARGET.build_vocab(
        ds_train, vectors='fasttext.simple.300d', min_freq=20
    )  # , max_size=80000)#,vectors='fasttext.simple.300d')  #vectors=,'fasttext.simple.300d' not-simple 'fasttext.en.300d' ,'glove.twitter.27B.50d': '
    TEXT.vocab = TEXT_TARGET.vocab  # same except from the added <sos>,<eos>
    print('total', len(TEXT.vocab), 'after ignoring non-frequent')
    if verbose:
        print('vocab TEXT: len', len(TEXT.vocab), 'common',
              TEXT.vocab.freqs.most_common()[:5])
        print('vocab TEXT: len', len(TEXT.vocab), 'uncommon',
              TEXT.vocab.freqs.most_common()[-5:])
        print('vocab TEXT_TARGET:', len(TEXT_TARGET.vocab),
              TEXT_TARGET.vocab.freqs.most_common()[:5])
        print('vocab ', TEXT_TARGET.SYM_SOS, TEXT_TARGET.sos_id,
              TEXT_TARGET.vocab.stoi[TEXT_TARGET.SYM_SOS])
        print('vocab ', TEXT_TARGET.SYM_EOS, TEXT_TARGET.eos_id,
              TEXT_TARGET.vocab.stoi[TEXT_TARGET.SYM_EOS])
        print('vocab ', 'out-of-vocab', TEXT_TARGET.vocab.stoi['out-of-vocab'])
        print('vocab ', 'i0',
              [(i, TEXT_TARGET.vocab.itos[i]) for i in range(6)])
    device = torch.device('cuda') if torch.cuda.is_available() else -1
    # READ:  https://github.com/mjc92/TorchTextTutorial/blob/master/01.%20Getting%20started.ipynb
    print('device is cuda or -1 for cpu:', device)

    bucket_iter_train = data.BucketIterator(dataset=ds_train,
                                            shuffle=True,
                                            device=device,
                                            batch_size=32,
                                            sort_within_batch=False,
                                            sort_key=lambda x: len(x.sent_0))
    bucket_iter_valid = data.BucketIterator(
        dataset=ds_val,
        shuffle=False,
        device=device,
        batch_size=32,
        sort_within_batch=False,  #sort_key=lambda x: len(x.sent_0)
    )

    if verbose:  #show few samples
        for i in range(1):
            # performance note: the first next, takes 3.5s, the next are fast (10000 is 1s)
            b = next(iter(bucket_iter_train))
            # usage
            print('\nb.is_x_0', b.is_x_0[0], b.is_x_0.type())
            # print ('b.src is values+len tuple',b.src[0].shape,b.src[1].shape )
            print('b.sent_0_target',
                  b.sent_0_target.shape)  # ,b.sent_0_target[0])
            print('b.sent_0_target',
                  revers_vocab(TEXT_TARGET.vocab, b.sent_0_target[0], ' '))
            print('b_sent0', b.sent_0[0].shape, b.sent_0[1].shape,
                  revers_vocab(TEXT.vocab, b.sent_0[0][0], ' '))
            print('b_sent1', b.sent_1[0].shape, b.sent_1[1].shape,
                  revers_vocab(TEXT.vocab, b.sent_1[0][0], ' '))
            print('b_sentx', b.sent_x[0].shape, b.sent_x[1].shape,
                  revers_vocab(TEXT.vocab, b.sent_x[0][0], ' '),
                  b.sent_x[0][0])
            print('b_y', b.is_x_0.shape, b.is_x_0[0])

    return bucket_iter_train, bucket_iter_valid
Esempio n. 4
0
def build_quora_dataset(verbose=False):
    # Create a dataset which is only used as internal tsv reader
    SOURCE_INT = data.Field(batch_first=True,
                            sequential=False,
                            use_vocab=False)  # tensor_type =torch.IntTensor)
    ds = data.TabularDataset('train.csv',
                             format='csv',
                             skip_header=True,
                             fields=[('id', SOURCE_INT), ('qid1', SOURCE_INT),
                                     ('qid2', SOURCE_INT),
                                     ('question1', SOURCE_INT),
                                     ('question2', SOURCE_INT),
                                     ('is_duplicate', SOURCE_INT)])

    tokenize = 'revtok'  # lambda x: x.split(' ') # 'revtok' #
    TEXT_TARGET = TargetField(batch_first=True,
                              sequential=True,
                              use_vocab=True,
                              lower=True,
                              init_token=TargetField.SYM_SOS,
                              eos_token=TargetField.SYM_EOS,
                              tokenize=tokenize)  # fix_length=30)
    TEXT = SourceField(batch_first=True,
                       sequential=True,
                       use_vocab=True,
                       lower=True,
                       tokenize=tokenize)  # , fix_length=20)
    LABEL = data.Field(batch_first=True,
                       sequential=False,
                       use_vocab=False,
                       tensor_type=torch.FloatTensor)

    sem_style_ds = SemStyleDS(ds, TEXT, TEXT_TARGET, LABEL, max_id=1000 * 1000)
    for i in range(5):
        print(type(sem_style_ds[i].sent_0), type(sem_style_ds[i].is_x_0),
              sem_style_ds[i])

    ds_train = data.Dataset(sem_style_ds,
                            fields=[('sent_0', TEXT), ('sent_1', TEXT),
                                    ('sent_x', TEXT), ('is_x_0', LABEL),
                                    ('sent_0_target', TEXT_TARGET)])
    # import pdb; pdb.set_trace()
    print('printing dataset directly, before tokenizing:')
    print('sent_0', ds_train[2].sent_0)  # not processed
    print('is_x_0', ds_train[2].is_x_0)  # not processed

    print('\nbuilding vocab:')

    TEXT_TARGET.build_vocab(
        ds_train, vectors='fasttext.simple.300d'
    )  # , max_size=80000)#,vectors='fasttext.simple.300d')  #vectors=,'fasttext.simple.300d' not-simple 'fasttext.en.300d' ,'glove.twitter.27B.50d': '
    TEXT.vocab = TEXT_TARGET.vocab  # same except from the added <sos>,<eos>

    print('vocab TEXT: len', len(TEXT.vocab), 'common',
          TEXT.vocab.freqs.most_common()[:10])
    print('vocab TEXT_TARGET:', len(TEXT_TARGET.vocab),
          TEXT_TARGET.vocab.freqs.most_common()[:10])
    print('vocab ', TEXT_TARGET.SYM_SOS, TEXT_TARGET.sos_id,
          TEXT_TARGET.vocab.stoi[TEXT_TARGET.SYM_SOS])
    print('vocab ', TEXT_TARGET.SYM_EOS, TEXT_TARGET.eos_id,
          TEXT_TARGET.vocab.stoi[TEXT_TARGET.SYM_EOS])
    print('vocab ', 'out-of-vocab', TEXT_TARGET.eos_id,
          TEXT_TARGET.vocab.stoi['out-of-vocab'])

    device = None if torch.cuda.is_available() else -1
    # READ:  https://github.com/mjc92/TorchTextTutorial/blob/master/01.%20Getting%20started.ipynb
    bucket_iter_train = data.BucketIterator(dataset=ds_train,
                                            shuffle=True,
                                            device=device,
                                            batch_size=32,
                                            sort_within_batch=False,
                                            sort_key=lambda x: len(x.sent_0))
    bucket_iter_valid = bucket_iter_train  # data.BucketIterator(dataset=ds_val, shuffle=False, device=device, batch_size=32,
    #    sort_within_batch=False, #sort_key=lambda x: len(x.sent_0)
    # )

    #bucket_iter_train = data.BucketIterator(dataset=ds_train, device=device, batch_size=32, sort_within_batch=False,
    #                                        sort_key=lambda x: len(x.sent_0))
    #print('$' * 40, 'change batch_size to 32')

    # performance note: the first next, takes 3.5s, the next are fast (10000 is 1s)

    if verbose:
        training_batch_generator = iter(bucket_iter_train)
        for i in range(5):
            b = next(training_batch_generator)
            # usage
            print('\nb.is_x_0', b.is_x_0[0], b.is_x_0.type())
            # print ('b.src is values+len tuple',b.src[0].shape,b.src[1].shape )
            print('b.sent_0_target', b.sent_0_target.shape, b.sent_0_target[0])
            print('b.sent_0_target', b.sent_0_target.shape, b.sent_0_target[0],
                  revers_vocab(TEXT_TARGET.vocab, b.sent_0_target[0], ' '))
            print('b_sent0', b.sent_0[0].shape,
                  b.sent_0[1].shape, b.sent_0[0][0],
                  revers_vocab(TEXT.vocab, b.sent_0[0][0], ' '))
            print('b_sent1', b.sent_1[0].shape,
                  b.sent_1[1].shape, b.sent_1[0][0],
                  revers_vocab(TEXT.vocab, b.sent_1[0][0], ' '))
            print('b_sentx', b.sent_x[0].shape,
                  b.sent_x[1].shape, b.sent_x[0][0],
                  revers_vocab(TEXT.vocab, b.sent_x[0][0], ' '))
            print('b_y', b.is_x_0.shape, b.is_x_0[0])

    # addons

    return bucket_iter_train, bucket_iter_valid
Esempio n. 5
0
        torch.cuda.set_device(opt.cuda_device)

#################################################################################
# load model

logging.info("loading checkpoint from {}".format(os.path.join(opt.checkpoint_path)))
checkpoint = Checkpoint.load(opt.checkpoint_path)
seq2seq = checkpoint.model
input_vocab = checkpoint.input_vocab
output_vocab = checkpoint.output_vocab

############################################################################
# Prepare dataset and loss
src = SourceField()
tgt = TargetField()
src.vocab = input_vocab
tgt.vocab = output_vocab
max_len = opt.max_len

def len_filter(example):
    return len(example.src) <= max_len and len(example.tgt) <= max_len

# generate test set
test = torchtext.data.TabularDataset(
    path=opt.test_data, format='tsv',
    fields=[('src', src), ('tgt', tgt)],
    filter_pred=len_filter
)

# Prepare loss
weight = torch.ones(len(output_vocab))
if opt.resume:
    if opt.load_checkpoint is None:
        raise Exception(
            'load_checkpoint must be specified when --resume is specified')
    else:
        logging.info("loading checkpoint from {}".format(
            os.path.join(opt.expt_dir, Checkpoint.CHECKPOINT_DIR_NAME,
                         opt.load_checkpoint)))
        checkpoint_path = os.path.join(opt.expt_dir,
                                       Checkpoint.CHECKPOINT_DIR_NAME,
                                       opt.load_checkpoint)
        checkpoint = Checkpoint.load(checkpoint_path)
        seq2seq = checkpoint.model
        # input_vocab = checkpoint.input_vocab
        # output_vocab = checkpoint.output_vocab
        src.vocab = checkpoint.input_vocab
        tgt.vocab = checkpoint.output_vocab
else:
    src.build_vocab(train,
                    max_size=params['src_vocab_size'],
                    specials=replace_tokens)
    tgt.build_vocab(train, max_size=params['tgt_vocab_size'])
    # input_vocab = src.vocab
    # output_vocab = tgt.vocab

logging.info('Indices of special replace tokens:\n')
for rep in replace_tokens:
    logging.info("%s, %d; " % (rep, src.vocab.stoi[rep]))
logging.info('\n')

# Prepare loss
                                          format='tsv',
                                          fields=[('src', src), ('tgt', tgt),
                                                  ('beh', beh)],
                                          filter_pred=len_filter)
    dev = torchtext.data.TabularDataset(path=opt.dev_path,
                                        format='tsv',
                                        fields=[('src', src), ('tgt', tgt)],
                                        filter_pred=len_filter)

    if not os.path.exists(opt.ckpt_dir):
        os.makedirs(opt.ckpt_dir)

    if opt.resume:
        latest_checkpoint_path = Checkpoint.get_latest_checkpoint(opt.ckpt_dir)
        resume_checkpoint = Checkpoint.load(latest_checkpoint_path)
        src.vocab = resume_checkpoint.input_vocab
        tgt.vocab = resume_checkpoint.output_vocab
    else:
        print('Building vocab')
        #src.build_vocab(train, max_size=50000)
        #tgt.build_vocab(train, max_size=opt.vocab_size, vectors='glove.840B.300d')
        if hidden_size == 300:
            vectors = 'glove.42B.300d'
        elif hidden_size == 100:
            vectors = 'glove.6B.100d'
        else:
            vectors = None

        tgt.build_vocab(train, max_size=vocab_size, vectors=vectors)
        src.vocab = tgt.vocab
        input_vocab = src.vocab