示例#1
0
def main():
    parser = argparse.ArgumentParser()
    opt = options.train_options(parser)
    opt = parser.parse_args()

    opt.cuda = torch.cuda.is_available()
    opt.device = None if opt.cuda else -1

    # 快速變更設定
    opt.exp_dir = './experiment/transformer-reinforce/use_billion'
    opt.load_vocab_from = './experiment/transformer/lang8-cor2err/vocab.pt'
    opt.build_vocab_from = './data/billion/billion.30m.model.vocab'

    opt.load_D_from = opt.exp_dir
    # opt.load_D_from = None

    # dataset params
    opt.max_len = 20

    # G params
    # opt.load_G_a_from = './experiment/transformer/lang8-err2cor/'
    # opt.load_G_b_from = './experiment/transformer/lang8-cor2err/'
    opt.d_word_vec = 300
    opt.d_model = 300
    opt.d_inner_hid = 600
    opt.n_head = 6
    opt.n_layers = 3
    opt.embs_share_weight = False
    opt.beam_size = 1
    opt.max_token_seq_len = opt.max_len + 2  # 包含<BOS>, <EOS>
    opt.n_warmup_steps = 4000

    # D params
    opt.embed_dim = opt.d_model
    opt.num_kernel = 100
    opt.kernel_sizes = [3, 4, 5, 6, 7]
    opt.dropout_p = 0.25

    # train params
    opt.batch_size = 1
    opt.n_epoch = 10

    if not os.path.exists(opt.exp_dir):
        os.makedirs(opt.exp_dir)
    logging.basicConfig(filename=opt.exp_dir + '/.log',
                        format=LOG_FORMAT,
                        level=logging.DEBUG)
    logging.getLogger().addHandler(logging.StreamHandler())

    logging.info('Use CUDA? ' + str(opt.cuda))
    logging.info(opt)

    # ---------- prepare dataset ----------

    def len_filter(example):
        return len(example.src) <= opt.max_len and len(
            example.tgt) <= opt.max_len

    EN = SentencePieceField(init_token=Constants.BOS_WORD,
                            eos_token=Constants.EOS_WORD,
                            batch_first=True,
                            include_lengths=True)

    train = datasets.TranslationDataset(path='./data/dualgan/train',
                                        exts=('.billion.sp', '.use.sp'),
                                        fields=[('src', EN), ('tgt', EN)],
                                        filter_pred=len_filter)
    val = datasets.TranslationDataset(path='./data/dualgan/val',
                                      exts=('.billion.sp', '.use.sp'),
                                      fields=[('src', EN), ('tgt', EN)],
                                      filter_pred=len_filter)
    train_lang8, val_lang8 = Lang8.splits(exts=('.err.sp', '.cor.sp'),
                                          fields=[('src', EN), ('tgt', EN)],
                                          train='test',
                                          validation='test',
                                          test=None,
                                          filter_pred=len_filter)

    # 讀取 vocabulary(確保一致)
    try:
        logging.info('Load voab from %s' % opt.load_vocab_from)
        EN.load_vocab(opt.load_vocab_from)
    except FileNotFoundError:
        EN.build_vocab_from(opt.build_vocab_from)
        EN.save_vocab(opt.load_vocab_from)

    logging.info('Vocab len: %d' % len(EN.vocab))

    # 檢查Constants是否有誤
    assert EN.vocab.stoi[Constants.BOS_WORD] == Constants.BOS
    assert EN.vocab.stoi[Constants.EOS_WORD] == Constants.EOS
    assert EN.vocab.stoi[Constants.PAD_WORD] == Constants.PAD
    assert EN.vocab.stoi[Constants.UNK_WORD] == Constants.UNK

    # ---------- init model ----------

    # G = build_G(opt, EN, EN)
    hidden_size = 512
    bidirectional = True
    encoder = EncoderRNN(len(EN.vocab),
                         opt.max_len,
                         hidden_size,
                         n_layers=1,
                         bidirectional=bidirectional)
    decoder = DecoderRNN(len(EN.vocab),
                         opt.max_len,
                         hidden_size * 2 if bidirectional else 1,
                         n_layers=1,
                         dropout_p=0.2,
                         use_attention=True,
                         bidirectional=bidirectional,
                         eos_id=Constants.EOS,
                         sos_id=Constants.BOS)
    G = Seq2seq(encoder, decoder)
    for param in G.parameters():
        param.data.uniform_(-0.08, 0.08)

    # optim_G = ScheduledOptim(optim.Adam(
    #     G.get_trainable_parameters(),
    #     betas=(0.9, 0.98), eps=1e-09),
    #     opt.d_model, opt.n_warmup_steps)
    optim_G = optim.Adam(G.parameters(), lr=1e-4, betas=(0.9, 0.98), eps=1e-09)
    loss_G = NLLLoss(size_average=False)
    if torch.cuda.is_available():
        loss_G.cuda()

    # # 預先訓練D
    if opt.load_D_from:
        D = load_model(opt.load_D_from)
    else:
        D = build_D(opt, EN)
    optim_D = torch.optim.Adam(D.parameters(), lr=1e-4)

    def get_criterion(vocab_size):
        ''' With PAD token zero weight '''
        weight = torch.ones(vocab_size)
        weight[Constants.PAD] = 0
        return nn.CrossEntropyLoss(weight, size_average=False)

    crit_G = get_criterion(len(EN.vocab))
    crit_D = nn.BCELoss()

    if opt.cuda:
        G.cuda()
        D.cuda()
        crit_G.cuda()
        crit_D.cuda()

    # ---------- train ----------

    trainer_D = trainers.DiscriminatorTrainer()

    if not opt.load_D_from:
        for epoch in range(1):
            logging.info('[Pretrain D Epoch %d]' % epoch)

            pool = helper.DiscriminatorDataPool(opt.max_len, D.min_len,
                                                Constants.PAD)

            # 將資料塞進pool中
            train_iter = data.BucketIterator(dataset=train,
                                             batch_size=opt.batch_size,
                                             device=opt.device,
                                             sort_key=lambda x: len(x.src),
                                             repeat=False)
            pool.fill(train_iter)

            # train D
            trainer_D.train(D,
                            train_iter=pool.batch_gen(),
                            crit=crit_D,
                            optimizer=optim_D)
            pool.reset()

        Checkpoint(model=D,
                   optimizer=optim_D,
                   epoch=0,
                   step=0,
                   input_vocab=EN.vocab,
                   output_vocab=EN.vocab).save(opt.exp_dir)

    def eval_D():
        pool = helper.DiscriminatorDataPool(opt.max_len, D.min_len,
                                            Constants.PAD)
        val_iter = data.BucketIterator(dataset=val,
                                       batch_size=opt.batch_size,
                                       device=opt.device,
                                       sort_key=lambda x: len(x.src),
                                       repeat=False)
        pool.fill(val_iter)
        trainer_D.evaluate(D, val_iter=pool.batch_gen(), crit=crit_D)

        # eval_D()

    # Train G
    ALPHA = 0
    for epoch in range(100):
        logging.info('[Epoch %d]' % epoch)
        train_iter = data.BucketIterator(dataset=train,
                                         batch_size=1,
                                         device=opt.device,
                                         sort_within_batch=True,
                                         sort_key=lambda x: len(x.src),
                                         repeat=False)

        for step, batch in enumerate(train_iter):
            src_seq = batch.src[0]
            src_length = batch.src[1]
            tgt_seq = src_seq[0].clone()
            # gold = tgt_seq[:, 1:]

            optim_G.zero_grad()
            loss_G.reset()

            decoder_outputs, decoder_hidden, other = G.rollout(src_seq,
                                                               None,
                                                               None,
                                                               n_rollout=1)
            for i, step_output in enumerate(decoder_outputs):
                batch_size = tgt_seq.size(0)
                # print(step_output)

                # loss_G.eval_batch(step_output.contiguous().view(batch_size, -1), tgt_seq[:, i + 1])

            softmax_output = torch.exp(
                torch.cat([x for x in decoder_outputs], dim=0)).unsqueeze(0)
            softmax_output = helper.stack(softmax_output, 8)

            print(softmax_output)
            rollout = softmax_output.multinomial(1)
            print(rollout)

            tgt_seq = helper.pad_seq(tgt_seq.data,
                                     max_len=len(decoder_outputs) + 1,
                                     pad_value=Constants.PAD)
            tgt_seq = autograd.Variable(tgt_seq)
            for i, step_output in enumerate(decoder_outputs):
                batch_size = tgt_seq.size(0)
                loss_G.eval_batch(
                    step_output.contiguous().view(batch_size, -1),
                    tgt_seq[:, i + 1])
            G.zero_grad()
            loss_G.backward()
            optim_G.step()

            if step % 100 == 0:
                pred = torch.cat([x for x in other['sequence']], dim=1)
                print('[step %d] loss_rest %.4f' %
                      (epoch * len(train_iter) + step, loss_G.get_loss()))
                print('%s -> %s' %
                      (EN.reverse(tgt_seq.data)[0], EN.reverse(pred.data)[0]))

    # Reinforce Train G
    for p in D.parameters():
        p.requires_grad = False
