Пример #1
0
 def sort_key(ex):
     return data.interleave_keys(len(ex.src), len(ex.trg))
 def sort_key(ex):
     return data.interleave_keys(len(ex.premise), len(ex.hypothesis))
Пример #3
0
 def sort_key(ex):
     return data.interleave_keys(len(ex.premise), len(ex.hypothesis))
Пример #4
0
def main():
    args_parser = argparse.ArgumentParser(
        description='Tuning with graph-based parsing')
    args_parser.add_argument('--cuda', action='store_true', help='using GPU')
    args_parser.add_argument('--num_epochs',
                             type=int,
                             default=200,
                             help='Number of training epochs')
    args_parser.add_argument('--batch_size',
                             type=int,
                             default=64,
                             help='Number of sentences in each batch')
    args_parser.add_argument('--hidden_size',
                             type=int,
                             default=256,
                             help='Number of hidden units in RNN')
    args_parser.add_argument('--num_layers',
                             type=int,
                             default=1,
                             help='Number of layers of RNN')
    args_parser.add_argument('--opt',
                             choices=['adam', 'sgd', 'adamax'],
                             help='optimization algorithm')
    args_parser.add_argument('--objective',
                             choices=['cross_entropy', 'crf'],
                             default='cross_entropy',
                             help='objective function of training procedure.')
    args_parser.add_argument('--learning_rate',
                             type=float,
                             default=0.01,
                             help='Learning rate')
    args_parser.add_argument('--decay_rate',
                             type=float,
                             default=0.05,
                             help='Decay rate of learning rate')
    args_parser.add_argument('--clip',
                             type=float,
                             default=5.0,
                             help='gradient clipping')
    args_parser.add_argument('--gamma',
                             type=float,
                             default=0.0,
                             help='weight for regularization')
    args_parser.add_argument('--epsilon',
                             type=float,
                             default=1e-8,
                             help='epsilon for adam or adamax')
    args_parser.add_argument('--p_rnn',
                             nargs=2,
                             type=float,
                             default=0.1,
                             help='dropout rate for RNN')
    args_parser.add_argument('--p_in',
                             type=float,
                             default=0.33,
                             help='dropout rate for input embeddings')
    args_parser.add_argument('--p_out',
                             type=float,
                             default=0.33,
                             help='dropout rate for output layer')
    args_parser.add_argument('--schedule',
                             type=int,
                             help='schedule for learning rate decay')
    args_parser.add_argument(
        '--unk_replace',
        type=float,
        default=0.,
        help='The rate to replace a singleton word with UNK')
    # args_parser.add_argument('--punctuation', nargs='+', type=str, help='List of punctuations')
    args_parser.add_argument('--word_path',
                             help='path for word embedding dict')
    args_parser.add_argument(
        '--freeze',
        action='store_true',
        help='frozen the word embedding (disable fine-tuning).')
    # args_parser.add_argument('--char_path', help='path for character embedding dict')
    args_parser.add_argument(
        '--train')  # "data/POS-penn/wsj/split1/wsj1.train.original"
    args_parser.add_argument(
        '--dev')  # "data/POS-penn/wsj/split1/wsj1.dev.original"
    args_parser.add_argument(
        '--test')  # "data/POS-penn/wsj/split1/wsj1.test.original"
    args_parser.add_argument('--model_path',
                             help='path for saving model file.',
                             default='models/temp')
    args_parser.add_argument('--model_name',
                             help='name for saving model file.',
                             default='generator')

    args_parser.add_argument('--seq2seq_save_path',
                             default='checkpoints/seq2seq_save_model',
                             type=str,
                             help='seq2seq_save_path')
    args_parser.add_argument('--seq2seq_load_path',
                             default='checkpoints/seq2seq_save_model',
                             type=str,
                             help='seq2seq_load_path')

    args_parser.add_argument('--direct_eval',
                             action='store_true',
                             help='direct eval without generation process')
    args_parser.add_argument('--single_seq2seq',
                             action='store_true',
                             help='1to1 or 2to1')
    args = args_parser.parse_args()

    spacy_en = spacy.load('en_core_web_sm')  # python -m spacy download en
    spacy_de = spacy.load('de_core_news_sm')  # python -m spacy download en
    spacy_fr = spacy.load('fr_core_news_sm')  # python -m spacy download en

    SEED = random.randint(1, 100000)
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    device = torch.device(
        'cpu'
    )  # torch.device('cuda' if torch.cuda.is_available() else 'cpu') #'cpu' if not torch.cuda.is_available() else 'cuda:0'

    def tokenizer_en(text):  # create a tokenizer function
        return [tok.text for tok in spacy_en.tokenizer(text)]

    def tokenizer_de(text):  # create a tokenizer function
        return [tok.text for tok in spacy_de.tokenizer(text)]

    def tokenizer_fr(text):  # create a tokenizer function
        return [tok.text for tok in spacy_fr.tokenizer(text)]

    en_field = data.Field(sequential=True,
                          tokenize=tokenizer_en,
                          lower=True,
                          include_lengths=True,
                          batch_first=True)  # use_vocab=False fix_length=10
    de_field = data.Field(sequential=True,
                          tokenize=tokenizer_de,
                          lower=True,
                          include_lengths=True,
                          batch_first=True)  # use_vocab=False
    fr_field = data.Field(sequential=True,
                          tokenize=tokenizer_fr,
                          lower=True,
                          include_lengths=True,
                          batch_first=True)  # use_vocab=False
    print('begin loading training data-----')
    print('time: ', time.asctime(time.localtime(time.time())))
    seq2seq_train_data = MultiSourceTranslationDataset(
        path='wmt14_3/train',
        exts=('.de', '.fr', '.en'),
        fields=(de_field, fr_field, en_field))
    print('begin loading validation data-----')
    print('time: ', time.asctime(time.localtime(time.time())))
    seq2seq_dev_data = MultiSourceTranslationDataset(
        path='wmt14_3/valid',
        exts=('.de', '.fr', '.en'),
        fields=(de_field, fr_field, en_field))
    print('end loading data-----')
    print('time: ', time.asctime(time.localtime(time.time())))

    # vocab_thread = 20000 + 2
    # with open(str(vocab_thread) + '_vocab_en.pickle', 'rb') as f:
    #     en_field.vocab = pickle.load(f)
    # with open(str(vocab_thread) + '_vocab_de.pickle', 'rb') as f:
    #     de_field.vocab = pickle.load(f)
    # with open(str(vocab_thread) + '_vocab_fr.pickle', 'rb') as f:
    #     fr_field.vocab = pickle.load(f)
    # print('end build vocab-----')
    # print('time: ', time.asctime(time.localtime(time.time())))

    en_train_data = datasets.TranslationDataset(path='wmt14_3/train',
                                                exts=('.en', '.en'),
                                                fields=(en_field, en_field))
    print('end en data-----')
    print('time: ', time.asctime(time.localtime(time.time())))
    de_train_data = datasets.TranslationDataset(path='wmt14_3/train',
                                                exts=('.de', '.de'),
                                                fields=(de_field, de_field))
    fr_train_data = datasets.TranslationDataset(path='wmt14_3/train',
                                                exts=('.fr', '.fr'),
                                                fields=(fr_field, fr_field))
    en_field.build_vocab(en_train_data,
                         max_size=80000)  # ,vectors="glove.6B.100d"
    de_field.build_vocab(de_train_data,
                         max_size=80000)  # ,vectors="glove.6B.100d"
    fr_field.build_vocab(fr_train_data,
                         max_size=80000)  # ,vectors="glove.6B.100d"

    train_iter = data.BucketIterator(
        dataset=seq2seq_train_data,
        batch_size=16,
        sort_key=lambda x: data.interleave_keys(len(x.src), len(x.trg)),
        device=device,
        shuffle=True
    )  # Note that if you are runing on CPU, you must set device to be -1, otherwise you can leave it to 0 for GPU.
    dev_iter = data.BucketIterator(
        dataset=seq2seq_dev_data,
        batch_size=16,
        sort_key=lambda x: data.interleave_keys(len(x.src), len(x.trg)),
        device=device,
        shuffle=False)

    num_words_en = len(en_field.vocab.stoi)
    # Pretrain seq2seq model using denoising autoencoder. model name: seq2seq model

    EPOCHS = 150  # 150
    DECAY = 0.97
    # TODO: #len(en_field.vocab.stoi)  # ?? word_embedd ??
    word_dim = 10  #300  # ??
    if args.single_seq2seq:
        seq2seq = Single_Seq2seq_Model(EMB=word_dim,
                                       HID=args.hidden_size,
                                       DPr=0.5,
                                       vocab_size1=len(de_field.vocab.stoi),
                                       vocab_size2=len(fr_field.vocab.stoi),
                                       vocab_size3=len(en_field.vocab.stoi),
                                       word_embedd=None,
                                       device=device).to(
                                           device)  # TODO: random init vocab
    else:
        seq2seq = Seq2seq_Model(EMB=word_dim,
                                HID=args.hidden_size,
                                DPr=0.5,
                                vocab_size1=len(de_field.vocab.stoi),
                                vocab_size2=len(fr_field.vocab.stoi),
                                vocab_size3=len(en_field.vocab.stoi),
                                word_embedd=None,
                                device=device).to(
                                    device)  # TODO: random init vocab

    # seq2seq.emb.weight.requires_grad = False
    print(seq2seq)

    loss_seq2seq = torch.nn.CrossEntropyLoss(reduction='none').to(device)
    parameters_need_update = filter(lambda p: p.requires_grad,
                                    seq2seq.parameters())
    optim_seq2seq = torch.optim.Adam(parameters_need_update, lr=0.0003)

    # seq2seq.load_state_dict(torch.load(args.seq2seq_load_path + str(2) + '.pt'))  # TODO: 10.7
    seq2seq.to(device)

    def count_parameters(model: torch.nn.Module):
        return sum(p.numel() for p in model.parameters() if p.requires_grad)

    print(f'The model has {count_parameters(seq2seq):,} trainable parameters')
    PAD_IDX = en_field.vocab.stoi['<pad>']
    # criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)
    for i in range(EPOCHS):
        ls_seq2seq_ep = 0
        seq2seq.train()
        # seq2seq.emb.weight.requires_grad = False
        print('----------' + str(i) + ' iter----------')
        for _, batch in enumerate(train_iter):
            src1, lengths_src1 = batch.src1  # word:(32,50)  150,64
            src2, lengths_src2 = batch.src2  # word:(32,50)  150,64
            trg, lengths_trg = batch.trg

            # max_len1 = src1.size()[1]  # batch_first
            # masks1 = torch.arange(max_len1).expand(len(lengths_src1), max_len1) < lengths_src1.unsqueeze(1)
            # masks1 = masks1.long()
            # max_len2 = src2.size()[1]  # batch_first
            # masks2 = torch.arange(max_len2).expand(len(lengths_src2), max_len2) < lengths_src2.unsqueeze(1)
            # masks2 = masks2.long()
            dec_out = trg
            start_list = torch.ones((trg.shape[0], 1)).long().to(device)
            dec_inp = torch.cat((start_list, trg[:, 0:-1]),
                                dim=1)  # maybe wrong
            # train_seq2seq
            if args.single_seq2seq:
                out = seq2seq(src1.long().to(device),
                              is_tr=True,
                              dec_inp=dec_inp.long().to(device))
            else:
                out = seq2seq(src1.long().to(device),
                              src2.long().to(device),
                              is_tr=True,
                              dec_inp=dec_inp.long().to(device))

            out = out.view((out.shape[0] * out.shape[1], out.shape[2]))
            dec_out = dec_out.view((dec_out.shape[0] * dec_out.shape[1], ))

            # max_len_trg = trg.size()[1]  # batch_first
            # masks_trg = torch.arange(max_len_trg).expand(len(lengths_trg), max_len_trg) < lengths_trg.unsqueeze(1)
            # masks_trg = masks_trg.float().to(device)
            # wgt = masks_trg.view(-1)
            # wgt = seq2seq.add_stop_token(masks, lengths_src)  # TODO
            # wgt = wgt.view((wgt.shape[0] * wgt.shape[1],)).float().to(device)
            # wgt = masks.view(-1)

            ls_seq2seq_bh = loss_seq2seq(
                out,
                dec_out.long().to(device))  # 9600, 8133
            # ls_seq2seq_bh = (ls_seq2seq_bh * wgt).sum() / wgt.sum()  # TODO
            ls_seq2seq_bh = ls_seq2seq_bh.sum() / ls_seq2seq_bh.numel()

            optim_seq2seq.zero_grad()
            ls_seq2seq_bh.backward()
            optim_seq2seq.step()

            ls_seq2seq_bh = ls_seq2seq_bh.cpu().detach().numpy()
            ls_seq2seq_ep += ls_seq2seq_bh
        print('ls_seq2seq_ep: ', ls_seq2seq_ep)
        for pg in optim_seq2seq.param_groups:
            pg['lr'] *= DECAY

        # test th bleu of seq2seq
        if i > 40:
            print('ss')
        if i > 0:  # i%1 == 0:
            seq2seq.eval()
            bleu_ep = 0
            acc_numerator_ep = 0
            acc_denominator_ep = 0
            testi = 0
            for _, batch in enumerate(
                    train_iter
            ):  # for _ in range(1, num_batches + 1):  word, char, pos, heads, types, masks, lengths = conllx_data.get_batch_tensor(data_dev, batch_size, unk_replace=unk_replace)  # word:(32,50)  char:(32,50,35)
                src1, lengths_src1 = batch.src1  # word:(32,50)  150,64
                src2, lengths_src2 = batch.src2  # word:(32,50)  150,64
                trg, lengths_trg = batch.trg
                if args.single_seq2seq:
                    sel, _ = seq2seq(src1.long().to(device),
                                     LEN=src1.size()[1] + 5)  # TODO:
                else:
                    sel, _ = seq2seq(src1.long().to(device),
                                     src2.long().to(device),
                                     LEN=max(src1.size()[1],
                                             src2.size()[1]))  # TODO:
                sel = sel.detach().cpu().numpy()
                dec_out = trg.cpu().numpy()

                bleus = []

                for j in range(sel.shape[0]):
                    bleu = get_bleu(sel[j], dec_out[j], PAD_IDX)  # sel
                    bleus.append(bleu)
                    numerator, denominator = get_correct(
                        sel[j], dec_out[j], PAD_IDX)
                    acc_numerator_ep += numerator
                    acc_denominator_ep += denominator  # .detach().cpu().numpy() TODO: 10.8
                bleu_bh = np.average(bleus)
                bleu_ep += bleu_bh
                testi += 1
            bleu_ep /= testi  # num_batches
            print('testi: ', testi)
            print('Valid bleu: %.4f%%' % (bleu_ep * 100))
            # print(acc_denominator_ep)
            if acc_denominator_ep > 0:
                print('Valid acc: %.4f%%' %
                      ((acc_numerator_ep * 1.0 / acc_denominator_ep) * 100))
        # for debug TODO:
        if i % 5 == 0:
            torch.save(seq2seq.state_dict(),
                       args.seq2seq_save_path + str(i) + '.pt')
