Ejemplo n.º 1
0
 def train(self, dataset, devset, sub_set, epochs):
     self.config.logger.info('Start training...')
     nbatches = (len(dataset) + self.config.batch_size -
                 1) // self.config.batch_size
     for epoch in range(1, epochs + 1):
         self.config.logger.info('Epoch %2d/%2d:' % (epoch, epochs))
         prog = Progbar(target=nbatches)  # nbatches
         for i, (sl, sr, desc, cand,
                 y) in enumerate(batch_iter(dataset,
                                            self.config.batch_size)):
             feed_dict = self._get_feed_dict(
                 sl,
                 sr,
                 desc,
                 cand,
                 True,
                 y=y,
                 lr=self.config.lr,
                 keep_prob=self.config.keep_prob)
             _, train_loss = self.sess.run([self.train_op, self.loss],
                                           feed_dict=feed_dict)
             prog.update(i + 1, [("train loss", train_loss)])
             if i % 1000 == 0:
                 self.evaluate(sub_set, batch_size=self.config.batch_size)
         self.config.lr *= self.config.lr_decay
         # build evaluate
         self.evaluate(devset, self.config.batch_size)
Ejemplo n.º 2
0
 def evaluate(self, dataset, batch_size):
     nbatches = (len(dataset) + batch_size - 1) // batch_size
     acc = []
     for sl, sr, desc, cand, y in batch_iter(dataset, batch_size):
         feed_dict = self._get_feed_dict(sl,
                                         sr,
                                         desc,
                                         cand,
                                         False,
                                         y=y,
                                         lr=self.config.lr,
                                         keep_prob=1.0)
         batch_acc = self.sess.run(self.accuracy, feed_dict=feed_dict)
         acc.append(batch_acc)
     # assert len(acc) == nbatches
     self.config.logger.info('\nAccuracy: {:04.2f}'.format(
         (sum(acc) / len(acc)) * 100))
Ejemplo n.º 3
0
    def initialize_rerank_features(self, examples, decode_results):
        hyp_examples = []
        # print('initializing features...', file=sys.stderr)
        for example, hyps in zip(examples, decode_results):
            for hyp_id, hyp in enumerate(hyps):
                hyp_example = Example(idx=None,
                                      src_sent=example.src_sent,
                                      tgt_code=hyp.code,
                                      tgt_actions=None,
                                      tgt_ast=None)
                hyp_examples.append(hyp_example)
                # hyp.tokenized_code = len(self.transition_system.tokenize_code(hyp.code))
                # hyp.code_token_count = len(hyp.code.split(' '))

                feat_vals = OrderedDict()
                hyp.rerank_feature_values = feat_vals

        for batch_examples in utils.batch_iter(hyp_examples, batch_size=128):
            for feat_name, feat in self.batched_features.items():
                batch_example_scores = feat.score(
                    batch_examples).data.cpu().tolist()
                for i, e in enumerate(batch_examples):
                    setattr(e, feat_name, batch_example_scores[i])

        e_ptr = 0
        for example, hyps in zip(examples, decode_results):
            for hyp in hyps:
                for feat_name, feat in self.batched_features.items():
                    hyp.rerank_feature_values[feat_name] = getattr(
                        hyp_examples[e_ptr], feat_name)
                e_ptr += 1

        for example, hyps in zip(examples, decode_results):
            for hyp_id, hyp in enumerate(hyps):
                for feat_name, feat in self.feat_map.items():
                    if not feat.is_batched:
                        feat_val = feat.get_feat_value(
                            example,
                            hyp,
                            hyp_id=hyp_id,
                            all_hyps=hyps,
                            transition_system=self.transition_system)
                        hyp.rerank_feature_values[feat_name] = feat_val
Ejemplo n.º 4
0
def evaluate_ppl(model, dev_data, batch_size=32):
    """
    Calculate validation perplexity
    """
    was_training = model.training
    model.eval()

    cum_loss = 0.
    total_word = 0.
    with torch.no_grad():
        for sents in batch_iter(dev_data, batch_size):
            loss = -model(sents).sum()
            cum_loss += loss.item()
            target_word = sum(len(s[1:]) for s in sents)
            total_word += target_word

        ppl = np.exp(cum_loss / total_word)

    if was_training:
        model.train()

    return ppl
from model.doub_enc import DoubEnc
from data.data_prepro import load_json
from model.utils import batch_iter

num_units = 300
lr = 0.001
grad_clip = 5.0
finetune_emb = True
ckpt_path = 'ckpt/'
model_name = 'doub_enc'
embedding_path = 'data_new/glove.840B.300d.filtered.npz'
batch_size = 32
epochs = 5