示例#2
0
    logging.info('[Epoch %d]' % epoch)
    train_iter = data.BucketIterator(dataset=train,
                                     batch_size=16,
                                     device=opt.device,
                                     sort_within_batch=True,
                                     sort_key=lambda x: len(x.src),
                                     repeat=False)

    for step, batch in enumerate(train_iter):
        src_seq = batch.src[0]
        src_length = batch.src[1]
        tgt_seq = src_seq.clone()  # a -> b' -> a

        decoder_outputs, decoder_hiddens, other = G.forward(
            src_seq, src_length.tolist(), target_variable=None)
        crit_G.reset()
        for i, step_output in enumerate(decoder_outputs):
            batch_size = tgt_seq.size(0)
            crit_G.eval_batch(step_output.contiguous().view(batch_size, -1),
                              tgt_seq[:, i + 1])

        optim_G.zero_grad()
        crit_G.backward()
        optim_G.step()

        if step % 100 == 0:
            pred = torch.cat([x for x in other['sequence']], dim=1)
            print('[step %d] loss %.4f' %
                  (epoch * len(train_iter) + step, crit_G.get_loss()))
            print('%s -> %s' %
                  (EN.reverse(tgt_seq.data)[0], EN.reverse(pred.data)[0]))
def train(opt):
    LOG_FORMAT = '%(asctime)s %(levelname)-8s %(message)s'
    logging.basicConfig(format=LOG_FORMAT,
                        level=getattr(logging, opt.log_level.upper()))
    logging.info(opt)
    if int(opt.GPU) >= 0:
        torch.cuda.set_device(int(opt.GPU))
    if opt.load_checkpoint is not None:
        logging.info("loading checkpoint from {}".format(
            os.path.join(opt.expt_dir, Checkpoint.CHECKPOINT_DIR_NAME,
                         opt.load_checkpoint)))
        checkpoint_path = os.path.join(opt.expt_dir,
                                       Checkpoint.CHECKPOINT_DIR_NAME,
                                       opt.load_checkpoint)
        checkpoint = Checkpoint.load(checkpoint_path)
        seq2tree = checkpoint.model
        input_vocab = checkpoint.input_vocab

    else:
        # Prepare dataset
        src = SourceField()
        nt = NTField()
        pos = PosField()
        tgt_tree = TreeField()
        comp = CompField()
        max_len = opt.max_len

        def len_filter(example):
            return len(example.src) <= max_len

        train = torchtext.data.TabularDataset(path=opt.train_path,
                                              format='tsv',
                                              fields=[('src', src), ('nt', nt),
                                                      ('pos', pos),
                                                      ('tree', tgt_tree)],
                                              filter_pred=len_filter)
        dev = torchtext.data.TabularDataset(path=opt.dev_path,
                                            format='tsv',
                                            fields=[('src', src), ('nt', nt),
                                                    ('pos', pos),
                                                    ('tree', tgt_tree)],
                                            filter_pred=len_filter)
        src.build_vocab(train, max_size=50000)
        comp.build_vocab(train, max_size=50000)
        nt.build_vocab(train, max_size=50000)
        pos.build_vocab(train, max_size=50000)
        # src_tree.build_vocab(train, max_size=50000)
        pos_in_nt = set()
        for Pos in pos.vocab.stoi:
            if nt.vocab.stoi[Pos] > 1:
                pos_in_nt.add(nt.vocab.stoi[Pos])
        hidden_size = opt.hidden_size
        input_vocab = src.vocab
        nt_vocab = nt.vocab

        def tree_to_id(tree):
            tree.set_label(nt_vocab.stoi[tree.label()])
            if len(tree) == 1 and str(tree[0])[0] is not '(':
                tree[0] = input_vocab.stoi[tree[0]]
                return
            else:
                for subtree in tree:
                    tree_to_id(subtree)
                tree.append(Tree(nt_vocab.stoi['<eos>'], []))
                return tree

        # train.examples = [str(tree_to_id(ex.tree)) for ex in train.examples]
        # dev.examples = [str(tree_to_id(ex.tree)) for ex in dev.examples]
        for ex in train.examples:
            ex.tree = str(tree_to_id(Tree.fromstring(ex.tree)))
        for ex in dev.examples:
            ex.tree = str(tree_to_id(Tree.fromstring(ex.tree)))
        # train.examples = [tree_to_id(Tree.fromstring(ex.tree)) for ex in train.examples]
        # dev.examples = [str(tree_to_id(Tree.fromstring(ex.tree))) for ex in dev.examples]
        if opt.word_embedding is not None:
            input_vocab.load_vectors([opt.word_embedding])

        loss = NLLLoss()
        if torch.cuda.is_available():
            loss.cuda()
        loss.reset()
        seq2tree = None
        optimizer = None
        if not opt.resume:
            # Initialize model
            bidirectional = opt.bidirectional_encoder
            encoder = EncoderRNN(len(src.vocab),
                                 opt.word_embedding_size,
                                 max_len,
                                 hidden_size,
                                 bidirectional=bidirectional,
                                 variable_lengths=True)
            decoder = DecoderTree(len(src.vocab),
                                  opt.word_embedding_size,
                                  opt.nt_embedding_size,
                                  len(nt.vocab),
                                  max_len,
                                  hidden_size *
                                  2 if bidirectional else hidden_size,
                                  sos_id=nt_vocab.stoi['<sos>'],
                                  eos_id=nt_vocab.stoi['<eos>'],
                                  dropout_p=0.2,
                                  use_attention=True,
                                  bidirectional=bidirectional,
                                  pos_in_nt=pos_in_nt)

            seq2tree = Seq2tree(encoder, decoder)
            if torch.cuda.is_available():
                seq2tree.cuda()

            for param in seq2tree.parameters():
                param.data.uniform_(-0.08, 0.08)
                # encoder.embedding.weight.data.set_(input_vocab.vectors)
                # encoder.embedding.weight.data.set_(output_vocab.vectors)

            # Optimizer and learning rate scheduler can be customized by
            # explicitly constructing the objects and pass to the trainer.
            #
            # optimizer = Optimizer(torch.optim.Adam(seq2seq.parameters()), max_grad_norm=5)
            # scheduler = StepLR(optimizer.optimizer, 1)
            # optimizer.set_scheduler(scheduler)

            optimizer = Optimizer(optim.Adam(seq2tree.parameters(), lr=opt.lr),
                                  max_grad_norm=5)
        # train
        t = SupervisedTrainer(loss=loss,
                              batch_size=opt.batch_size,
                              checkpoint_every=opt.checkpoint_every,
                              print_every=10,
                              expt_dir=opt.expt_dir,
                              lr=opt.lr)

        seq2tree = t.train(seq2tree,
                           train,
                           num_epochs=opt.epoch,
                           dev_data=dev,
                           optimizer=optimizer,
                           teacher_forcing_ratio=0,
                           resume=opt.resume)

    predictor = Predictor(seq2tree, input_vocab, nt_vocab)
    return predictor, dev, train
    # input_vocab.load_vectors([])
    #
    input_vocab.load_vectors(['glove.840B.300d'])

    #
    input_vocab.vectors[input_vocab.stoi['<unk>']] = torch.Tensor(hidden_size).uniform_(-0.8,0.8)#<unk>

    # Prepare loss

    # loss = NLLLoss(weight, pad)#Perplexity(weight, pad)
    loss = NLLLoss()

    if torch.cuda.is_available():
        loss.cuda()
    loss.reset()
    seq2tree = None

    if not opt.resume:
        # Initialize model

        bidirectional = True
        encoder = EncoderRNN(len(src.vocab), max_len, hidden_size,
                             bidirectional=bidirectional, variable_lengths=True)
        decoder = DecoderTree(len(src.vocab), len(nt.vocab),max_len, hidden_size * 2 if bidirectional else hidden_size,
                             dropout_p=0.2, use_attention=True, bidirectional=bidirectional, pos_in_nt = pos_in_nt)

        seq2tree = Seq2tree(encoder, decoder)
        if torch.cuda.is_available():
            seq2tree.cuda()