Пример #5
0
 def sort_key(ex):
     return data.interleave_keys(len(ex.src), len(ex.trg))
Пример #6
0
def load_train_data(data_path,
                    batch_size,
                    max_src_len,
                    max_trg_len,
                    use_cuda=False):
    # Note: sequential=False, use_vocab=False, since we use preprocessed inputs.
    src_field = Field(
        sequential=True,
        use_vocab=False,
        include_lengths=True,
        batch_first=True,
        pad_token=PAD,
        unk_token=UNK,
        init_token=None,
        eos_token=None,
    )
    trg_field = Field(
        sequential=True,
        use_vocab=False,
        include_lengths=True,
        batch_first=True,
        pad_token=PAD,
        unk_token=UNK,
        init_token=BOS,
        eos_token=EOS,
    )
    fields = (src_field, trg_field)
    device = torch.device("cuda:0" if use_cuda else "cpu")

    def filter_pred(example):
        if len(example.src) <= max_src_len and len(example.trg) <= max_trg_len:
            return True
        return False

    dataset = torch.load(data_path)
    train_src, train_tgt = dataset['train_src'], dataset['train_tgt']
    dev_src, dev_tgt = dataset['dev_src'], dataset['dev_tgt']

    train_data = ParallelDataset(
        train_src,
        train_tgt,
        fields=fields,
        filter_pred=filter_pred,
    )
    train_iter = Iterator(
        dataset=train_data,
        batch_size=batch_size,
        train=True,  # Variable(volatile=False)
        sort_key=lambda x: data.interleave_keys(len(x.src), len(x.trg)),
        repeat=False,
        shuffle=True,
        device=device)
    dev_data = ParallelDataset(
        dev_src,
        dev_tgt,
        fields=fields,
    )
    dev_iter = Iterator(
        dataset=dev_data,
        batch_size=batch_size,
        train=False,  # Variable(volatile=True)
        repeat=False,
        device=device,
        shuffle=False,
        sort=False,
    )

    return src_field, trg_field, train_iter, dev_iter