print("Loading data...")
train_set = load_json('data_new2/train.json')
dev_set = load_json('data_new2/dev.json')
print(len(train_set))
print(type(train_set[0]))
# print(train_set[0])
s1, _, idx, desc, cand, y = batch_iter(train_set, 2)
print(s1)
print(idx)
print(desc)
print(y)
print(cand)
Ejemplo n.º 6
0
def train(args):
    """
    Perform training
    """
    train_data = read_corpus(args.data_dir)
    dev_data = read_corpus(args.valid_dir)
    train_batch_size = args.batch_size
    clip_grad = args.clip_grad
    valid_freq = args.valid_freq
    save_path = args.model_save_path
    vocab = Vocab.load('vocab.json')
    device = torch.device("cuda:0" if args.cuda else "cpu")
    max_patience = 5
    max_num_trial = 5
    learning_rate_decay = 0.5
    max_epoch = 5000
    model_save_path = args.save_path

    model = RNNLM(embed_size=args.embed_size,
                  hidden_size=args.hidden_size,
                  vocab=vocab,
                  dropout_rate=args.dropout,
                  device=device,
                  tie_embed=args.tie_embed)

    model.train()

    #Xavier initialization
    for p in model.parameters():
        p.data.uniform_(-0.1, 0.1)

    model.to(device)

    #TODO Tunable learning rate
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    num_trial = 0
    train_iter = patience = cum_loss = report_loss = total_word = report_total_word = 0

    cum_examples = report_examples = epoch = valid_num = 0

    hist_valid_scores = []
    train_time = begin_time = time.time()

    print("Begin training")

    while True:
        epoch += 1

        for sent_batch in batch_iter(train_data,
                                     batch_size=train_batch_size,
                                     shuffle=True):
            train_iter += 1

            optimizer.zero_grad()

            batch_size = len(sent_batch)

            example_losses = -model(sent_batch)

            batch_loss = example_losses.sum()
            loss = batch_loss / batch_size

            loss.backward()

            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       clip_grad)

            optimizer.step()

            batch_losses_val = batch_loss.item()
            report_loss += batch_losses_val
            cum_loss += batch_losses_val

            tgt_word_num_to_predict = sum(len(s[1:]) for s in sent_batch)
            total_word += tgt_word_num_to_predict

            report_total_word += tgt_word_num_to_predict
            report_examples += batch_size
            cum_examples += batch_size

            if train_iter % 10 == 0:
                print('epoch %d, iter %d, avg.loss %.2f, avg. ppl %.2f' \
                    'cum. examples %d, speed %.2f words/sec, time elapsed %.2f sec' %(epoch, train_iter, report_loss/ report_examples, math.exp(report_loss / report_total_word),cum_examples, report_total_word / (time.time() - train_time), time.time() - begin_time), file = sys.stderr)

                train_time = time.time()
                report_loss = report_total_word = report_examples = 0.

            #VALIDATION
            if train_iter % valid_freq == 0:
                print(
                    "epoch %d, iter %d, cum. loss %.2f, cum. ppl %.2f, cum. examples %d"
                    % (epoch, train_iter, cum_loss / cum_examples,
                       np.exp(cum_loss / total_word), cum_examples),
                    file=sys.stderr)

                cum_loss = cum_examples = total_word = 0

                valid_num += 1

                print("Begin validation", file=sys.stderr)

                dev_ppl = evaluate_ppl(model, dev_data, batch_size=128)
                valid_metric = -dev_ppl

                print("validation: iter %d, dev. ppl %f" %
                      (train_iter, dev_ppl),
                      file=sys.stderr)

                is_better = len(hist_valid_scores
                                ) == 0 or valid_metric > max(hist_valid_scores)

                if is_better:
                    patience = 0
                    print("Save currently best model")
                    model.save(model_save_path)

                    torch.save(optimizer.state_dict(),
                               model_save_path + '.optim')

                elif patience < max_patience:
                    patience += 1

                    print("hit patience %d" % patience, file=sys.stderr)
                    if patience == max_patience:
                        num_trial += 1

                        if num_trial == max_num_trial:
                            print("early stop!", file=sys.stderr)
                            exit(0)

                        #Learning rate decay
                        lr = optimizer.param_groups[0][
                            'lr'] * learning_rate_decay

                        #load previous best model
                        params = torch.load(
                            model_save_path,
                            map_location=lambda storage, loc: storage)

                        model.load_state_dict(params['state_dict'])

                        model = model.to(device)

                        optimizer.load_state_dict(
                            torch.load(model_save_path + '.optim'))

                        #load learning rate
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr

                        patience = 0

                if epoch == max_epoch:
                    print("maximum epoch reached!", file=sys.stderr)
                    exit(0)