示例#5
0
def eval_fa_equiv(model, data, input_vocab, output_vocab):
    loss = NLLLoss()
    batch_size = 1

    model.eval()

    loss.reset()
    match = 0
    total = 0

    device = None if torch.cuda.is_available() else -1
    batch_iterator = torchtext.data.BucketIterator(
        dataset=data,
        batch_size=batch_size,
        sort=False,
        sort_key=lambda x: len(x.src),
        device=device,
        train=False)
    tgt_vocab = data.fields[seq2seq.tgt_field_name].vocab
    pad = tgt_vocab.stoi[data.fields[seq2seq.tgt_field_name].pad_token]

    predictor = Predictor(model, input_vocab, output_vocab)

    num_samples = 0
    perfect_samples = 0
    dfa_perfect_samples = 0

    match = 0
    total = 0

    with torch.no_grad():
        for batch in batch_iterator:
            num_samples = num_samples + 1

            input_variables, input_lengths = getattr(batch,
                                                     seq2seq.src_field_name)

            target_variables = getattr(batch, seq2seq.tgt_field_name)

            target_string = decode_tensor(target_variables, output_vocab)

            #target_string = target_string + " <eos>"

            input_string = decode_tensor(input_variables, input_vocab)

            generated_string = ' '.join([
                x for x in predictor.predict(input_string.strip().split())[:-1]
                if x != '<pad>'
            ])

            #str(pos_example)[2]

            generated_string = refine_outout(generated_string)

            #str(pos_example)[2]

            pos_example = subprocess.check_output([
                'python2', 'regexDFAEquals.py', '--gold',
                '{}'.format(target_string), '--predicted',
                '{}'.format(generated_string)
            ])

            if target_string == generated_string:
                perfect_samples = perfect_samples + 1
                dfa_perfect_samples = dfa_perfect_samples + 1
            elif str(pos_example)[2] == '1':
                dfa_perfect_samples = dfa_perfect_samples + 1

            target_tokens = target_string.split()
            generated_tokens = generated_string.split()

            shorter_len = min(len(target_tokens), len(generated_tokens))

            for idx in range(len(generated_tokens)):
                total = total + 1

                if idx >= len(target_tokens):
                    total = total + 1
                elif target_tokens[idx] == generated_tokens[idx]:
                    match = match + 1

            if total == 0:
                accuracy = float('nan')
            else:
                accuracy = match / total

            string_accuracy = perfect_samples / num_samples
            dfa_accuracy = dfa_perfect_samples / num_samples

        f = open('./time_logs/log_score_time.txt', 'a')
        f.write('{}\n'.format(dfa_accuracy))
        f.close()