Пример #7
0
 def sort_key(ex):
     return interleave_keys(len(ex.sentence_a), len(ex.sentence_b))
Пример #8
0
                pretrained_embeddings=TEXT.vocab.vectors,
                freeze_embeddings=False,
                gpu=True)

# Annealing for KL term
kld_start_inc = 3000
kld_weight = 0.01
kld_max = 0.15
kld_inc = (kld_max - kld_weight) / (n_iter - kld_start_inc)

trainer = optim.Adam(model.vae_params, lr=lr)

train_iter = data.BucketIterator(
    dataset=train,
    batch_size=mb_size,
    sort_key=lambda x: data.interleave_keys(len(x.src), len(x.trg)))

#print("loading previous 200yelp_nofix_max40_predif.bin model")
model.load_state_dict(torch.load('models/{}.bin'.format('pre_yelp128_100')))


def save_model():
    if not os.path.exists('models/'):
        os.makedirs('models/')

    torch.save(model.state_dict(),
               'models/{}.bin'.format('pre_yelp128_100_addsamplepredif'))


pre_weight = 10
Пример #9
0
 def sort_key(ex):
     return data.interleave_keys(len(ex.sgn), len(ex.txt))
Пример #10
0
prefix_f = './escape.en-de.tok.5k'

parallel_dataset = TranslationDataset(path=prefix_f, exts=('.en', '.de'), fields=[('src', src), ('tgt', tgt)])

src.build_vocab(parallel_dataset, min_freq=5, max_size=15000)
tgt.build_vocab(parallel_dataset, min_freq=5, max_size=15000)

train, valid = parallel_dataset.split(split_ratio=0.97)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

BATCH_SIZE = 20

train_iterator, valid_iterator = BucketIterator.splits((train, valid), batch_size=BATCH_SIZE,
                                                    sort_key=lambda x: interleave_keys(len(x.src), len(x.tgt)),
                                                    device=device)



class Encoder(nn.Module):
    def __init__(self, hidden_dim: int, src_ntoken: int, dropout: float):
        super().__init__()

        self.hidden_dim = hidden_dim
        self.dropout = dropout
        self.src_ntoken = src_ntoken

        self.embedding = nn.Embedding(src_ntoken, hidden_dim, padding_idx=src.vocab.stoi['<pad>'])
        self.rnn = nn.GRU(hidden_dim, hidden_dim, bidirectional = True, batch_first=True)
        self.fc = nn.Linear(hidden_dim * 2, hidden_dim)
Пример #11
0
                              wv_type=word_vectors)
    answer_field.build_vocab(train_dataset,
                             test_dataset,
                             wv_dir=DATA_DIRECTORY,
                             wv_type=word_vectors)
    support_field.build_vocab(train_dataset, test_dataset)

    return train_dataset, test_dataset, context_field.vocab, answer_field.vocab, support_field.vocab


if __name__ == "__main__":
    from torchtext.data.iterator import Iterator

    train, test, text_vocab, answer_vocab, tag_vocab = bAbI()

    sort_key = lambda batch: data.interleave_keys(len(batch.context),
                                                  len(batch.question))
    train_iterator = Iterator(train,
                              1,
                              shuffle=True,
                              device=-1,
                              repeat=False,
                              sort_key=sort_key)
    valid_iterator = Iterator(test,
                              1,
                              device=-1,
                              train=False,
                              sort_key=sort_key)

    train_batch = next(iter(train_iterator))
    valid_batch = next(iter(valid_iterator))
Пример #12
0
def MAKE_ITER():
    print('start prepare data file')
    src_TEXT = data.Field(
        sequential=True,
        include_lengths=True,
        eos_token='<eos>',  #init_token='<sos>', 
        #lower=True, tokenize=lambda s: mecab.parse(s).rstrip().split()))
        lower=True)
    trg_TEXT = data.Field(sequential=True,
                          include_lengths=True,
                          init_token='<sos>',
                          eos_token='<eos>',
                          lower=True)
    #lower=True, tokenize='spacy')
    fields = [('src', src_TEXT), ('trg', trg_TEXT)]

    # Load Dataset Using torchtext
    import csv
    csv.field_size_limit(100000000)
    #print(csv.field_size_limit())

    train, val, test = data.TabularDataset.splits(
        path='./',
        train=TRAIN_FILE,
        validation=VALID_FILE,
        test=TEST_FILE,
        format='tsv',
        fields=fields,
        csv_reader_params={"quotechar": None})

    # Build Vocablary
    src_TEXT.build_vocab(
        train, max_size=VOCAB_SIZE)  #max_size=50000) min_freq=MIN_FREQ)
    trg_TEXT.build_vocab(train, max_size=VOCAB_SIZE)  #max_size=50000)

    #if REVERSE_TRANSLATION: #from target to source
    #    src_TEXT, trg_TEXT = trg_TEXT, src_TEXT

    # Make iterator
    if SET in ['continue', 'moto']:
        train_iter, valid_iter = data.Iterator.splits(
            (train, val),
            batch_size=BATCH_SIZE,
            repeat=False,
            sort_key=lambda x: data.interleave_keys(len(x.src), len(x.trg)),
            device=torch.device(DEVICE))
        test_iter = None
        # Check Dataset
        print('we have', len(train), 'training dataset')
        print('train', train[0].src, train[0].trg)
        print('val', val[1].src, val[1].trg)
        print('test', test[1].src, test[1].trg)

    else:
        train_iter = None
        valid_iter = None
        test_iter = data.Iterator(test,
                                  batch_size=BATCH_SIZE_TEST,
                                  repeat=False,
                                  shuffle=False,
                                  device=torch.device(DEVICE))

    # Check Vocab
    print('source vocaburary:', len(src_TEXT.vocab.itos))
    print('target vocaburary:', len(trg_TEXT.vocab.itos))
    print(src_TEXT.vocab.itos[:10])
    print(trg_TEXT.vocab.itos[:10])
    return train_iter, valid_iter, test_iter, src_TEXT, trg_TEXT