Exemple #1
0
    def run_epoch(self, sess, train, dev, tags, epoch):
        """
        Performs one complete pass over the train set and evaluate on dev
        Args:
            sess: tensorflow session
            train: dataset that yields tuple of sentences, tags
            dev: dataset
            tags: {tag: index} dictionary
            epoch: (int) number of the epoch
        """
        nbatches = (
            len(train) + self.config.batch_size - 1) // self.config.batch_size
        prog = Progbar(target=nbatches)
        for i, (words, labels
                ) in enumerate(minibatches(train, self.config.batch_size)):
            fd, _ = self.get_feed_dict(words, labels, self.config.LR,
                                       self.config.dropout)

            _, train_loss, summary = sess.run(
                [self.train_op, self.loss, self.merged], feed_dict=fd)

            prog.update(i + 1, [("train loss", train_loss)])

            # tensorboard
            if i % 10 == 0:
                self.file_writer.add_summary(summary, epoch * nbatches + i)

        acc, f1 = self.run_evaluate(sess, dev, tags)
        self.logger.info(
            "- dev acc {:04.2f} - f1 {:04.2f}".format(100 * acc, 100 * f1))
        return acc, f1
def _valid(data_loader, model, criterion, optimizer, epoch, opt, is_train=False):
    progbar = Progbar(title='Validating', target=len(data_loader), batch_size=opt.batch_size,
                      total_examples=len(data_loader.dataset))
    if is_train:
        model.train()
    else:
        model.eval()

    losses = []

    # Note that the data should be shuffled every time
    for i, batch in enumerate(data_loader):
        src = batch.src
        trg = batch.trg

        if torch.cuda.is_available():
            src.cuda()
            trg.cuda()

        decoder_probs, _, _ = model.forward(src, trg, must_teacher_forcing=True)

        start_time = time.time()

        loss = criterion(
            decoder_probs.contiguous().view(-1, opt.vocab_size),
            trg[:, 1:].contiguous().view(-1)
        )
        print("--loss calculation --- %s" % (time.time() - start_time))

        start_time = time.time()
        if is_train:
            optimizer.zero_grad()
            loss.backward()
            if opt.max_grad_norm > 0:
                torch.nn.utils.clip_grad_norm(model.parameters(), opt.max_grad_norm)
            optimizer.step()

        print("--backward function - %s seconds ---" % (time.time() - start_time))

        losses.append(loss.data[0])

        start_time = time.time()
        progbar.update(epoch, i, [('valid_loss', loss.data[0])])
        print("-progbar.update --- %s" % (time.time() - start_time))

    return losses
def evaluate_greedy(model, data_loader, test_examples, opt):
    model.eval()

    logging.info('======================  Checking GPU Availability  =========================')
    if torch.cuda.is_available():
        logging.info('Running on GPU!')
        model.cuda()
    else:
        logging.info('Running on CPU!')

    logging.info('======================  Start Predicting  =========================')
    progbar = Progbar(title='Testing', target=len(data_loader), batch_size=data_loader.batch_size,
                      total_examples=len(data_loader.dataset))

    '''
    Note here each batch only contains one data example, thus decoder_probs is flattened
    '''
    for i, (batch, example) in enumerate(zip(data_loader, test_examples)):
        src = batch.src

        logging.info('======================  %d  =========================' % (i + 1))
        logging.info('\nSource text: \n %s\n' % (' '.join([opt.id2word[wi] for wi in src.data.numpy()[0]])))

        if torch.cuda.is_available():
            src.cuda()

        # trg = Variable(torch.from_numpy(np.zeros((src.size(0), opt.max_sent_length), dtype='int64')))
        trg = Variable(torch.LongTensor([[opt.word2id[pykp.io.BOS_WORD]] * opt.max_sent_length]))

        max_words_pred = model.greedy_predict(src, trg)
        progbar.update(None, i, [])

        sentence_pred = [opt.id2word[x] for x in max_words_pred]
        sentence_real = example['trg_str']

        if '</s>' in sentence_real:
            index = sentence_real.index('</s>')
            sentence_pred = sentence_pred[:index]

        logging.info('\t\tPredicted : %s ' % (' '.join(sentence_pred)))
        logging.info('\t\tReal : %s ' % (sentence_real))
def _valid_error(data_loader, model, criterion, epoch, opt):
    progbar = Progbar(title='Validating', target=len(data_loader), batch_size=data_loader.batch_size,
                      total_examples=len(data_loader.dataset))
    model.eval()

    losses = []

    # Note that the data should be shuffled every time
    for i, batch in enumerate(data_loader):
        # if i >= 100:
        #     break

        one2many_batch, one2one_batch = batch
        src, trg, trg_target, trg_copy_target, src_ext, oov_lists = one2one_batch

        if torch.cuda.is_available():
            src                = src.cuda()
            trg                = trg.cuda()
            trg_target         = trg_target.cuda()
            trg_copy_target    = trg_copy_target.cuda()
            src_ext            = src_ext.cuda()

        decoder_log_probs, _, _ = model.forward(src, trg, src_ext)

        if not opt.copy_model:
            loss = criterion(
                decoder_log_probs.contiguous().view(-1, opt.vocab_size),
                trg_target.contiguous().view(-1)
            )
        else:
            loss = criterion(
                decoder_log_probs.contiguous().view(-1, opt.vocab_size + opt.max_unk_words),
                trg_copy_target.contiguous().view(-1)
            )
        losses.append(loss.data[0])

        progbar.update(epoch, i, [('valid_loss', loss.data[0]), ('PPL', loss.data[0])])

    return losses
Exemple #5
0
 def train(self, trainset, devset, testset, batch_size=64, epochs=50, shuffle=True):
     self.logger.info('Start training...')
     init_lr = self.cfg.lr  # initial learning rate, used for decay learning rate
     best_score = 0.0  # record the best score
     best_score_epoch = 1  # record the epoch of the best score obtained
     no_imprv_epoch = 0  # no improvement patience counter
     for epoch in range(self.start_epoch, epochs + 1):
         self.logger.info('Epoch %2d/%2d:' % (epoch, epochs))
         progbar = Progbar(target=(len(trainset) + batch_size - 1) // batch_size)  # number of batches
         if shuffle:
             np.random.shuffle(trainset)  # shuffle training dataset each epoch
         # training each epoch
         for i, (words, labels) in enumerate(batch_iter(trainset, batch_size)):
             feed_dict = self._get_feed_dict(words, labels, lr=self.cfg.lr, is_train=True)
             _, train_loss = self.sess.run([self.train_op, self.loss], feed_dict=feed_dict)
             progbar.update(i + 1, [("train loss", train_loss)])
         if devset is not None:
             self.evaluate(devset, batch_size)
         cur_score = self.evaluate(testset, batch_size, is_devset=False)
         # learning rate decay
         if self.cfg.decay_lr:
             self.cfg.lr = init_lr / (1 + self.cfg.lr_decay * epoch)
         # performs model saving and evaluating on test dataset
         if cur_score > best_score:
             no_imprv_epoch = 0
             self.save_session(epoch)
             best_score = cur_score
             best_score_epoch = epoch
             self.logger.info(' -- new BEST score on TEST dataset: {:05.3f}'.format(best_score))
         else:
             no_imprv_epoch += 1
             if no_imprv_epoch >= self.cfg.no_imprv_patience:
                 self.logger.info('early stop at {}th epoch without improvement for {} epochs, BEST score: '
                                  '{:05.3f} at epoch {}'.format(epoch, no_imprv_epoch, best_score, best_score_epoch))
                 break
     self.logger.info('Training process done...')
Exemple #6
0
def run_simultrans(model,
                   options_file=None,
                   config=None,
                   policy=None,
                   id=None,
                   remote=False):
    # check envoriments
    check_env()
    if id is not None:
        fcon = '.config/{}.conf'.format(id)
        if os.path.exists(fcon):
            print 'load config files'
            policy, config = pkl.load(open(fcon, 'r'))

    # ======================================================================= #
    # load model model_options
    # ======================================================================= #
    _model = model
    model  = '.pretrained/{}'.format(model)

    if options_file is not None:
        with open(options_file, 'rb') as f:
            options = pkl.load(f)
    else:
        with open('%s.pkl' % model, 'rb') as f:
            options = pkl.load(f)
    options['birnn'] = True

    print 'load options...'
    for w, p in sorted(options.items(), key=lambda x:x[0]):
        print '{}: {}'.format(w, p)

    # load detail settings from option file:
    dictionary, dictionary_target = options['dictionaries']

    def _iter(fname):
        with open(fname, 'r') as f:
            for line in f:
                words = line.strip().split()
                x = map(lambda w: word_dict[w] if w in word_dict else 1, words)
                x = map(lambda ii: ii if ii < options['n_words'] else 1, x)
                x += [0]
                yield x

    def _check_length(fname):
        f = open(fname, 'r')
        count = 0
        for _ in f:
            count += 1
        f.close()

        return count

    # load source dictionary and invert
    with open(dictionary, 'rb') as f:
        word_dict = pkl.load(f)
    word_idict = dict()
    for kk, vv in word_dict.iteritems():
        word_idict[vv] = kk
    word_idict[0] = '<eos>'
    word_idict[1] = 'UNK'

    # load target dictionary and invert
    with open(dictionary_target, 'rb') as f:
        word_dict_trg = pkl.load(f)
    word_idict_trg = dict()
    for kk, vv in word_dict_trg.iteritems():
        word_idict_trg[vv] = kk
    word_idict_trg[0] = '<eos>'
    word_idict_trg[1] = 'UNK'

    # ======================================================================== #
    # Build a Translator
    # ======================================================================== #

    # allocate model parameters
    params  = init_params(options)
    params  = load_params(model, params)
    tparams = init_tparams(params)

    # print 'build the model for computing cost (full source sentence).'
    trng, use_noise, \
    _x, _x_mask, _y, _y_mask, \
    opt_ret, \
    cost, f_cost = build_model(tparams, options)
    print 'done.'

    # functions for sampler
    # f_sim_ctx, f_sim_init, f_sim_next = build_simultaneous_sampler(tparams, options, trng)
    f_sim_ctx, f_sim_init, f_sim_next = build_noisy_sampler(tparams, options, trng)
    print 'build sampler done.'

    # check the ID:
    policy['base'] = _model
    _policy        = Policy(trng, options, policy, config,
                            n_out=options['dim'],
                            recurrent=True, id=id)


    # DATASET
    trainIter = TextIterator(options['datasets'][0], options['datasets'][1],
                             options['dictionaries'][0], options['dictionaries'][1],
                             n_words_source=options['n_words_src'], n_words_target=options['n_words'],
                             batch_size=config['batchsize'],
                             maxlen=options['maxlen'])

    train_num = trainIter.num

    validIter = TextIterator(options['valid_datasets'][0], options['valid_datasets'][1],
                             options['dictionaries'][0], options['dictionaries'][1],
                             n_words_source=options['n_words_src'], n_words_target=options['n_words'],
                             batch_size=1,
                             maxlen=options['maxlen'])

    valid_num = validIter.num

    valid_    = options['valid_datasets'][0]
    valid_num = _check_length(valid_)
    print 'training set {} lines / validation set {} lines'.format(train_num, valid_num)
    print 'use the reward function {}'.format(chr(config['Rtype'] + 65))

    # Translator model
    def _translate(src, trg, train=False, samples=80):
        ret = noisy_decoding(
            f_sim_ctx, f_sim_init,
            f_sim_next, f_cost,
            src, trg, word_idict_trg, n_samples=samples,
            train=train,
            _policy=_policy)

        if not train:
            sample, score, actions, R, tracks, attentions = ret
            return sample, score, actions, R, tracks
        else:
            sample, score, actions, R, info = ret
            return sample, score, actions, R, info


    # ======================================================================== #
    # Main Loop: Run
    # ======================================================================== #
    print 'Start Simultaneous Translator...'
    probar           = Progbar(train_num / config['batchsize'],  with_history=False)

    # freqs
    save_freq        = 2000
    sample_freq      = 10
    valid_freq       = 1000
    valid_size       = 200
    display_freq     = 50

    history, last_it = _policy.load()
    time0            = timer()

    for it, (srcs, trgs) in enumerate(trainIter):  # only one sentence each iteration
        if it < last_it:  # go over the scanned lines.
            continue

        samples, scores, actions, rewards, info = _translate(srcs, trgs, train=True)
        if it % sample_freq == 0:

            print '\nModel has been trained for {} seconds'.format(timer() - time0)
            print 'source: ', _bpe2words(_seqs2words([srcs[0]], word_idict))[0]
            print 'target: ', _bpe2words(_seqs2words([trgs[0]], word_idict_trg))[0]

            # obtain the translation results
            samples = _bpe2words(_seqs2words(samples, word_idict_trg))

            print '---'
            print 'sample: ', samples[40]
            print 'sample: ', samples[60]

        values = [(w, info[w]) for w in info]
        probar.update(it + 1, values=values)

        # NaN detector
        for w in info:
            if numpy.isnan(info[w]) or numpy.isinf(info[w]):
                raise RuntimeError, 'NaN/INF is detected!! {} : ID={}'.format(w, id)
def train_model(model, optimizer, criterion, train_data_loader, valid_data_loader, test_data_loader, opt):
    generator = SequenceGenerator(model,
                                  eos_id=opt.word2id[pykp.io.EOS_WORD],
                                  beam_size=opt.beam_size,
                                  max_sequence_length=opt.max_sent_length
                                  )

    logging.info('======================  Checking GPU Availability  =========================')
    if torch.cuda.is_available():
        if isinstance(opt.gpuid, int):
            opt.gpuid = [opt.gpuid]
        logging.info('Running on GPU! devices=%s' % str(opt.gpuid))
        # model = nn.DataParallel(model, device_ids=opt.gpuid)
    else:
        logging.info('Running on CPU!')

    logging.info('======================  Start Training  =========================')

    checkpoint_names        = []
    train_history_losses    = []
    valid_history_losses    = []
    test_history_losses     = []
    # best_loss = sys.float_info.max # for normal training/testing loss (likelihood)
    best_loss               = 0.0 # for f-score
    stop_increasing         = 0

    train_losses = []
    total_batch = 0
    early_stop_flag = False

    if opt.train_from:
        state_path = opt.train_from.replace('.model', '.state')
        logging.info('Loading training state from: %s' % state_path)
        if os.path.exists(state_path):
            (epoch, total_batch, best_loss, stop_increasing, checkpoint_names, train_history_losses, valid_history_losses,
                        test_history_losses) = torch.load(open(state_path, 'rb'))
            opt.start_epoch = epoch

    for epoch in range(opt.start_epoch , opt.epochs):
        if early_stop_flag:
            break

        progbar = Progbar(title='Training', target=len(train_data_loader), batch_size=train_data_loader.batch_size,
                          total_examples=len(train_data_loader.dataset))

        for batch_i, batch in enumerate(train_data_loader):
            model.train()
            batch_i += 1 # for the aesthetics of printing
            total_batch += 1
            one2many_batch, one2one_batch = batch
            src, trg, trg_target, trg_copy_target, src_ext, oov_lists = one2one_batch
            max_oov_number = max([len(oov) for oov in oov_lists])

            print("src size - ",src.size())
            print("target size - ",trg.size())

            if torch.cuda.is_available():
                src = src.cuda()
                trg = trg.cuda()
                trg_target = trg_target.cuda()
                trg_copy_target = trg_copy_target.cuda()
                src_ext = src_ext.cuda()

            optimizer.zero_grad()

            '''
            Training with Maximum Likelihood (word-level error)
            '''
            decoder_log_probs, _, _ = model.forward(src, trg, src_ext, oov_lists)

            # simply average losses of all the predicitons
            # IMPORTANT, must use logits instead of probs to compute the loss, otherwise it's super super slow at the beginning (grads of probs are small)!
            start_time = time.time()

            if not opt.copy_model:
                ml_loss = criterion(
                    decoder_log_probs.contiguous().view(-1, opt.vocab_size),
                    trg_target.contiguous().view(-1)
                )
            else:
                ml_loss = criterion(
                    decoder_log_probs.contiguous().view(-1, opt.vocab_size + max_oov_number),
                    trg_copy_target.contiguous().view(-1)
                )

            '''
            Training with Reinforcement Learning (instance-level reward f-score)
            '''
            src_list, trg_list, _, trg_copy_target_list, src_oov_map_list, oov_list, src_str_list, trg_str_list = one2many_batch

            if torch.cuda.is_available():
                src_list = src_list.cuda()
                src_oov_map_list = src_oov_map_list.cuda()
            rl_loss = get_loss_rl()

            start_time = time.time()
            ml_loss.backward()
            print("--backward- %s seconds ---" % (time.time() - start_time))

            if opt.max_grad_norm > 0:
                pre_norm = torch.nn.utils.clip_grad_norm(model.parameters(), opt.max_grad_norm)
                after_norm = (sum([p.grad.data.norm(2) ** 2 for p in model.parameters() if p.grad is not None])) ** (1.0 / 2)
                logging.info('clip grad (%f -> %f)' % (pre_norm, after_norm))

            optimizer.step()

            train_losses.append(ml_loss.data[0])

            progbar.update(epoch, batch_i, [('train_loss', ml_loss.data[0]), ('PPL', ml_loss.data[0])])

            if batch_i > 1 and batch_i % opt.report_every == 0:
                logging.info('======================  %d  =========================' % (batch_i))

                logging.info('Epoch : %d Minibatch : %d, Loss=%.5f' % (epoch, batch_i, np.mean(ml_loss.data[0])))
                sampled_size = 2
                logging.info('Printing predictions on %d sampled examples by greedy search' % sampled_size)

                if torch.cuda.is_available():
                    src                 = src.data.cpu().numpy()
                    decoder_log_probs   = decoder_log_probs.data.cpu().numpy()
                    max_words_pred      = decoder_log_probs.argmax(axis=-1)
                    trg_target          = trg_target.data.cpu().numpy()
                    trg_copy_target     = trg_copy_target.data.cpu().numpy()
                else:
                    src                 = src.data.numpy()
                    decoder_log_probs   = decoder_log_probs.data.numpy()
                    max_words_pred      = decoder_log_probs.argmax(axis=-1)
                    trg_target          = trg_target.data.numpy()
                    trg_copy_target     = trg_copy_target.data.numpy()

                sampled_trg_idx     = np.random.random_integers(low=0, high=len(trg) - 1, size=sampled_size)
                src                 = src[sampled_trg_idx]
                oov_lists           = [oov_lists[i] for i in sampled_trg_idx]
                max_words_pred      = [max_words_pred[i] for i in sampled_trg_idx]
                decoder_log_probs   = decoder_log_probs[sampled_trg_idx]
                if not opt.copy_model:
                    trg_target      = [trg_target[i] for i in sampled_trg_idx] # use the real target trg_loss (the starting <BOS> has been removed and contains oov ground-truth)
                else:
                    trg_target      = [trg_copy_target[i] for i in sampled_trg_idx]

                for i, (src_wi, pred_wi, trg_i, oov_i) in enumerate(zip(src, max_words_pred, trg_target, oov_lists)):
                    nll_prob = -np.sum([decoder_log_probs[i][l][pred_wi[l]] for l in range(len(trg_i))])
                    find_copy       = np.any([x >= opt.vocab_size for x in src_wi])
                    has_copy        = np.any([x >= opt.vocab_size for x in trg_i])

                    sentence_source = [opt.id2word[x] if x < opt.vocab_size else oov_i[x-opt.vocab_size] for x in src_wi]
                    sentence_pred   = [opt.id2word[x] if x < opt.vocab_size else oov_i[x-opt.vocab_size] for x in pred_wi]
                    sentence_real   = [opt.id2word[x] if x < opt.vocab_size else oov_i[x-opt.vocab_size] for x in trg_i]

                    sentence_source = sentence_source[:sentence_source.index('<pad>')] if '<pad>' in sentence_source else sentence_source
                    sentence_pred   = sentence_pred[:sentence_pred.index('<pad>')] if '<pad>' in sentence_pred else sentence_pred
                    sentence_real   = sentence_real[:sentence_real.index('<pad>')] if '<pad>' in sentence_real else sentence_real

                    logging.info('==================================================')
                    logging.info('Source: %s '          % (' '.join(sentence_source)))
                    logging.info('\t\tPred : %s (%.4f)' % (' '.join(sentence_pred), nll_prob) + (' [FIND COPY]' if find_copy else ''))
                    logging.info('\t\tReal : %s '       % (' '.join(sentence_real)) + (' [HAS COPY]' + str(trg_i) if has_copy else ''))

            if total_batch > 1 and total_batch % opt.run_valid_every == 0:
                logging.info('*' * 50)
                logging.info('Run validing and testing @Epoch=%d,#(Total batch)=%d' % (epoch, total_batch))
                # valid_losses    = _valid_error(valid_data_loader, model, criterion, epoch, opt)
                # valid_history_losses.append(valid_losses)
                valid_score_dict  = evaluate_beam_search(generator, valid_data_loader, opt, title='valid', epoch=epoch, save_path=opt.exp_path + '/epoch%d_batch%d_total_batch%d' % (epoch, batch_i, total_batch))
                test_score_dict   = evaluate_beam_search(generator, test_data_loader, opt, title='test', epoch=epoch, save_path=opt.exp_path + '/epoch%d_batch%d_total_batch%d' % (epoch, batch_i, total_batch))

                checkpoint_names.append('epoch=%d-batch=%d-total_batch=%d' % (epoch, batch_i, total_batch))
                train_history_losses.append(copy.copy(train_losses))
                valid_history_losses.append(valid_score_dict)
                test_history_losses.append(test_score_dict)
                train_losses = []

                scores = [train_history_losses]
                curve_names = ['Training Error']
                scores += [[result_dict[name] for result_dict in valid_history_losses] for name in opt.report_score_names]
                curve_names += ['Valid-'+name for name in opt.report_score_names]
                scores += [[result_dict[name] for result_dict in test_history_losses] for name in opt.report_score_names]
                curve_names += ['Test-'+name for name in opt.report_score_names]

                scores = [np.asarray(s) for s in scores]
                # Plot the learning curve
                plot_learning_curve(scores=scores,
                                    curve_names=curve_names,
                                    checkpoint_names=checkpoint_names,
                                    title='Training Validation & Test',
                                    save_path=opt.exp_path + '/[epoch=%d,batch=%d,total_batch=%d]train_valid_test_curve.png' % (epoch, batch_i, total_batch))

                '''
                determine if early stop training (whether f-score increased, before is if valid error decreased)
                '''
                valid_loss      = np.average(valid_history_losses[-1][opt.report_score_names[0]])
                is_best_loss    = valid_loss > best_loss
                rate_of_change  = float(valid_loss - best_loss) / float(best_loss) if float(best_loss) > 0 else 0.0

                # valid error doesn't increase
                if rate_of_change <= 0:
                    stop_increasing += 1
                else:
                    stop_increasing = 0

                if is_best_loss:
                    logging.info('Validation: update best loss (%.4f --> %.4f), rate of change (ROC)=%.2f' % (
                        best_loss, valid_loss, rate_of_change * 100))
                else:
                    logging.info('Validation: best loss is not updated for %d times (%.4f --> %.4f), rate of change (ROC)=%.2f' % (
                        stop_increasing, best_loss, valid_loss, rate_of_change * 100))

                best_loss = max(valid_loss, best_loss)

                # only store the checkpoints that make better validation performances
                if total_batch > 1 and (total_batch % opt.save_model_every == 0 or is_best_loss): #epoch >= opt.start_checkpoint_at and
                    # Save the checkpoint
                    logging.info('Saving checkpoint to: %s' % os.path.join(opt.save_path, '%s.epoch=%d.batch=%d.total_batch=%d.error=%f' % (opt.exp, epoch, batch_i, total_batch, valid_loss) + '.model'))
                    torch.save(
                        model.state_dict(),
                        open(os.path.join(opt.save_path, '%s.epoch=%d.batch=%d.total_batch=%d' % (opt.exp, epoch, batch_i, total_batch) + '.model'), 'wb')
                    )
                    torch.save(
                        (epoch, total_batch, best_loss, stop_increasing, checkpoint_names, train_history_losses, valid_history_losses, test_history_losses),
                        open(os.path.join(opt.save_path, '%s.epoch=%d.batch=%d.total_batch=%d' % (opt.exp, epoch, batch_i, total_batch) + '.state'), 'wb')
                    )

                if stop_increasing >= opt.early_stop_tolerance:
                    logging.info('Have not increased for %d epoches, early stop training' % stop_increasing)
                    early_stop_flag = True
                    break
                logging.info('*' * 50)
Exemple #8
0
def train_model(model, optimizer, criterion, train_data_loader,
                valid_data_loader, test_data_loader, opt):
    generator = SequenceGenerator(model,
                                  eos_id=opt.word2id[pykp.io.EOS_WORD],
                                  beam_size=opt.beam_size,
                                  max_sequence_length=opt.max_sent_length)

    logging.info(
        '======================  Checking GPU Availability  ========================='
    )
    if torch.cuda.is_available():
        if isinstance(opt.gpuid, int):
            opt.gpuid = [opt.gpuid]
        logging.info('Running on GPU! devices=%s' % str(opt.gpuid))
        # model = nn.DataParallel(model, device_ids=opt.gpuid)
    else:
        logging.info('Running on CPU!')

    logging.info(
        '======================  Start Training  =========================')

    checkpoint_names = []
    train_history_losses = []
    valid_history_losses = []
    test_history_losses = []
    # best_loss = sys.float_info.max # for normal training/testing loss (likelihood)
    best_loss = 0.0  # for f-score
    stop_increasing = 0

    train_losses = []
    total_batch = 0
    early_stop_flag = False

    if opt.train_from:
        state_path = opt.train_from.replace('.model', '.state')
        logging.info('Loading training state from: %s' % state_path)
        if os.path.exists(state_path):
            (epoch, total_batch, best_loss, stop_increasing, checkpoint_names,
             train_history_losses, valid_history_losses,
             test_history_losses) = torch.load(open(state_path, 'rb'))
            opt.start_epoch = epoch

    for epoch in range(opt.start_epoch, opt.epochs):
        if early_stop_flag:
            break

        progbar = Progbar(title='Training',
                          target=len(train_data_loader),
                          batch_size=train_data_loader.batch_size,
                          total_examples=len(train_data_loader.dataset))

        for batch_i, batch in enumerate(train_data_loader):
            model.train()
            batch_i += 1  # for the aesthetics of printing
            total_batch += 1
            one2many_batch, one2one_batch = batch
            src, trg, trg_target, trg_copy_target, src_ext, oov_lists = one2one_batch
            max_oov_number = max([len(oov) for oov in oov_lists])

            print("src size - ", src.size())
            print("target size - ", trg.size())

            if torch.cuda.is_available():
                src = src.cuda()
                trg = trg.cuda()
                trg_target = trg_target.cuda()
                trg_copy_target = trg_copy_target.cuda()
                src_ext = src_ext.cuda()

            optimizer.zero_grad()
            '''
            Training with Maximum Likelihood (word-level error)
            '''
            decoder_log_probs, _, _ = model.forward(src, trg, src_ext,
                                                    oov_lists)

            # simply average losses of all the predicitons
            # IMPORTANT, must use logits instead of probs to compute the loss, otherwise it's super super slow at the beginning (grads of probs are small)!
            start_time = time.time()

            if not opt.copy_model:
                ml_loss = criterion(
                    decoder_log_probs.contiguous().view(-1, opt.vocab_size),
                    trg_target.contiguous().view(-1))
            else:
                ml_loss = criterion(
                    decoder_log_probs.contiguous().view(
                        -1, opt.vocab_size + max_oov_number),
                    trg_copy_target.contiguous().view(-1))
            '''
            Training with Reinforcement Learning (instance-level reward f-score)
            '''
            src_list, trg_list, _, trg_copy_target_list, src_oov_map_list, oov_list, src_str_list, trg_str_list = one2many_batch

            if torch.cuda.is_available():
                src_list = src_list.cuda()
                src_oov_map_list = src_oov_map_list.cuda()
            rl_loss = get_loss_rl()

            start_time = time.time()
            ml_loss.backward()
            print("--backward- %s seconds ---" % (time.time() - start_time))

            if opt.max_grad_norm > 0:
                pre_norm = torch.nn.utils.clip_grad_norm(
                    model.parameters(), opt.max_grad_norm)
                after_norm = (sum([
                    p.grad.data.norm(2)**2 for p in model.parameters()
                    if p.grad is not None
                ]))**(1.0 / 2)
                logging.info('clip grad (%f -> %f)' % (pre_norm, after_norm))

            optimizer.step()

            train_losses.append(ml_loss.data[0])

            progbar.update(epoch, batch_i, [('train_loss', ml_loss.data[0]),
                                            ('PPL', ml_loss.data[0])])

            if batch_i > 1 and batch_i % opt.report_every == 0:
                logging.info(
                    '======================  %d  =========================' %
                    (batch_i))

                logging.info('Epoch : %d Minibatch : %d, Loss=%.5f' %
                             (epoch, batch_i, np.mean(ml_loss.data[0])))
                sampled_size = 2
                logging.info(
                    'Printing predictions on %d sampled examples by greedy search'
                    % sampled_size)

                if torch.cuda.is_available():
                    src = src.data.cpu().numpy()
                    decoder_log_probs = decoder_log_probs.data.cpu().numpy()
                    max_words_pred = decoder_log_probs.argmax(axis=-1)
                    trg_target = trg_target.data.cpu().numpy()
                    trg_copy_target = trg_copy_target.data.cpu().numpy()
                else:
                    src = src.data.numpy()
                    decoder_log_probs = decoder_log_probs.data.numpy()
                    max_words_pred = decoder_log_probs.argmax(axis=-1)
                    trg_target = trg_target.data.numpy()
                    trg_copy_target = trg_copy_target.data.numpy()

                sampled_trg_idx = np.random.random_integers(low=0,
                                                            high=len(trg) - 1,
                                                            size=sampled_size)
                src = src[sampled_trg_idx]
                oov_lists = [oov_lists[i] for i in sampled_trg_idx]
                max_words_pred = [max_words_pred[i] for i in sampled_trg_idx]
                decoder_log_probs = decoder_log_probs[sampled_trg_idx]
                if not opt.copy_model:
                    trg_target = [
                        trg_target[i] for i in sampled_trg_idx
                    ]  # use the real target trg_loss (the starting <BOS> has been removed and contains oov ground-truth)
                else:
                    trg_target = [trg_copy_target[i] for i in sampled_trg_idx]

                for i, (src_wi, pred_wi, trg_i, oov_i) in enumerate(
                        zip(src, max_words_pred, trg_target, oov_lists)):
                    nll_prob = -np.sum([
                        decoder_log_probs[i][l][pred_wi[l]]
                        for l in range(len(trg_i))
                    ])
                    find_copy = np.any([x >= opt.vocab_size for x in src_wi])
                    has_copy = np.any([x >= opt.vocab_size for x in trg_i])

                    sentence_source = [
                        opt.id2word[x]
                        if x < opt.vocab_size else oov_i[x - opt.vocab_size]
                        for x in src_wi
                    ]
                    sentence_pred = [
                        opt.id2word[x]
                        if x < opt.vocab_size else oov_i[x - opt.vocab_size]
                        for x in pred_wi
                    ]
                    sentence_real = [
                        opt.id2word[x]
                        if x < opt.vocab_size else oov_i[x - opt.vocab_size]
                        for x in trg_i
                    ]

                    sentence_source = sentence_source[:sentence_source.index(
                        '<pad>'
                    )] if '<pad>' in sentence_source else sentence_source
                    sentence_pred = sentence_pred[:sentence_pred.index(
                        '<pad>'
                    )] if '<pad>' in sentence_pred else sentence_pred
                    sentence_real = sentence_real[:sentence_real.index(
                        '<pad>'
                    )] if '<pad>' in sentence_real else sentence_real

                    logging.info(
                        '==================================================')
                    logging.info('Source: %s ' % (' '.join(sentence_source)))
                    logging.info('\t\tPred : %s (%.4f)' %
                                 (' '.join(sentence_pred), nll_prob) +
                                 (' [FIND COPY]' if find_copy else ''))
                    logging.info('\t\tReal : %s ' % (' '.join(sentence_real)) +
                                 (' [HAS COPY]' +
                                  str(trg_i) if has_copy else ''))

            if total_batch > 1 and total_batch % opt.run_valid_every == 0:
                logging.info('*' * 50)
                logging.info(
                    'Run validing and testing @Epoch=%d,#(Total batch)=%d' %
                    (epoch, total_batch))
                # valid_losses    = _valid_error(valid_data_loader, model, criterion, epoch, opt)
                # valid_history_losses.append(valid_losses)
                valid_score_dict = evaluate_beam_search(
                    generator,
                    valid_data_loader,
                    opt,
                    title='valid',
                    epoch=epoch,
                    save_path=opt.exp_path + '/epoch%d_batch%d_total_batch%d' %
                    (epoch, batch_i, total_batch))
                test_score_dict = evaluate_beam_search(
                    generator,
                    test_data_loader,
                    opt,
                    title='test',
                    epoch=epoch,
                    save_path=opt.exp_path + '/epoch%d_batch%d_total_batch%d' %
                    (epoch, batch_i, total_batch))

                checkpoint_names.append('epoch=%d-batch=%d-total_batch=%d' %
                                        (epoch, batch_i, total_batch))
                train_history_losses.append(copy.copy(train_losses))
                valid_history_losses.append(valid_score_dict)
                test_history_losses.append(test_score_dict)
                train_losses = []

                scores = [train_history_losses]
                curve_names = ['Training Error']
                scores += [[
                    result_dict[name] for result_dict in valid_history_losses
                ] for name in opt.report_score_names]
                curve_names += [
                    'Valid-' + name for name in opt.report_score_names
                ]
                scores += [[
                    result_dict[name] for result_dict in test_history_losses
                ] for name in opt.report_score_names]
                curve_names += [
                    'Test-' + name for name in opt.report_score_names
                ]

                scores = [np.asarray(s) for s in scores]
                # Plot the learning curve
                plot_learning_curve(
                    scores=scores,
                    curve_names=curve_names,
                    checkpoint_names=checkpoint_names,
                    title='Training Validation & Test',
                    save_path=opt.exp_path +
                    '/[epoch=%d,batch=%d,total_batch=%d]train_valid_test_curve.png'
                    % (epoch, batch_i, total_batch))
                '''
                determine if early stop training (whether f-score increased, before is if valid error decreased)
                '''
                valid_loss = np.average(
                    valid_history_losses[-1][opt.report_score_names[0]])
                is_best_loss = valid_loss > best_loss
                rate_of_change = float(valid_loss - best_loss) / float(
                    best_loss) if float(best_loss) > 0 else 0.0

                # valid error doesn't increase
                if rate_of_change <= 0:
                    stop_increasing += 1
                else:
                    stop_increasing = 0

                if is_best_loss:
                    logging.info(
                        'Validation: update best loss (%.4f --> %.4f), rate of change (ROC)=%.2f'
                        % (best_loss, valid_loss, rate_of_change * 100))
                else:
                    logging.info(
                        'Validation: best loss is not updated for %d times (%.4f --> %.4f), rate of change (ROC)=%.2f'
                        % (stop_increasing, best_loss, valid_loss,
                           rate_of_change * 100))

                best_loss = max(valid_loss, best_loss)

                # only store the checkpoints that make better validation performances
                if total_batch > 1 and (
                        total_batch % opt.save_model_every == 0 or
                        is_best_loss):  #epoch >= opt.start_checkpoint_at and
                    # Save the checkpoint
                    logging.info('Saving checkpoint to: %s' % os.path.join(
                        opt.save_path,
                        '%s.epoch=%d.batch=%d.total_batch=%d.error=%f' %
                        (opt.exp, epoch, batch_i, total_batch, valid_loss) +
                        '.model'))
                    torch.save(
                        model.state_dict(),
                        open(
                            os.path.join(
                                opt.save_path,
                                '%s.epoch=%d.batch=%d.total_batch=%d' %
                                (opt.exp, epoch, batch_i, total_batch) +
                                '.model'), 'wb'))
                    torch.save((epoch, total_batch, best_loss, stop_increasing,
                                checkpoint_names, train_history_losses,
                                valid_history_losses, test_history_losses),
                               open(
                                   os.path.join(
                                       opt.save_path,
                                       '%s.epoch=%d.batch=%d.total_batch=%d' %
                                       (opt.exp, epoch, batch_i, total_batch) +
                                       '.state'), 'wb'))

                if stop_increasing >= opt.early_stop_tolerance:
                    logging.info(
                        'Have not increased for %d epoches, early stop training'
                        % stop_increasing)
                    early_stop_flag = True
                    break
                logging.info('*' * 50)
def evaluate_beam_search(generator,
                         data_loader,
                         opt,
                         title='',
                         epoch=1,
                         predict_save_path=None):
    logger = config.init_logging(title,
                                 predict_save_path + '/%s.log' % title,
                                 redirect_to_stdout=False)
    progbar = Progbar(logger=logger,
                      title=title,
                      target=len(data_loader.dataset.examples),
                      batch_size=data_loader.batch_size,
                      total_examples=len(data_loader.dataset.examples))

    topk_range = [5, 10]
    score_names = ['precision', 'recall', 'f_score']

    example_idx = 0
    score_dict = {
    }  # {'precision@5':[],'recall@5':[],'f1score@5':[], 'precision@10':[],'recall@10':[],'f1score@10':[]}

    for i, batch in enumerate(data_loader):
        # if i > 5:
        #     break

        one2many_batch, one2one_batch = batch
        src_list, src_len, trg_list, _, trg_copy_target_list, src_oov_map_list, oov_list, src_str_list, trg_str_list = one2many_batch

        if torch.cuda.is_available():
            src_list = src_list.cuda()
            src_oov_map_list = src_oov_map_list.cuda()

        print("batch size - %s" % str(src_list.size(0)))
        print("src size - %s" % str(src_list.size()))
        print("target size - %s" % len(trg_copy_target_list))

        pred_seq_list = generator.beam_search(src_list, src_len,
                                              src_oov_map_list, oov_list,
                                              opt.word2id)
        '''
        process each example in current batch
        '''
        for src, src_str, trg, trg_str_seqs, trg_copy, pred_seq, oov in zip(
                src_list, src_str_list, trg_list, trg_str_list,
                trg_copy_target_list, pred_seq_list, oov_list):
            logger.info(
                '======================  %d =========================' % (i))
            print_out = ''
            print_out += '[Source][%d]: %s \n' % (len(src_str),
                                                  ' '.join(src_str))
            src = src.cpu().data.numpy() if torch.cuda.is_available(
            ) else src.data.numpy()
            print_out += '\nSource Input: \n %s\n' % (' '.join(
                [opt.id2word[x] for x in src[:len(src_str) + 5]]))
            print_out += 'Real Target String [%d] \n\t\t%s \n' % (
                len(trg_str_seqs), trg_str_seqs)
            print_out += 'Real Target Input:  \n\t\t%s \n' % str(
                [[opt.id2word[x] for x in t] for t in trg])
            print_out += 'Real Target Copy:   \n\t\t%s \n' % str([[
                opt.id2word[x] if x < opt.vocab_size else oov[x -
                                                              opt.vocab_size]
                for x in t
            ] for t in trg_copy])
            trg_str_is_present_flags, _ = if_present_duplicate_phrases(
                src_str, trg_str_seqs)

            # ignore the cases that there's no present phrases
            if opt.must_appear_in_src and np.sum(
                    trg_str_is_present_flags) == 0:
                logger.error('found no present targets')
                continue

            print_out += '[GROUND-TRUTH] #(present)/#(all targets)=%d/%d\n' % (
                sum(trg_str_is_present_flags), len(trg_str_is_present_flags))
            print_out += '\n'.join([
                '\t\t[%s]' % ' '.join(phrase) if is_present else '\t\t%s' %
                ' '.join(phrase) for phrase, is_present in zip(
                    trg_str_seqs, trg_str_is_present_flags)
            ])
            print_out += '\noov_list:   \n\t\t%s \n' % str(oov)

            # 1st filtering
            pred_is_valid_flags, processed_pred_seqs, processed_pred_str_seqs, processed_pred_score = process_predseqs(
                pred_seq, oov, opt.id2word, opt)
            # 2nd filtering: if filter out phrases that don't appear in text, and keep unique ones after stemming
            if opt.must_appear_in_src:
                pred_is_present_flags, _ = if_present_duplicate_phrases(
                    src_str, processed_pred_str_seqs)
                filtered_trg_str_seqs = np.asarray(
                    trg_str_seqs)[trg_str_is_present_flags]
            else:
                pred_is_present_flags = [True] * len(processed_pred_str_seqs)

            valid_and_present = np.asarray(pred_is_valid_flags) * np.asarray(
                pred_is_present_flags)
            match_list = get_match_result(true_seqs=filtered_trg_str_seqs,
                                          pred_seqs=processed_pred_str_seqs)
            print_out += '[PREDICTION] #(valid)=%d, #(present)=%d, #(retained&present)=%d, #(all)=%d\n' % (
                sum(pred_is_valid_flags), sum(pred_is_present_flags),
                sum(valid_and_present), len(pred_seq))
            print_out += ''
            '''
            Print and export predictions
            '''
            preds_out = ''
            for p_id, (seq, word, score, match, is_valid,
                       is_present) in enumerate(
                           zip(processed_pred_seqs, processed_pred_str_seqs,
                               processed_pred_score, match_list,
                               pred_is_valid_flags, pred_is_present_flags)):
                # if p_id > 5:
                #     break

                preds_out += '%s\n' % (' '.join(word))
                if is_present:
                    print_phrase = '[%s]' % ' '.join(word)
                else:
                    print_phrase = ' '.join(word)

                if is_valid:
                    print_phrase = '*%s' % print_phrase

                if match == 1.0:
                    correct_str = '[correct!]'
                else:
                    correct_str = ''
                if any([t >= opt.vocab_size for t in seq.sentence]):
                    copy_str = '[copied!]'
                else:
                    copy_str = ''

                print_out += '\t\t[%.4f]\t%s \t %s %s%s\n' % (
                    -score, print_phrase, str(
                        seq.sentence), correct_str, copy_str)
            '''
            Evaluate predictions w.r.t different filterings and metrics
            '''
            processed_pred_seqs = np.asarray(
                processed_pred_seqs)[valid_and_present]
            filtered_processed_pred_str_seqs = np.asarray(
                processed_pred_str_seqs)[valid_and_present]
            filtered_processed_pred_score = np.asarray(
                processed_pred_score)[valid_and_present]

            # 3rd round filtering (one-word phrases)
            num_oneword_seq = -1
            filtered_pred_seq, filtered_pred_str_seqs, filtered_pred_score = post_process_predseqs(
                (processed_pred_seqs, filtered_processed_pred_str_seqs,
                 filtered_processed_pred_score), num_oneword_seq)

            match_list_exact = get_match_result(
                true_seqs=filtered_trg_str_seqs,
                pred_seqs=filtered_pred_str_seqs,
                type='exact')
            match_list_soft = get_match_result(
                true_seqs=filtered_trg_str_seqs,
                pred_seqs=filtered_pred_str_seqs,
                type='partial')

            assert len(filtered_pred_seq) == len(
                filtered_pred_str_seqs) == len(filtered_pred_score) == len(
                    match_list_exact) == len(match_list_soft)

            print_out += "\n ======================================================="
            print_pred_str_seqs = [
                " ".join(item) for item in filtered_pred_str_seqs
            ]
            print_trg_str_seqs = [
                " ".join(item) for item in filtered_trg_str_seqs
            ]
            # print_out += "\n PREDICTION: " + " / ".join(print_pred_str_seqs)
            # print_out += "\n GROUND TRUTH: " + " / ".join(print_trg_str_seqs)

            for topk in topk_range:
                results_exact = evaluate(match_list_exact,
                                         filtered_pred_str_seqs,
                                         filtered_trg_str_seqs,
                                         topk=topk)
                for k, v in zip(score_names, results_exact):
                    if '%s@%d_exact' % (k, topk) not in score_dict:
                        score_dict['%s@%d_exact' % (k, topk)] = []
                    score_dict['%s@%d_exact' % (k, topk)].append(v)

                print_out += "\n ------------------------------------------------- EXACT, k=%d" % (
                    topk)
                print_out += "\n --- batch precision, recall, fscore: " + str(
                    results_exact[0]) + " , " + str(
                        results_exact[1]) + " , " + str(results_exact[2])
                print_out += "\n --- total precision, recall, fscore: " + str(np.average(score_dict['precision@%d_exact' % (topk)])) + " , " +\
                            str(np.average(score_dict['recall@%d_exact' % (topk)])) + " , " +\
                            str(np.average(score_dict['f_score@%d_exact' % (topk)]))

            for topk in topk_range:
                results_soft = evaluate(match_list_soft,
                                        filtered_pred_str_seqs,
                                        filtered_trg_str_seqs,
                                        topk=topk)
                for k, v in zip(score_names, results_soft):
                    if '%s@%d_soft' % (k, topk) not in score_dict:
                        score_dict['%s@%d_soft' % (k, topk)] = []
                    score_dict['%s@%d_soft' % (k, topk)].append(v)

                print_out += "\n ------------------------------------------------- SOFT, k=%d" % (
                    topk)
                print_out += "\n --- batch precision, recall, fscore: " + str(
                    results_soft[0]) + " , " + str(
                        results_soft[1]) + " , " + str(results_soft[2])
                print_out += "\n --- total precision, recall, fscore: " + str(np.average(score_dict['precision@%d_soft' % (topk)])) + " , " +\
                            str(np.average(score_dict['recall@%d_soft' % (topk)])) + " , " +\
                            str(np.average(score_dict['f_score@%d_soft' % (topk)]))

            print_out += "\n ======================================================="
            logger.info(print_out)
            '''
            write predictions to disk
            '''
            if predict_save_path:
                if not os.path.exists(
                        os.path.join(predict_save_path, title + '_detail')):
                    os.makedirs(
                        os.path.join(predict_save_path, title + '_detail'))
                with open(
                        os.path.join(predict_save_path, title + '_detail',
                                     str(example_idx) + '_print.txt'),
                        'w') as f_:
                    f_.write(print_out)
                with open(
                        os.path.join(predict_save_path, title + '_detail',
                                     str(example_idx) + '_prediction.txt'),
                        'w') as f_:
                    f_.write(preds_out)

                out_dict = {}
                out_dict['src_str'] = src_str
                out_dict['trg_str'] = trg_str_seqs
                out_dict['trg_present_flag'] = trg_str_is_present_flags
                out_dict['pred_str'] = processed_pred_str_seqs
                out_dict['pred_score'] = [
                    float(s) for s in processed_pred_score
                ]
                out_dict['present_flag'] = pred_is_present_flags
                out_dict['valid_flag'] = pred_is_valid_flags
                out_dict['match_flag'] = [float(m) for m in match_list]

                for k, v in out_dict.items():
                    out_dict[k] = list(v)
                    # print('len(%s) = %d' % (k, len(v)))

                # print(out_dict)

                assert len(out_dict['trg_str']) == len(
                    out_dict['trg_present_flag'])
                assert len(out_dict['pred_str']) == len(out_dict['present_flag']) \
                       == len(out_dict['valid_flag']) == len(out_dict['match_flag']) == len(out_dict['pred_score'])

                with open(
                        os.path.join(predict_save_path, title + '_detail',
                                     str(example_idx) + '.json'), 'w') as f_:
                    f_.write(json.dumps(out_dict))

            progbar.update(epoch, example_idx, [
                ('f_score@5_exact', np.average(score_dict['f_score@5_exact'])),
                ('f_score@5_soft', np.average(score_dict['f_score@5_soft'])),
                ('f_score@10_exact', np.average(
                    score_dict['f_score@10_exact'])),
                ('f_score@10_soft', np.average(score_dict['f_score@10_soft'])),
            ])

            example_idx += 1

    # print('#(f_score@5#oneword=-1)=%d, sum=%f' % (len(score_dict['f_score@5#oneword=-1']), sum(score_dict['f_score@5#oneword=-1'])))
    # print('#(f_score@10#oneword=-1)=%d, sum=%f' % (len(score_dict['f_score@10#oneword=-1']), sum(score_dict['f_score@10#oneword=-1'])))
    # print('#(f_score@5#oneword=1)=%d, sum=%f' % (len(score_dict['f_score@5#oneword=1']), sum(score_dict['f_score@5#oneword=1'])))
    # print('#(f_score@10#oneword=1)=%d, sum=%f' % (len(score_dict['f_score@10#oneword=1']), sum(score_dict['f_score@10#oneword=1'])))

    if predict_save_path:
        # export scores. Each row is scores (precision, recall and f-score) of different way of filtering predictions (how many one-word predictions to keep)
        with open(predict_save_path + os.path.sep + title + '_result.csv',
                  'w') as result_csv:
            csv_lines = []
            for mode in ["exact", "soft"]:
                for topk in topk_range:
                    csv_line = ""
                    for k in score_names:
                        csv_line += ',%f' % np.average(
                            score_dict['%s@%d_%s' % (k, topk, mode)])
                    csv_lines.append(csv_line + '\n')

            result_csv.writelines(csv_lines)

    # precision, recall, f_score = macro_averaged_score(precisionlist=score_dict['precision'], recalllist=score_dict['recall'])
    # logging.info("Macro@5\n\t\tprecision %.4f\n\t\tmacro recall %.4f\n\t\tmacro fscore %.4f " % (np.average(score_dict['precision@5']), np.average(score_dict['recall@5']), np.average(score_dict['f1score@5'])))
    # logging.info("Macro@10\n\t\tprecision %.4f\n\t\tmacro recall %.4f\n\t\tmacro fscore %.4f " % (np.average(score_dict['precision@10']), np.average(score_dict['recall@10']), np.average(score_dict['f1score@10'])))
    # precision, recall, f_score = evaluate(true_seqs=target_all, pred_seqs=prediction_all, topn=5)
    # logging.info("micro precision %.4f , micro recall %.4f, micro fscore %.4f " % (precision, recall, f_score))

    for k, v in score_dict.items():
        print('#(%s) = %d' % (k, len(v)))

    return score_dict
Exemple #10
0
    def run_epoch(self, sess, src_train, src_dev, tags, target_train,
                  target_dev, n_epoch_noimprove):
        nbatces = (len(target_train) + self.target_batch_size -
                   1) // self.target_batch_size
        prog = Progbar(target=nbatces)
        total_loss = 0

        src = minibatches(src_train, self.src_batch_size, circle=True)
        target = minibatches(target_train, self.target_batch_size, circle=True)

        for i in range(nbatces):
            src_words, src_tags, _ = next(src)
            target_words, target_tags, _ = next(target)
            labels = src_tags + target_tags

            feed_dict, _ = self.get_feed_dict(src_words,
                                              labels,
                                              target_words,
                                              self.args.learning_rate,
                                              self.args.dropout,
                                              self.src_batch_size,
                                              is_training=True)

            if self.args.penalty_ratio > 0:
                _, src_crf_loss, target_crf_loss, penalty_loss, loss = sess.run(
                    [
                        self.train_op, self.src_crf_loss, self.target_crf_loss,
                        self.penalty_loss, self.loss
                    ],
                    feed_dict=feed_dict)
                try:
                    prog.update(
                        i + 1,
                        [("train loss", loss[0]), ("src crf", src_crf_loss),
                         ("target crf", target_crf_loss),
                         ("{} loss".format(self.args.penalty), penalty_loss)])
                except:
                    prog.update(
                        i + 1,
                        [("train loss", loss), ("src crf", src_crf_loss),
                         ("target crf", target_crf_loss),
                         ("{} loss".format(self.args.penalty), penalty_loss)])
            else:
                _, src_crf_loss, target_crf_loss, loss = sess.run(
                    [
                        self.train_op, self.src_crf_loss, self.target_crf_loss,
                        self.loss
                    ],
                    feed_dict=feed_dict)
                try:
                    prog.update(i + 1, [("train loss", loss[0]),
                                        ("src crf", src_crf_loss),
                                        ("target crf", target_crf_loss)])
                except:
                    prog.update(i + 1, [("train loss", loss),
                                        ("src crf", src_crf_loss),
                                        ("target crf", target_crf_loss)])
            total_loss += loss

        self.info['loss'] += [total_loss / nbatces]
        acc, p, r, f1 = self.run_evaluate(sess,
                                          target_train,
                                          tags,
                                          target='target')
        self.info['dev'].append((acc, p, r, f1))
        self.logger.critical(
            "target train acc {:04.2f}  f1  {:04.2f}  p {:04.2f}  r  {:04.2f}".
            format(100 * acc, 100 * f1, 100 * p, 100 * r))
        acc, p, r, f1 = self.run_evaluate(sess,
                                          target_dev,
                                          tags,
                                          target='target')
        self.info['dev'].append((acc, p, r, f1))
        self.logger.info(
            "dev acc {:04.2f}  f1  {:04.2f}  p {:04.2f}  r  {:04.2f}".format(
                100 * acc, 100 * f1, 100 * p, 100 * r))
        return acc, p, r, f1
Exemple #11
0
def run(
    kn=1.0,
    sigma=lambda v, w: 1.0,
    G=lambda x: 0.0,
    xmin=0.0,
    xmax=1.0,
    nx=40,
    dt=0.01,
    nt=1000,
    coll="linear",
    scheme="Euler",
    BC="periodic",
    f_l=lambda v: 0.0,
    f_r=lambda v: 0.0,
    init_func=lambda vmesh, u, T, rho: 0.0,
):
    # Load config
    config = collision.utils.CollisionConfig.from_json(
        "./linear_boltz/configs/" + "linear" + ".json"
    )

    # Collision
    vmesh = collision.CartesianMesh(config)
    if coll == "linear":
        coll_op = collision.LinearBotlzmannCollision(config, vmesh, sigma=sigma)
    elif coll == "rbm":
        coll_op = collision.RandomBatchLinearBoltzmannCollision(
            config, vmesh, sigma=sigma
        )
    elif coll == "rbm_symm":
        coll_op = collision.SymmetricRBMLinearCollision(config, vmesh, sigma=sigma)
    else:
        raise NotImplementedError(
            "Collision method {} is not implemented.".format(coll)
        )

    # x domian
    x = pykinetic.Dimension(xmin, xmax, nx, name="x")
    domain = pykinetic.Domain([x])

    # Riemann solver
    rp = pykinetic.riemann.advection_1D
    solver = DiffusiveRegimeSolver1D(
        rp,
        [coll_op],
        kn=kn(x.centers),
        G=G(x.centers),
    )
    solver.order = 1
    # solver.lim_type = 2
    # Time integrator
    if "RK" in scheme:
        solver.time_integrator = "RK"
        solver.a = rkcoeff[scheme]["a"]
        solver.b = rkcoeff[scheme]["b"]
        solver.c = rkcoeff[scheme]["c"]
    else:
        solver.time_integrator = scheme
    solver.dt = dt
    print("dt is {}".format(solver.dt))

    # Boundary condition
    def dirichlet_lower_BC(state, dim, t, qbc, auxbc, num_ghost):
        v = state.problem_data["v"][0]
        for i in range(num_ghost):
            qbc[0, i, v > 0] = f_l(v[v > 0])

    def dirichlet_upper_BC(state, dim, t, qbc, auxbc, num_ghost):
        v = state.problem_data["v"][0]
        for i in range(num_ghost):
            qbc[0, -i - 1, v < 0] = f_r(v[v < 0])

    if BC == "periodic":
        solver.bc_lower[0] = pykinetic.BC.periodic
        solver.bc_upper[0] = pykinetic.BC.periodic
    elif BC == "dirichlet":
        solver.bc_lower[0] = pykinetic.BC.custom
        solver.bc_upper[0] = pykinetic.BC.custom
        solver.user_bc_lower = dirichlet_lower_BC
        solver.user_bc_upper = dirichlet_upper_BC
    else:
        raise ValueError("Given BC type is not avaliable!")

    state = pykinetic.State(domain, vmesh, 1)
    state.problem_data["v"] = vmesh.centers
    qinit(state, vmesh, init_func)
    sol = pykinetic.Solution(state, domain)

    output_dict = {}
    sol_frames, macro_frames, ts = (
        [copy.deepcopy(sol)],
        [compute_rho(sol.state, vmesh)],
        [0.0],
    )
    pbar = Progbar(nt)
    for t in range(nt):
        solver.evolve_to_time(sol)
        if (t + 1) % 1 == 0:
            sol_frames.append(copy.deepcopy(sol))
            macro_frames.append(compute_rho(sol.state, vmesh))
            ts.append(sol.t)
        pbar.update(t + 1, finalize=False)
    pbar.update(nt, finalize=True)

    output_dict["macro_frames"] = np.asarray(macro_frames)
    output_dict["x"] = np.asarray(x.centers)
    output_dict["t"] = np.asarray(ts)

    return output_dict
    def train(self,
              trainset,
              devset,
              testset,
              batch_size=64,
              epochs=50,
              shuffle=True):
        '''

        :param trainset: 是所有的数据 分batch是在这个函数完成的
        :param devset:
        :param testset:
        :param batch_size:
        :param epochs:
        :param shuffle:
        :return:
        '''
        self.logger.info('Start training...')
        init_lr = self.cfg.lr  # initial learning rate, used for decay learning rate
        best_score = 0.0  # record the best score
        best_score_epoch = 1  # record the epoch of the best score obtained
        no_imprv_epoch = 0  # no improvement patience counter
        for epoch in range(self.start_epoch, epochs + 1):
            self.logger.info('Epoch %2d/%2d:' % (epoch, epochs))

            progbar = Progbar(target=(len(trainset) + batch_size - 1) //
                              batch_size)  # number of batches
            if shuffle:
                np.random.shuffle(
                    trainset)  # shuffle training dataset each epoch
            # training each epoch
            for i, (words,
                    labels) in enumerate(batch_iter(trainset, batch_size)):
                feed_dict = self._get_feed_dict(words,
                                                labels,
                                                lr=self.cfg.lr,
                                                is_train=True)
                _, train_loss = self.sess.run([self.train_op, self.loss],
                                              feed_dict=feed_dict)
                progbar.update(i + 1, [("train loss", train_loss)])
            if devset is not None:
                self.evaluate(devset, batch_size)
            cur_score = self.evaluate(testset, batch_size, is_devset=False)
            # learning rate decay
            if self.cfg.decay_lr:
                self.cfg.lr = init_lr / (1 + self.cfg.lr_decay * epoch)
            # performs model saving and evaluating on test dataset
            if cur_score > best_score:
                no_imprv_epoch = 0
                self.save_session(epoch)
                best_score = cur_score
                best_score_epoch = epoch
                self.logger.info(
                    ' -- new BEST score on TEST dataset: {:05.3f}'.format(
                        best_score))
            else:
                no_imprv_epoch += 1
                if no_imprv_epoch >= self.cfg.no_imprv_patience:
                    self.logger.info(
                        'early stop at {}th epoch without improvement for {} epochs, BEST score: '
                        '{:05.3f} at epoch {}'.format(epoch, no_imprv_epoch,
                                                      best_score,
                                                      best_score_epoch))
                    break
        self.logger.info('Training process done...')
Exemple #13
0
def train_model(model, optimizer_ml, optimizer_rl, criterion,
                train_data_loader, valid_data_loader, test_data_loader, opt):
    generator = SequenceGenerator(model,
                                  eos_id=opt.word2id[pykp.io.EOS_WORD],
                                  beam_size=opt.beam_size,
                                  max_sequence_length=opt.max_sent_length)

    logging.info(
        '======================  Checking GPU Availability  ========================='
    )
    if torch.cuda.is_available():
        if isinstance(opt.gpuid, int):
            opt.gpuid = [opt.gpuid]
        logging.info('Running on GPU! devices=%s' % str(opt.gpuid))
        # model = nn.DataParallel(model, device_ids=opt.gpuid)
    else:
        logging.info('Running on CPU!')

    logging.info(
        '======================  Start Training  =========================')

    checkpoint_names = []
    train_ml_history_losses = []
    train_rl_history_losses = []
    valid_history_losses = []
    test_history_losses = []
    # best_loss = sys.float_info.max # for normal training/testing loss (likelihood)
    best_loss = 0.0  # for f-score
    stop_increasing = 0

    train_ml_losses = []
    train_rl_losses = []
    total_batch = -1
    early_stop_flag = False
    if opt.train_rl:
        reward_cache = RewardCache(2000)

    if False:  # opt.train_from:
        state_path = opt.train_from.replace('.model', '.state')
        logging.info('Loading training state from: %s' % state_path)
        if os.path.exists(state_path):
            (epoch, total_batch, best_loss, stop_increasing, checkpoint_names,
             train_ml_history_losses, train_rl_history_losses,
             valid_history_losses,
             test_history_losses) = torch.load(open(state_path, 'rb'))
            opt.start_epoch = epoch

    for epoch in range(opt.start_epoch, opt.epochs):
        if early_stop_flag:
            break

        progbar = Progbar(logger=logging,
                          title='Training',
                          target=len(train_data_loader),
                          batch_size=train_data_loader.batch_size,
                          total_examples=len(
                              train_data_loader.dataset.examples))

        for batch_i, batch in enumerate(train_data_loader):
            model.train()
            total_batch += 1
            one2many_batch, one2one_batch = batch
            report_loss = []

            # Training
            if opt.train_ml:
                loss_ml, decoder_log_probs = train_ml(one2one_batch, model,
                                                      optimizer_ml, criterion,
                                                      opt)
                train_ml_losses.append(loss_ml)
                report_loss.append(('train_ml_loss', loss_ml))
                report_loss.append(('PPL', loss_ml))

                # Brief report
                if batch_i % opt.report_every == 0:
                    brief_report(epoch, batch_i, one2one_batch, loss_ml,
                                 decoder_log_probs, opt)

            # do not apply rl in 0th epoch, need to get a resonable model before that.
            if opt.train_rl:
                if epoch >= opt.rl_start_epoch:
                    loss_rl = train_rl(one2many_batch, model, optimizer_rl,
                                       generator, opt, reward_cache)
                else:
                    loss_rl = 0.0
                train_rl_losses.append(loss_rl)
                report_loss.append(('train_rl_loss', loss_rl))

            progbar.update(epoch, batch_i, report_loss)

            # Validate and save checkpoint
            if (opt.run_valid_every == -1 and batch_i == len(train_data_loader) - 1) or\
               (opt.run_valid_every > -1 and total_batch > 1 and total_batch % opt.run_valid_every == 0):
                logging.info('*' * 50)
                logging.info(
                    'Run validing and testing @Epoch=%d,#(Total batch)=%d' %
                    (epoch, total_batch))
                # valid_losses    = _valid_error(valid_data_loader, model, criterion, epoch, opt)
                # valid_history_losses.append(valid_losses)
                valid_score_dict = evaluate_beam_search(
                    generator,
                    valid_data_loader,
                    opt,
                    title='Validating, epoch=%d, batch=%d, total_batch=%d' %
                    (epoch, batch_i, total_batch),
                    epoch=epoch,
                    predict_save_path=opt.pred_path +
                    '/epoch%d_batch%d_total_batch%d' %
                    (epoch, batch_i, total_batch))
                test_score_dict = evaluate_beam_search(
                    generator,
                    test_data_loader,
                    opt,
                    title='Testing, epoch=%d, batch=%d, total_batch=%d' %
                    (epoch, batch_i, total_batch),
                    epoch=epoch,
                    predict_save_path=opt.pred_path +
                    '/epoch%d_batch%d_total_batch%d' %
                    (epoch, batch_i, total_batch))

                checkpoint_names.append('epoch=%d-batch=%d-total_batch=%d' %
                                        (epoch, batch_i, total_batch))

                curve_names = []
                scores = []
                if opt.train_ml:
                    train_ml_history_losses.append(copy.copy(train_ml_losses))
                    scores += [train_ml_history_losses]
                    curve_names += ['Training ML Error']
                    train_ml_losses = []

                if opt.train_rl:
                    train_rl_history_losses.append(copy.copy(train_rl_losses))
                    scores += [train_rl_history_losses]
                    curve_names += ['Training RL Reward']
                    train_rl_losses = []

                valid_history_losses.append(valid_score_dict)
                test_history_losses.append(test_score_dict)

                scores += [[
                    result_dict[name] for result_dict in valid_history_losses
                ] for name in opt.report_score_names]
                curve_names += [
                    'Valid-' + name for name in opt.report_score_names
                ]
                scores += [[
                    result_dict[name] for result_dict in test_history_losses
                ] for name in opt.report_score_names]
                curve_names += [
                    'Test-' + name for name in opt.report_score_names
                ]

                scores = [np.asarray(s) for s in scores]
                # Plot the learning curve
                plot_learning_curve_and_write_csv(
                    scores=scores,
                    curve_names=curve_names,
                    checkpoint_names=checkpoint_names,
                    title='Training Validation & Test',
                    save_path=opt.exp_path +
                    '/[epoch=%d,batch=%d,total_batch=%d]train_valid_test_curve.png'
                    % (epoch, batch_i, total_batch))
                '''
                determine if early stop training (whether f-score increased, before is if valid error decreased)
                '''
                valid_loss = np.average(
                    valid_history_losses[-1][opt.report_score_names[0]])
                is_best_loss = valid_loss > best_loss
                rate_of_change = float(valid_loss - best_loss) / float(
                    best_loss) if float(best_loss) > 0 else 0.0

                # valid error doesn't increase
                if rate_of_change <= 0:
                    stop_increasing += 1
                else:
                    stop_increasing = 0

                if is_best_loss:
                    logging.info(
                        'Validation: update best loss (%.4f --> %.4f), rate of change (ROC)=%.2f'
                        % (best_loss, valid_loss, rate_of_change * 100))
                else:
                    logging.info(
                        'Validation: best loss is not updated for %d times (%.4f --> %.4f), rate of change (ROC)=%.2f'
                        % (stop_increasing, best_loss, valid_loss,
                           rate_of_change * 100))

                best_loss = max(valid_loss, best_loss)

                # only store the checkpoints that make better validation performances
                if total_batch > 1 and (
                        total_batch % opt.save_model_every == 0 or
                        is_best_loss):  # epoch >= opt.start_checkpoint_at and
                    # Save the checkpoint
                    logging.info('Saving checkpoint to: %s' % os.path.join(
                        opt.model_path,
                        '%s.epoch=%d.batch=%d.total_batch=%d.error=%f' %
                        (opt.exp, epoch, batch_i, total_batch, valid_loss) +
                        '.model'))
                    torch.save(
                        model.state_dict(),
                        open(
                            os.path.join(
                                opt.model_path,
                                '%s.epoch=%d.batch=%d.total_batch=%d' %
                                (opt.exp, epoch, batch_i, total_batch) +
                                '.model'), 'wb'))
                    torch.save((epoch, total_batch, best_loss, stop_increasing,
                                checkpoint_names, train_ml_history_losses,
                                train_rl_history_losses, valid_history_losses,
                                test_history_losses),
                               open(
                                   os.path.join(
                                       opt.model_path,
                                       '%s.epoch=%d.batch=%d.total_batch=%d' %
                                       (opt.exp, epoch, batch_i, total_batch) +
                                       '.state'), 'wb'))

                if stop_increasing >= opt.early_stop_tolerance:
                    logging.info(
                        'Have not increased for %d epoches, early stop training'
                        % stop_increasing)
                    early_stop_flag = True
                    break
                logging.info('*' * 50)
Exemple #14
0
def run(
    kn=lambda x: 1.0,
    sigma_s=lambda x: 1.0,
    sigma_a=lambda x: 0.0,
    Q=lambda x: 0.0,
    xmin=0.0,
    xmax=1.0,
    nx=40,
    dt=0.01,
    nt=1000,
    coll="linear",
    scheme="Euler",
    BC="periodic",
    f_l=lambda v: 1.0,
    f_r=lambda v: 0.0,
    init_func=lambda vmesh, u, T, rho: 0.0,
):
    # Load config
    config = collision.utils.CollisionConfig.from_json(
        "./linear_transport/configs/" + "parity" + ".json")

    # Collision
    vmesh = collision.CartesianMesh(config)
    # print(vmesh.centers[0], vmesh.weights)
    if coll == "linear":
        coll_op = collision.LinearCollision(config, vmesh)
    elif coll == "rbm":
        coll_op = collision.RandomBatchLinearCollision(config, vmesh)
    elif coll == "rbm_symm":
        coll_op = collision.SymmetricRBMLinearCollision(config, vmesh)
    else:
        raise NotImplementedError(
            "Collision method {} is not implemented.".format(coll))

    # x domian
    x = pykinetic.Dimension(xmin, xmax, nx, name="x")
    # print(x.centers_with_ghost(2))
    domain = pykinetic.Domain([x])

    # Riemann solver
    rp = pykinetic.riemann.parity_1D
    solver = APNeutronTransportSolver1D(
        rp,
        [coll_op],
        kn=kn(x.centers),
        sigma_s=sigma_s(x.centers),
        sigma_a=sigma_a(x.centers),
        Q=Q(x.centers),
    )
    # print(solver.kn)
    solver.order = 1
    # solver.lim_type = -1
    solver.time_integrator = scheme
    solver.dt = dt
    print("dt is {}".format(solver.dt))

    sigma = sigma_s(x.centers) + kn(x.centers)**2 * sigma_a(x.centers)
    sigma_l, sigma_r = None, None
    if isinstance(sigma, np.ndarray):
        sigma_l, sigma_r = sigma[0], sigma[-1]
    else:
        sigma_l = sigma_r = sigma
    kn_l, kn_r = None, None
    if isinstance(solver.kn, np.ndarray):
        kn_l, kn_r = solver.kn[0], solver.kn[-1]
    else:
        kn_l = kn_r = solver.kn

    # Boundary conditions
    def dirichlet_lower_BC(state, dim, t, qbc, auxbc, num_ghost):
        v = state.problem_data["v"][0] / sigma_l
        kn_dx = kn_l / state.grid.delta[0]
        for i in range(num_ghost):
            qbc[0,
                i, :] = (f_l(v) - (0.5 - kn_dx * v) * qbc[0, num_ghost, :]) / (
                    0.5 + kn_dx * v)
        for i in range(num_ghost):
            qbc[1,
                i, :] = (2 * f_l(v) -
                         (qbc[0, num_ghost - 1, :] +
                          qbc[0, num_ghost, :])) / kn_l - qbc[1, num_ghost, :]

    def dirichlet_upper_BC(state, dim, t, qbc, auxbc, num_ghost):
        v = state.problem_data["v"][0] / sigma_r
        kn_dx = kn_r / state.grid.delta[0]
        for i in range(num_ghost):
            qbc[0, -i -
                1, :] = (f_r(v) -
                         (0.5 - kn_dx * v) * qbc[0, -num_ghost - 1, :]) / (
                             0.5 + kn_dx * v)
        for i in range(num_ghost):
            qbc[1, -i -
                1, :] = ((qbc[0, -num_ghost, :] + qbc[0, -num_ghost - 1, :]) -
                         2 * f_r(v)) / kn_r - qbc[1, -num_ghost - 1, :]

    if BC == "periodic":
        solver.bc_lower[0] = pykinetic.BC.periodic
        solver.bc_upper[0] = pykinetic.BC.periodic
    elif BC == "dirichlet":
        solver.bc_lower[0] = pykinetic.BC.custom
        solver.bc_upper[0] = pykinetic.BC.custom
        solver.user_bc_lower = dirichlet_lower_BC
        solver.user_bc_upper = dirichlet_upper_BC
    else:
        raise ValueError("Given BC type is not avaliable!")

    state = pykinetic.State(domain, vmesh, 2)
    state.problem_data["v"] = vmesh.centers
    state.problem_data["phi"] = phi(solver.kn)
    state.problem_data["sqrt_phi"] = np.sqrt(phi(solver.kn))

    qinit(state, vmesh, solver.kn, init_func)
    sol = pykinetic.Solution(state, domain)

    output_dict = {}
    sol_frames, macro_frames, ts = (
        [copy.deepcopy(sol.q)],
        [compute_rho(sol.state, vmesh)],
        [0.0],
    )
    pbar = Progbar(nt)
    for t in range(nt):
        solver.evolve_to_time(sol)
        sol_frames.append(copy.deepcopy(sol.q))
        macro_frames.append(compute_rho(sol.state, vmesh))
        # Test
        # qbc = solver.qbc
        # print(
        #     np.max(
        #         np.abs(
        #             5.0 * np.sin(vmesh.centers[0])
        #             - (
        #                 0.5 * (qbc[0, 1] + qbc[0, 2])
        #                 + kn * 0.5 * (qbc[1, 1] + qbc[1, 2])
        #             )
        #         )
        #     )
        # )
        ts.append(0.0 + (t + 1) * dt)
        pbar.update(t + 1, finalize=False)
    pbar.update(nt, finalize=True)

    output_dict["f_frames"] = np.asarray(sol_frames)
    output_dict["macro_frames"] = macro_frames
    output_dict["x"] = x.centers
    output_dict["t"] = ts
    output_dict["v"] = np.asarray(vmesh.centers[0])

    return output_dict
Exemple #15
0
def train_model(model, optimizer, criterion, training_data_loader,
                validation_data_loader, opt):
    logging.info(
        '======================  Checking GPU Availability  ========================='
    )
    if torch.cuda.is_available():
        if isinstance(opt.gpuid, int):
            opt.gpuid = [opt.gpuid]
        logging.info('Running on GPU! devices=%s' % str(opt.gpuid))
        model = model.cuda()
        model = nn.DataParallel(model, device_ids=opt.gpuid)
        criterion.cuda()
    else:
        logging.info('Running on CPU!')

    logging.info(
        '======================  Start Training  =========================')

    train_history_losses = []
    valid_history_losses = []
    best_loss = sys.float_info.max

    train_losses = []
    total_batch = 0
    early_stop_flag = False

    for epoch in range(opt.start_epoch, opt.epochs):
        if early_stop_flag:
            break

        progbar = Progbar(title='Training',
                          target=len(training_data_loader),
                          batch_size=opt.batch_size,
                          total_examples=len(training_data_loader.dataset))
        model.train()

        for batch_i, batch in enumerate(training_data_loader):
            batch_i += 1  # for the aesthetics of printing
            total_batch += 1
            src = batch.src
            trg = batch.trg

            # print("src size - ",src.size())
            # print("target size - ",trg.size())
            if torch.cuda.is_available():
                src.cuda()
                trg.cuda()

            optimizer.zero_grad()
            decoder_logits, _, _ = model.forward(src,
                                                 trg,
                                                 must_teacher_forcing=False)

            start_time = time.time()

            # remove the 1st word in trg to let predictions and real goal match
            loss = criterion(
                decoder_logits.contiguous().view(-1, opt.vocab_size),
                trg[:, 1:].contiguous().view(-1))
            print("--loss calculation- %s seconds ---" %
                  (time.time() - start_time))

            start_time = time.time()
            loss.backward()
            print("--backward- %s seconds ---" % (time.time() - start_time))

            if opt.max_grad_norm > 0:
                pre_norm = torch.nn.utils.clip_grad_norm(
                    model.parameters(), opt.max_grad_norm)
                after_norm = (sum([
                    p.grad.data.norm(2)**2 for p in model.parameters()
                    if p.grad is not None
                ]))**(1.0 / 2)
                logging.info('clip grad (%f -> %f)' % (pre_norm, after_norm))

            optimizer.step()

            train_losses.append(loss.data[0])
            perplexity = np.math.exp(loss.data[0])

            progbar.update(epoch, batch_i, [('train_loss', loss.data[0]),
                                            ('perplexity', perplexity)])

            if batch_i > 1 and batch_i % opt.report_every == 0:
                logging.info(
                    '======================  %d  =========================' %
                    (batch_i))

                logging.info(
                    'Epoch : %d Minibatch : %d, Loss=%.5f, PPL=%.5f' %
                    (epoch, batch_i, np.mean(loss.data[0]), perplexity))
                sampled_size = 2
                logging.info(
                    'Printing predictions on %d sampled examples by greedy search'
                    % sampled_size)

                # softmax logits to get probabilities (batch_size, trg_len, vocab_size)
                # decoder_probs = torch.nn.functional.softmax(decoder_logits.view(trg.size(0) * trg.size(1), -1)).view(*trg.size(), -1)

                if torch.cuda.is_available():
                    src = src.data.cpu().numpy()
                    decoder_logits = decoder_logits.data.cpu().numpy()
                    max_words_pred = decoder_logits.argmax(axis=-1)
                    trg = trg.data.cpu().numpy()
                else:
                    src = src.data.numpy()
                    decoder_logits = decoder_logits.data.numpy()
                    max_words_pred = decoder_logits.argmax(axis=-1)
                    trg = trg.data.numpy()

                sampled_trg_idx = np.random.random_integers(low=0,
                                                            high=len(trg) - 1,
                                                            size=sampled_size)
                src = src[sampled_trg_idx]
                max_words_pred = [max_words_pred[i] for i in sampled_trg_idx]
                decoder_logits = decoder_logits[sampled_trg_idx]
                trg = [trg[i][1:] for i in sampled_trg_idx
                       ]  # the real target has removed the starting <BOS>

                for i, (src_wi, pred_wi,
                        real_wi) in enumerate(zip(src, max_words_pred, trg)):
                    nll_prob = -np.sum(
                        np.log2([
                            decoder_logits[i][l][pred_wi[l]]
                            for l in range(len(real_wi))
                        ]))
                    sentence_source = [opt.id2word[x] for x in src_wi]
                    sentence_pred = [opt.id2word[x] for x in pred_wi]
                    sentence_real = [opt.id2word[x] for x in real_wi]

                    logging.info(
                        '==================================================')
                    logging.info('Source: %s ' % (' '.join(sentence_source)))
                    logging.info('\t\tPred : %s (%.4f)' %
                                 (' '.join(sentence_pred), nll_prob))
                    logging.info('\t\tReal : %s ' % (' '.join(sentence_real)))

            if total_batch > 1 and total_batch % opt.run_valid_every == 0:
                logging.info('*' * 50)
                logging.info(
                    'Run validation test @Epoch=%d,#(Total batch)=%d' %
                    (epoch, total_batch))
                valid_losses = _valid(validation_data_loader,
                                      model,
                                      criterion,
                                      optimizer,
                                      epoch,
                                      opt,
                                      is_train=False)

                train_history_losses.append(copy.copy(train_losses))
                valid_history_losses.append(valid_losses)
                train_losses = []

                # Plot the learning curve
                plot_learning_curve(
                    train_history_losses,
                    valid_history_losses,
                    'Training and Validation',
                    curve1_name='Training Error',
                    curve2_name='Validation Error',
                    save_path=opt.exp_path +
                    '/[epoch=%d,batch=%d,total_batch=%d]train_valid_curve.png'
                    % (epoch, batch_i, total_batch))
                '''
                determine if early stop training
                '''
                valid_loss = np.average(valid_history_losses[-1])
                is_best_loss = valid_loss < best_loss
                rate_of_change = float(valid_loss -
                                       best_loss) / float(best_loss)

                # only store the checkpoints that make better validation performances
                if total_batch > 1 and epoch >= opt.start_checkpoint_at and (
                        total_batch % opt.save_model_every == 0
                        or is_best_loss):
                    # Save the checkpoint
                    logging.info('Saving checkpoint to: %s' % os.path.join(
                        opt.save_path,
                        '%s.epoch=%d.batch=%d.total_batch=%d.error=%f' %
                        (opt.exp, epoch, batch_i, total_batch, valid_loss) +
                        '.model'))
                    torch.save(
                        model.state_dict(),
                        open(
                            os.path.join(
                                opt.save_path,
                                '%s.epoch=%d.batch=%d.total_batch=%d' %
                                (opt.exp, epoch, batch_i, total_batch) +
                                '.model'), 'wb'))

                # valid error doesn't decrease
                if rate_of_change >= 0:
                    stop_increasing += 1
                else:
                    stop_increasing = 0

                if is_best_loss:
                    logging.info(
                        'Validation: update best loss (%.4f --> %.4f), rate of change (ROC)=%.2f'
                        % (best_loss, valid_loss, rate_of_change * 100))
                else:
                    logging.info(
                        'Validation: best loss is not updated for %d times (%.4f --> %.4f), rate of change (ROC)=%.2f'
                        % (stop_increasing, best_loss, valid_loss,
                           rate_of_change * 100))

                best_loss = min(valid_loss, best_loss)
                if stop_increasing >= opt.early_stop_tolerance:
                    logging.info(
                        'Have not increased for %d epoches, early stop training'
                        % stop_increasing)
                    early_stop_flag = True
                    break
                logging.info('*' * 50)
Exemple #16
0
def main(_):

    random_seed = 12345
    os.environ["PYTHONHASHSEED"] = str(random_seed)
    random.seed(random_seed)
    np.random.seed(random_seed)
    tf.set_random_seed(random_seed)

    start_logger(FLAGS.model_save_filename + ".train_log")
    atexit.register(stop_logger)

    print("-- Building vocabulary")
    #embeddings, token2id, id2token = load_glove(args.vectors_filename, args.max_vocab, args.embeddings_size)

    label2id = {"neutral": 0, "entailment": 1, "contradiction": 2}
    id2label = {v: k for k, v in label2id.items()}

    #num_tokens = len(token2id)

    num_labels = len(label2id)

    #print("Number of tokens: {}".format(num_tokens))
    print("Number of labels: {}".format(num_labels))

    # Load e_vsnli
    # Explanations are encoded/padded, we ignore original explanations
    print("-- Loading training set")

    train_labels, train_explanations, train_premises, train_hypotheses, train_img_names, _, _, _, train_max_length, embeddings, token2id, id2token, _ = \
        load_e_vsnli_dataset_and_glove(
            FLAGS.train_filename,
            label2id,
            FLAGS.vectors_filename,
            FLAGS.max_vocab,
            model_config.embedding_size,
            buffer_size=FLAGS.buffer_size,
            min_threshold = FLAGS.min_threshold,
        )

    num_tokens = len(token2id)
    print("Number of tokens after filtering: ", num_tokens)

    print("-- Loading development set")
    dev_labels, dev_explanations, dev_premises, dev_hypotheses, dev_img_names, dev_original_explanations, _, _, dev_max_length, _ = \
        load_e_vsnli_dataset(
            FLAGS.dev_filename,
            token2id,
            label2id,
            buffer_size=FLAGS.buffer_size,
            padding_length=train_max_length,
        )

    if FLAGS.imbalance == True:
        dev_num_examples = dev_labels.shape[0]
        class_freqs = np.bincount(dev_labels) / dev_num_examples
        class_weights = 1 / (class_freqs * num_labels)
        print("Class frequencies: ", class_freqs)
        print("Weights: ", class_weights)
        np.save(FLAGS.model_save_filename + '_class_freqs.npy', class_freqs)
    print("-- Loading images")
    image_reader = ImageReader(FLAGS.img_names_filename,
                               FLAGS.img_features_filename)

    print("-- Saving parameters")
    with open(FLAGS.model_save_filename + ".params", mode="w") as out_file:
        json.dump(vars(FLAGS), out_file)
        print("Params saved to: {}".format(FLAGS.model_save_filename +
                                           ".params"))

        with open(FLAGS.model_save_filename + ".index", mode="wb") as out_file:
            pickle.dump(
                {
                    "token2id": token2id,
                    "id2token": id2token,
                    "label2id": label2id,
                    "id2label": id2label
                }, out_file)
            print("Index saved to: {}".format(FLAGS.model_save_filename +
                                              ".index"))

    model_config.set_vocab_size(num_tokens)
    print("Vocab size, set to %d" % model_config.vocab_size)
    model_config.set_alpha(FLAGS.alpha)
    print("alpha = %f, set!" % model_config.alpha)

    ilabel2itoken = {}
    for i in id2label:
        label = id2label[i]
        if label in token2id:
            j = token2id[label]
        else:
            j = token2id["#unk#"]
        ilabel2itoken[i] = j

    print("label_id --> token_id: constructed!")

    num_examples = train_labels.shape[0]
    num_batches = num_examples // FLAGS.batch_size

    dev_num_examples = dev_labels.shape[0]
    dev_batches_indexes = np.arange(dev_num_examples)
    num_batches_dev = dev_num_examples // FLAGS.dev_batch_size

    tf.reset_default_graph()

    # Build the TensorFlow graph and train it
    g = tf.Graph()
    with g.as_default():

        model = build_model(model_config,
                            embeddings,
                            ilabel2itoken=ilabel2itoken,
                            mode=mode)

        # Set up the learning rate.
        learning_rate_decay_fn = None
        learning_rate = tf.constant(training_config.initial_learning_rate)
        if training_config.learning_rate_decay_factor > 0:
            num_batches_per_epoch = (num_examples / FLAGS.batch_size)
            decay_steps = int(num_batches_per_epoch *
                              training_config.num_epochs_per_decay)

            def _learning_rate_decay_fn(learning_rate, global_step):
                return tf.train.exponential_decay(
                    learning_rate,
                    global_step,
                    decay_steps=decay_steps,
                    decay_rate=training_config.learning_rate_decay_factor,
                    staircase=True)

            learning_rate_decay_fn = _learning_rate_decay_fn

        # Set up the training ops.
        train_op = tf.contrib.layers.optimize_loss(
            loss=model['total_loss'],
            global_step=model['global_step'],
            learning_rate=learning_rate,
            optimizer=training_config.optimizer,
            clip_gradients=training_config.clip_gradients,
            learning_rate_decay_fn=learning_rate_decay_fn)

        dev_best_accuracy = -1
        stopping_step = 0
        best_epoch = None
        should_stop = False

        # initialize all variables
        init = tf.global_variables_initializer()

        with tf.Session() as session:
            session.run(init)
            #session.run(tf.initializers.tables_initializer(name='init_all_tables'))

            t = 0  # counting iterations

            time_now = datetime.now()

            for epoch in range(training_config.total_num_epochs):
                if should_stop:
                    break

                print("\n==> Online epoch # {0}".format(epoch + 1))
                progress = Progbar(num_batches)
                batches_indexes = np.arange(num_examples)
                np.random.shuffle(batches_indexes)

                np.random.shuffle(batches_indexes)
                batch_index = 1
                loss_history = []
                epoch_loss = 0

                for indexes in batch(batches_indexes, FLAGS.batch_size):

                    t += 1
                    batch_hypotheses = train_hypotheses[indexes]
                    batch_labels = train_labels[indexes]

                    # explanations have been encoded / padded when loaded
                    batch_explanations = train_explanations[indexes]
                    batch_explanation_lengths = [
                        len(expl) for expl in batch_explanations
                    ]

                    batch_img_names = [train_img_names[i] for i in indexes]
                    batch_img_features = image_reader.get_features(
                        batch_img_names)

                    total_loss_value = _step(
                        session, batch_hypotheses, batch_labels,
                        batch_explanations, batch_img_features, train_op,
                        model, model_config.lstm_dropout_keep_prob
                    )  # run each training step

                    progress.update(batch_index, [("Loss", total_loss_value)])
                    loss_history.append(total_loss_value)
                    epoch_loss += total_loss_value
                    batch_index += 1

                    if FLAGS.print_every > 0 and t % FLAGS.print_every == 0:
                        print(
                            '(Iteration %d) loss: %f, and time elapsed: %.2f minutes'
                            % (t + 1, float(loss_history[-1]),
                               (datetime.now() - time_now).seconds / 60.0))

                print("Current mean training loss: {}\n".format(epoch_loss /
                                                                num_batches))

                print("-- Validating model")

                progress = Progbar(num_batches_dev)

                dev_num_correct = 0
                dev_batch_index = 0

                for indexes in batch(dev_batches_indexes,
                                     FLAGS.dev_batch_size):

                    t += 1

                    dev_batch_num_correct = 0

                    dev_batch_index += 1
                    dev_batch_hypotheses = dev_hypotheses[indexes]
                    dev_batch_labels = dev_labels[indexes]

                    # explanations have been encoded / padded when loaded
                    dev_batch_explanations = dev_explanations[indexes]
                    dev_batch_img_names = [dev_img_names[i] for i in indexes]
                    dev_batch_img_features = image_reader.get_features(
                        dev_batch_img_names)

                    pred_explanations, pred_labels = _run_validation(
                        session, dev_batch_hypotheses, dev_batch_labels,
                        dev_batch_explanations, dev_batch_img_features,
                        len(indexes), ilabel2itoken, model, 1.0)

                    if FLAGS.imbalance == True:
                        dev_batch_num_correct += np.dot(
                            pred_labels == dev_batch_labels,
                            class_weights[dev_batch_labels])
                    else:
                        dev_batch_num_correct += (
                            pred_labels == dev_batch_labels).sum()
                    dev_num_correct += dev_batch_num_correct

                    progress.update(
                        dev_batch_index,
                        [("Proportion of correct labels",
                          float(dev_batch_num_correct) / len(indexes))])
                    if FLAGS.sample_every > 0 and (
                            t + 1) % FLAGS.sample_every == 0:
                        pred_explanations = [
                            unpack.reshape(-1, 1)
                            for unpack in pred_explanations
                        ]
                        pred_explanations = np.concatenate(
                            pred_explanations, 1)
                        pred_explanations_decoded = [
                            decode(pred_explanations[i], id2token)
                            for i in range(len(indexes))
                        ]
                        print("\nExample generated explanation: ",
                              pred_explanations_decoded[0])  #TODO: decode it
                        #print("Original explanation: ", dev_original_explanations[indexes][0])

                dev_accuracy = float(dev_num_correct) / dev_num_examples
                print("Current mean validation accuracy: {}".format(
                    dev_accuracy))

                #if True:
                if dev_accuracy > dev_best_accuracy:
                    stopping_step = 0
                    best_epoch = epoch + 1
                    dev_best_accuracy = dev_accuracy
                    model['saver'].save(session,
                                        FLAGS.model_save_filename + ".ckpt")
                    print(
                        "Best mean validation accuracy: {} (reached at epoch {})"
                        .format(dev_best_accuracy, best_epoch))
                    print("Best model saved to: {}".format(
                        FLAGS.model_save_filename))
                else:
                    stopping_step += 1
                    print("Current stopping step: {}".format(stopping_step))
                if stopping_step >= FLAGS.patience:
                    print("Early stopping at epoch {}!".format(epoch + 1))
                    print(
                        "Best mean validation accuracy: {} (reached at epoch {})"
                        .format(dev_best_accuracy, best_epoch))
                    should_stop = True
                if epoch + 1 >= training_config.total_num_epochs:
                    print("Stopping at epoch {}!".format(epoch + 1))
                    print(
                        "Best mean validation accuracy: {} (reached at epoch {})"
                        .format(dev_best_accuracy, best_epoch))
Exemple #17
0
def train_model(model, optimizer_ml, optimizer_rl, criterion,
                train_data_loader, valid_data_loaders, test_data_loaders, opt):
    generator = SequenceGenerator(model,
                                  eos_id=opt.word2id[pykp.io.EOS_WORD],
                                  beam_size=opt.beam_size,
                                  max_sequence_length=opt.max_sent_length)
    logger = logging.getLogger('train.py')
    logger.info(
        '======================  Checking GPU Availability  ========================='
    )
    if torch.cuda.is_available():
        if isinstance(opt.gpuid, int):
            opt.gpuid = [opt.gpuid]
            logger.info('Running on GPU! devices=%s' % str(opt.gpuid))
        # model = nn.DataParallel(model, device_ids=opt.gpuid)
        model = model.cuda()
    else:
        logger.info('Running on CPU!')

        logger.info(
            '======================  Start Training  ========================='
        )

    checkpoint_names = []
    train_ml_history_losses = []
    train_rl_history_losses = []
    valid_history_scores = {}
    test_history_scores = {}
    # best_loss = sys.float_info.max # for normal training/testing loss (likelihood)
    best_loss = 0.0  # for f-score
    stop_increasing = 0

    train_ml_losses = []
    train_rl_losses = []
    total_batch = -1
    early_stop_flag = False
    if opt.train_rl:
        reward_cache = RewardCache(2000)

    # if False:  # opt.train_from:
    #     state_path = opt.train_from.replace('.model', '.state')
    #     logger.info('Loading training state from: %s' % state_path)
    #     if os.path.exists(state_path):
    #         (epoch, total_batch, best_loss, stop_increasing, checkpoint_names, train_ml_history_losses, train_rl_history_losses, valid_history_scores,
    #          test_history_losses) = torch.load(open(state_path, 'rb'))
    #         opt.start_epoch = epoch

    for epoch in range(opt.start_epoch, opt.epochs):
        if early_stop_flag:
            break

        progbar = Progbar(logger=logger,
                          title='Training',
                          target=len(train_data_loader),
                          batch_size=train_data_loader.batch_size,
                          total_examples=len(train_data_loader.dataset))

        for batch_i, batch in enumerate(train_data_loader):
            model.train()
            total_batch += 1
            one2many_batch, one2one_batch = batch
            report_loss = []

            # Training
            if opt.train_ml:
                loss_ml, decoder_log_probs = train_ml(one2one_batch, model,
                                                      optimizer_ml, criterion,
                                                      opt)

                # len(decoder_log_probs) == 0 if encountered OOM
                if len(decoder_log_probs) == 0:
                    continue

                train_ml_losses.append(loss_ml)
                report_loss.append(('train_ml_loss', loss_ml))
                report_loss.append(('PPL', loss_ml))

                # Brief report
                if batch_i % opt.report_every == 0:
                    brief_report(epoch, batch_i, one2one_batch, loss_ml,
                                 decoder_log_probs, opt)

            # do not apply rl in 0th epoch, need to get a resonable model before that.
            if opt.train_rl:
                if epoch >= opt.rl_start_epoch:
                    loss_rl = train_rl(one2many_batch, model, optimizer_rl,
                                       generator, opt, reward_cache)
                else:
                    loss_rl = 0.0
                train_rl_losses.append(loss_rl)
                report_loss.append(('train_rl_loss', loss_rl))

            progbar.update(epoch, batch_i, report_loss)
            '''
            Validate and save checkpoint
            '''
            if (opt.run_valid_every == -1 and batch_i == len(train_data_loader) - 1) or\
               (opt.run_valid_every > -1 and total_batch > 1 and total_batch % opt.run_valid_every == 0):
                logger.info('*' * 50)
                logger.info(
                    'Run validing and testing @Epoch=%d,#(Total batch)=%d' %
                    (epoch, total_batch))

                # return a dict, key is the dataset name and value is a score dict
                valid_score_dict = evaluate.evaluate_multiple_datasets(
                    generator,
                    valid_data_loaders,
                    opt,
                    epoch=epoch,
                    title='valid.epoch=%d.total_batch=%d' %
                    (epoch, total_batch),
                    predict_save_path=os.path.join(
                        opt.pred_path, 'epoch%d_batch%d_total_batch%d' %
                        (epoch, batch_i, total_batch)))
                test_score_dict = evaluate.evaluate_multiple_datasets(
                    generator,
                    test_data_loaders,
                    opt,
                    epoch=epoch,
                    title='test.epoch=%d.total_batch=%d' %
                    (epoch, total_batch),
                    predict_save_path=os.path.join(
                        opt.pred_path, 'epoch%d_batch%d_total_batch%d' %
                        (epoch, batch_i, total_batch)))
                '''
                Merge scores of current round into history_score
                '''
                for dataset_name, score_dict in valid_score_dict.items():
                    # each history_loss is a dict, specific to a dataset
                    # key is score name and value is a list, each element is a list of scores (e.g. f1_score) of all examples
                    valid_history_score = valid_history_scores.get(
                        dataset_name, {})
                    for score_name, score_values in score_dict.items():
                        history_score_values = valid_history_score.get(
                            score_name, [])
                        history_score_values.append(score_values)
                        valid_history_score[score_name] = history_score_values
                    valid_history_scores[dataset_name] = valid_history_score

                for dataset_name, score_dict in test_score_dict.items():
                    test_history_score = test_history_scores.get(
                        dataset_name, {})
                    for score_name, score_values in score_dict.items():
                        history_score_values = test_history_score.get(
                            score_name, [])
                        history_score_values.append(score_values)
                        test_history_score[score_name] = history_score_values
                    test_history_scores[dataset_name] = test_history_score

                if opt.train_ml:
                    train_ml_history_losses.append(copy.copy(train_ml_losses))
                    train_ml_losses = []
                if opt.train_rl:
                    train_rl_history_losses.append(copy.copy(train_rl_losses))
                    train_rl_losses = []
                '''
                Iterate each dataset (including a merged 'all_datasets') and plot learning curves
                '''
                for dataset_name in opt.test_dataset_names + ['all_datasets']:
                    valid_history_score = valid_history_scores[dataset_name]
                    test_history_score = test_history_scores[dataset_name]
                    curve_names = []
                    scores_for_plot = []
                    if opt.train_ml:
                        scores_for_plot += [train_ml_history_losses]
                        curve_names += ['Training ML Error']

                    if opt.train_rl:
                        scores_for_plot += [train_rl_history_losses]
                        curve_names += ['Training RL Reward']

                    scores_for_plot += [
                        valid_history_score[name]
                        for name in opt.report_score_names
                    ]
                    curve_names += [
                        'Valid-' + name for name in opt.report_score_names
                    ]
                    scores_for_plot += [
                        test_history_score[name]
                        for name in opt.report_score_names
                    ]
                    curve_names += [
                        'Test-' + name for name in opt.report_score_names
                    ]

                    scores_for_plot = [np.asarray(s) for s in scores_for_plot]
                    '''
                    Plot the learning curve
                    '''
                    plot_learning_curve_and_write_csv(
                        scores=scores_for_plot,
                        curve_names=curve_names,
                        checkpoint_names=checkpoint_names,
                        title='Training Validation & Test of %s' %
                        dataset_name,
                        save_path=opt.plot_path +
                        '/[epoch=%d,batch=%d,total_batch=%d].%s.learning_curve'
                        % (epoch, batch_i, total_batch, dataset_name))
                '''
                determine if early stop training (whether f-score increased, previously is if valid error decreased)
                opt.report_score_names[0] is 'f_score@5_exact'
                '''
                valid_loss = np.average(valid_history_scores['all_datasets'][
                    opt.report_score_names[0]][-1])
                is_best_loss = valid_loss > best_loss
                rate_of_change = float(valid_loss - best_loss) / float(
                    best_loss) if float(best_loss) > 0 else 0.0

                # valid error doesn't increase
                if rate_of_change <= 0:
                    stop_increasing += 1
                else:
                    stop_increasing = 0

                if is_best_loss:
                    logging.info(
                        'Validation: update best loss (%.4f --> %.4f), rate of change (ROC)=%.2f'
                        % (best_loss, valid_loss, rate_of_change * 100))
                else:
                    logging.info(
                        'Validation: best loss is not updated for %d times (%.4f --> %.4f), rate of change (ROC)=%.2f'
                        % (stop_increasing, best_loss, valid_loss,
                           rate_of_change * 100))

                logging.info(
                    'Current test loss (over %d datasets): %s\n' %
                    (len(opt.test_dataset_names), str(opt.test_dataset_names)))
                for report_score_name in opt.report_score_names:
                    test_loss = np.average(test_history_scores['all_datasets']
                                           [report_score_name][-1])
                    logging.info('\t\t %s = %.4f' %
                                 (report_score_name, test_loss))

                best_loss = max(valid_loss, best_loss)
                '''
                Save checkpoints, only store the ones that make better validation performances
                '''
                checkpoint_names.append('epoch=%d-batch=%d-total_batch=%d' %
                                        (epoch, batch_i, total_batch))

                if total_batch > 1 and (
                        total_batch % opt.save_model_every == 0 or
                        is_best_loss):  # epoch >= opt.start_checkpoint_at and
                    # Save the checkpoint
                    logging.info('Saving checkpoint to: %s' % os.path.join(
                        opt.model_path,
                        '%s.epoch=%d.batch=%d.total_batch=%d.error=%f' %
                        (opt.exp, epoch, batch_i, total_batch, valid_loss) +
                        '.model'))
                    torch.save(
                        model.state_dict(),
                        open(
                            os.path.join(
                                opt.model_path,
                                '%s.epoch=%d.batch=%d.total_batch=%d' %
                                (opt.exp, epoch, batch_i, total_batch) +
                                '.model'), 'wb'))
                    torch.save((epoch, total_batch, best_loss, stop_increasing,
                                checkpoint_names, train_ml_history_losses,
                                train_rl_history_losses, valid_history_scores,
                                test_history_scores),
                               open(
                                   os.path.join(
                                       opt.model_path,
                                       '%s.epoch=%d.batch=%d.total_batch=%d' %
                                       (opt.exp, epoch, batch_i, total_batch) +
                                       '.state'), 'wb'))

                if stop_increasing >= opt.early_stop_tolerance:
                    logging.info(
                        'Have not increased for %d epoches, early stop training'
                        % stop_increasing)
                    early_stop_flag = True
                    break

                logging.info('*' * 50)
    def train(self, save_dir=None):
        criterion = Dynamic_Cross_Entropy_Loss()
        padded_q_word, q_word_lengths, padded_q_dep, q_dep_lengths, padded_rel_words, rel_word_lengths, \
        batch_rel_ids, padded_cons_words, cons_word_lengths, padded_cons_id, cons_id_lengths, batch_prior_weights, \
        batch_labels = None, None, None, None, None, None, None, None, None, None, None, None, None
        epoch_loss_history = [float("inf"), ]
        sys.stderr.write('Max Epoch: {}\n'.format(self.max_epoch))
        for epoch in range(self.max_epoch):
            self.scheduler.step(epoch)
            progbar = Progbar(len(self.data_loader), file=sys.stderr)
            prog_idx = 0
            shuffled_indices = random.sample(range(len(self.data_loader)), len(self.data_loader))
            random.shuffle(shuffled_indices)
            for curr_index in shuffled_indices:
                data_dict = self.data_loader.get_one_batch(curr_index)
                padded_q_word, q_word_lengths, padded_q_dep, q_dep_lengths, padded_rel_words, rel_word_lengths, batch_rel_ids, \
                padded_cons_words,cons_word_lengths, padded_cons_id, cons_id_lengths, batch_prior_weights, batch_labels \
                    = self.unpack_data_dict(data_dict)
                # print "curr_idx: {}".format(curr_index)
                # print "padded_q_word: {}".format(padded_q_word)
                self.model.zero_grad()
                self.optimizer.zero_grad()
                # train
                if epoch <= self.pooling_threshold:
                    if self.data_loader.use_entity_type:
                        self.model.q_encoder.q_word_emb.weight[50:].requres_grad = False
                    else:
                        self.model.q_encoder.q_word_emb.weight[4:].requres_grad = False

                '''don't need VAR(use_constraint) because model has been initialized'''
                out = None # placeholder
                if not self.data_loader.cpu_data:
                    out = self.model(padded_q_word_seq=padded_q_word, q_word_lengths=q_word_lengths,
                                     padded_q_dep_seq=padded_q_dep, q_dep_lengths=q_dep_lengths,
                                     padded_rel_words_seq=padded_rel_words, rel_word_lengths=rel_word_lengths,
                                     rel_ids=batch_rel_ids, padded_constraint_words_seq=padded_cons_words,
                                     constraint_word_lengths=cons_word_lengths, padded_constraint_ids=padded_cons_id,
                                     constraint_id_lengths=cons_id_lengths, pooling=self.pooling_criterion(epoch))
                else:
                    out = self.model(padded_q_word_seq=cuda_wrapper(padded_q_word), q_word_lengths=q_word_lengths,
                                     padded_q_dep_seq=cuda_wrapper(padded_q_dep), q_dep_lengths=q_dep_lengths,
                                     padded_rel_words_seq=cuda_wrapper(padded_rel_words),
                                     rel_word_lengths=rel_word_lengths, rel_ids=cuda_wrapper(batch_rel_ids),
                                     padded_constraint_words_seq=cuda_wrapper(padded_cons_words),
                                     constraint_word_lengths=cons_word_lengths,
                                     padded_constraint_ids=cuda_wrapper(padded_cons_id),
                                     constraint_id_lengths=cons_id_lengths, pooling=self.pooling_criterion(epoch))
                loss = None # placeholder
                if self.use_prior_weights:
                    loss = criterion.forward(out, cuda_wrapper(batch_labels.long()), cuda_wrapper(batch_prior_weights))
                else:
                    loss = criterion.forward(out, cuda_wrapper(batch_labels.long()), None)
                loss.backward()
                # print "epoch: {}, iter: {}, loss: {}".format(epoch, prog_idx, loss.item())
                self.optimizer.step()
                if epoch <= self.pooling_threshold:
                    self.model.q_encoder.q_word_emb.weight.requires_grad = True
                if self.use_constraint:
                    self.model.query_graph_encoder.constraint_word_emb.weight.requires_grad = True
                    self.model.query_graph_encoder.constraint_id_emb.weight.requires_grad = True
                progbar.update(prog_idx + 1, [("loss", loss.item())])
                prog_idx += 1
            #epoch_loss = self.eval()
            #epoch_loss_history.append(epoch_loss)
            #sys.stderr.write("Epoch: {}, Loss: {}\n".format(epoch, epoch_loss))

            #print "Epoch: {}, Loss: {}".format(epoch, epoch_loss)
            if epoch == self.max_epoch - 1 or epoch % 3 == 2:
                epoch_loss = self.eval()
                sys.stderr.write("Epoch: {}, Loss: {}\n".format(epoch, epoch_loss))
            if save_dir is not None:
                check_point = {
                    #'loss': epoch_loss,
                    'state_dict': self.model.state_dict()
                }
                torch.save(check_point, os.path.join(save_dir, str(epoch)))
    def train(self, X_train, X_mask, Y_train, Y_mask, input, output, verbose, optimizer):

        train_set_x = theano.shared(np.asarray(X_train, dtype="int32"), borrow=True)
        train_set_y = theano.shared(np.asarray(Y_train, dtype="int32"), borrow=True)

        mask_set_x = theano.shared(np.asarray(X_mask, dtype="float32"), borrow=True)
        mask_set_y = theano.shared(np.asarray(Y_mask, dtype="float32"), borrow=True)

        index = T.lscalar("index")  # index to a case
        lr = T.scalar("lr", dtype=theano.config.floatX)
        mom = T.scalar("mom", dtype=theano.config.floatX)  # momentum
        n_ex = T.lscalar("n_ex")
        sindex = T.lscalar("sindex")  # index to a case

        ### batch

        batch_start = index * self.n_batch
        batch_stop = T.minimum(n_ex, (index + 1) * self.n_batch)

        effective_batch_size = batch_stop - batch_start

        get_batch_size = theano.function(inputs=[index, n_ex], outputs=effective_batch_size)

        cost = self.loss(self.y, self.y_mask) + self.L1_reg * self.L1

        updates = eval(optimizer)(self.params, cost, mom, lr)

        """    
        compute_val_error = theano.function(inputs = [index,n_ex ],
                                              outputs = self.loss(self.y,self.y_mask),
                                              givens = {
                                                  self.x: train_set_x[:,batch_start:batch_stop],
                                                  self.y: train_set_y[:,batch_start:batch_stop],
                                                  self.x_mask: mask_set_x[:,batch_start:batch_stop],
                                                  self.y_mask: mask_set_y[:,batch_start:batch_stop]  
                                                    },
                                              mode = mode)    
        """
        train_model = theano.function(
            inputs=[index, lr, mom, n_ex],
            outputs=[cost, self.loss(self.y, self.y_mask)],
            updates=updates,
            givens={
                self.x: train_set_x[:, batch_start:batch_stop],
                self.y: train_set_y[:, batch_start:batch_stop],
                self.x_mask: mask_set_x[:, batch_start:batch_stop],
                self.y_mask: mask_set_y[:, batch_start:batch_stop],
            },
            mode=mode,
            on_unused_input="ignore",
        )

        ###############
        # TRAIN MODEL #
        ###############
        print "Training model ..."
        epoch = 0
        n_train = train_set_x.get_value(borrow=True).shape[1]
        n_train_batches = int(np.ceil(1.0 * n_train / self.n_batch))

        if optimizer is not "SGD":
            self.learning_rate_decay = 1

        while epoch < self.n_epochs:
            epoch = epoch + 1
            if verbose == 1:
                progbar = Progbar(n_train_batches)
            train_losses = []
            train_batch_sizes = []
            for idx in xrange(n_train_batches):

                effective_momentum = (
                    self.final_momentum
                    if (epoch + len(self.errors)) > self.momentum_switchover
                    else self.initial_momentum
                )
                cost = train_model(idx, self.lr, effective_momentum, n_train)

                train_losses.append(cost[1])
                train_batch_sizes.append(get_batch_size(idx, n_train))

                if verbose == 1:
                    progbar.update(idx + 1)

            this_train_loss = np.average(train_losses, weights=train_batch_sizes)

            self.errors.append(this_train_loss)

            print ("epoch %i, train loss %f " "lr: %f" % (epoch, this_train_loss, self.lr))

            ### autimatically saving snapshot ..
            if np.mod(epoch, self.snapshot) == 0:
                if epoch is not n_train_batches:
                    self.save()

            ### generating sample..
            if np.mod(epoch, self.sample_Freq) == 0:
                print "Generating a sample..."

                i = np.random.randint(1, n_train)

                test = X_train[:, i]

                truth = Y_train[:, i]

                guess = self.gen_sample(test, X_mask[:, i])

                print "Input: ", " ".join(input.sequences_to_text(test))

                print "Truth: ", " ".join(output.sequences_to_text(truth))

                print "Sample: ", " ".join(output.sequences_to_text(guess[1]))

            """
            # compute loss on validation set
            if np.mod(epoch,self.val_Freq)==0:

                val_losses = [compute_val_error(i, n_train)
                                for i in xrange(n_train_batches)]
                val_batch_sizes = [get_batch_size(i, n_train)
                                     for i in xrange(n_train_batches)]
                this_val_loss = np.average(val_losses,
                                         weights=val_batch_sizes)                     
            """

            self.lr *= self.learning_rate_decay
    args = parser.parse_args()

    img_filenames = [
        filename for filename in os.listdir(args.img_path)
        if os.path.splitext(filename)[-1].lower() == args.img_extension
    ]
    num_img_filenames = len(img_filenames)
    img_features = []
    progress = Progbar(num_img_filenames)

    base_model = VGG16(weights="imagenet")
    model = Model(input=base_model.input,
                  output=base_model.get_layer("fc2").output)

    for num_filename, filename in enumerate(img_filenames, 1):
        img = image.load_img(os.path.join(args.img_path, filename),
                             target_size=(224, 224))
        x = image.img_to_array(img)
        x = np.expand_dims(x, axis=0)
        x = preprocess_input(x)
        fc7_features = model.predict(x).squeeze()
        img_features.append(fc7_features)
        progress.update(num_filename)

    img_features = np.array(img_features)

    with open(args.img_names_filename, "w") as out_file:
        json.dump(img_filenames, out_file)

    np.save(args.img_features_filename, img_features)
Exemple #21
0
def evaluate_beam_search(generator,
                         data_loader,
                         opt,
                         title='',
                         epoch=1,
                         save_path=None):
    logging = config.init_logging(title, save_path + '/%s.log' % title)
    progbar = Progbar(logger=logging,
                      title=title,
                      target=len(data_loader) / opt.beam_batch,
                      batch_size=opt.beam_batch,
                      total_examples=len(data_loader) / opt.beam_batch)

    example_idx = 0
    score_dict = {
    }  # {'precision@5':[],'recall@5':[],'f1score@5':[], 'precision@10':[],'recall@10':[],'f1score@10':[]}
    if opt.report_file:
        f = open(opt.report_file, 'w')

    for i, batch in enumerate(data_loader):
        # if i > 0:
        #     break

        one2many_batch = batch
        src_list, src_len, trg_list, _, trg_copy_target_list, src_oov_map_list, \
        oov_list, src_str_list, trg_str_list = one2many_batch

        if torch.cuda.is_available() and opt.use_gpu:
            src_list = src_list.cuda()
            src_oov_map_list = src_oov_map_list.cuda()

        pred_seq_list = generator.beam_search(src_list, src_len,
                                              src_oov_map_list, oov_list,
                                              opt.word2id)
        # print(len(pred_seq_list[0]), type(pred_seq_list[0]))
        '''
        process each example in current batch
        '''

        for src, src_str, trg, trg_str_seqs, trg_copy, pred_seq, oov \
                in zip(src_list, src_str_list, trg_list, trg_str_list,
                       trg_copy_target_list, pred_seq_list, oov_list):
            # logging.info('======================  %d =========================' % (example_idx))

            # print(trg_str_seqs)
            # print(src_str)
            # print(src, 'src')
            # print(trg_copy, 'trg_copy')
            # print(pred_seq, 'pred_seq')
            # print(oov, 'oov')
            trg_str_is_present = if_present_duplicate_phrase(
                src_str, trg_str_seqs)
            # TODO
            # print(trg_str_is_present)
            # 1st filtering
            pred_is_valid, processed_pred_seqs, processed_pred_str_seqs, \
            processed_pred_score = process_predseqs(pred_seq, oov, opt.id2word, opt)
            # print(pred_is_valid, 'pred_is_valid')
            # print(processed_pred_seqs, 'precessed_pred_seqs')
            # print(processed_pred_str_seqs, len(processed_pred_str_seqs))
            # print(processed_pred_score, 'processed_pred_score')
            # print(len(processed_pred_str_seqs))
            # 2nd filtering: if filter out phrases that don't appear in text, and keep unique ones after stemming
            # print(opt.must_appear_in_src)
            if opt.must_appear_in_src is True:
                pred_is_present = if_present_duplicate_phrase(
                    src_str, processed_pred_str_seqs)
                trg_str_seqs = np.asarray(trg_str_seqs)[trg_str_is_present]
            else:
                pred_is_present = [True] * len(processed_pred_str_seqs)
                # print(pred_is_present)
            valid_and_present = np.asarray(pred_is_valid) * np.asarray(
                pred_is_present)

            match_list = get_match_result(true_seqs=trg_str_seqs,
                                          pred_seqs=processed_pred_str_seqs)
            # print(match_list, 'match', len(match_list))
            '''
            Evaluate predictions w.r.t different filterings and metrics
            '''
            num_oneword_range = [-1, 1]
            topk_range = [5, 10]
            score_names = ['precision', 'recall', 'f_score']
            processed_pred_seqs = np.asarray(
                processed_pred_seqs)[valid_and_present]
            processed_pred_str_seqs = np.asarray(
                processed_pred_str_seqs)[valid_and_present]
            processed_pred_score = np.asarray(
                processed_pred_score)[valid_and_present]
            try:
                data = {}
                data['src_str'] = src_str
                data['trg_str_seqs'] = trg_str_seqs
                data['pred'] = list(processed_pred_str_seqs)
                f.write(json.dumps(data) + '\n')
            except:
                pass
            for num_oneword_seq in num_oneword_range:
                # 3rd round filtering (one-word phrases)
                filtered_pred_seq, filtered_pred_str_seqs, filtered_pred_score = \
                    post_process_predseqs((processed_pred_seqs, processed_pred_str_seqs,
                                           processed_pred_score), num_oneword_seq)

                match_list = get_match_result(true_seqs=trg_str_seqs,
                                              pred_seqs=filtered_pred_str_seqs)

                assert len(filtered_pred_seq) == len(
                    filtered_pred_str_seqs) == len(filtered_pred_score) == len(
                        match_list)

                for topk in topk_range:
                    results = evaluate(match_list,
                                       filtered_pred_seq,
                                       trg_str_seqs,
                                       topk=topk)
                    for k, v in zip(score_names, results):
                        if '%s@%d#oneword=%d' % (
                                k, topk, num_oneword_seq) not in score_dict:
                            score_dict['%s@%d#oneword=%d' %
                                       (k, topk, num_oneword_seq)] = []
                        score_dict['%s@%d#oneword=%d' %
                                   (k, topk, num_oneword_seq)].append(v)

        if example_idx % 10 == 0:
            print('#(precision@5#oneword=-1)=%d, avg=%f' %
                  (len(score_dict['precision@5#oneword=-1']),
                   np.average(score_dict['precision@5#oneword=-1'])))
            print('#(precision@10#oneword=-1)=%d, avg=%f' %
                  (len(score_dict['precision@10#oneword=-1']),
                   np.average(score_dict['precision@10#oneword=-1'])))

            print('#(recall@5#oneword=-1)=%d, avg=%f' %
                  (len(score_dict['recall@5#oneword=-1']),
                   np.average(score_dict['recall@5#oneword=-1'])))
            print('#(recall@10#oneword=-1)=%d, avg=%f' %
                  (len(score_dict['recall@10#oneword=-1']),
                   np.average(score_dict['recall@10#oneword=-1'])))

            x, y = np.average(score_dict['f_score@5#oneword=-1']), np.average(
                score_dict['f_score@10#oneword=-1'])
            print('#(f_score@5#oneword=-1)=%d, avg=%f' %
                  (len(score_dict['f_score@5#oneword=-1']), x))
            print('#(f_score@10#oneword=-1)=%d, avg=%f' %
                  (len(score_dict['f_score@10#oneword=-1']), y))
            progbar.update(epoch, example_idx, [('f_score@5#oneword=-1', x),
                                                ('f_score@10#oneword=-1', y)])
            print('*' * 50)
        example_idx += 1

        # exit(0)
    print('#(f_score@5#oneword=-1)=%d, avg=%f' %
          (len(score_dict['f_score@5#oneword=-1']),
           np.average(score_dict['f_score@5#oneword=-1'])))
    print('#(f_score@10#oneword=-1)=%d, avg=%f' %
          (len(score_dict['f_score@10#oneword=-1']),
           np.average(score_dict['f_score@10#oneword=-1'])))
    # print('#(f_score@5#oneword=1)=%d, avg=%f' % (len(score_dict['f_score@5#oneword=1']), np.average(score_dict['f_score@5#oneword=1'])))
    # print('#(f_score@10#oneword=1)=%d, avg=%f' % (len(score_dict['f_score@10#oneword=1']), np.average(score_dict['f_score@10#oneword=1'])))

    if save_path:
        # export scores. Each row is scores (precision, recall and f-score) of different way of filtering predictions (how many one-word predictions to keep)
        with open(save_path + os.path.sep + title + '_result.csv',
                  'w') as result_csv:
            csv_lines = []
            for num_oneword_seq in num_oneword_range:
                for topk in topk_range:
                    csv_line = '#oneword=%d,@%d' % (num_oneword_seq, topk)
                    for k in score_names:
                        csv_line += ',%f' % np.average(
                            score_dict['%s@%d#oneword=%d' %
                                       (k, topk, num_oneword_seq)])
                    csv_lines.append(csv_line + '\n')

            result_csv.writelines(csv_lines)

    return score_dict
Exemple #22
0
         batch_ix += BATCH_SIZE
         if batch_ix >= train_x.shape[0]:
             batch_ix = 0
         model_probs = rbm.gibbs_sampler(batch_x, k, return_probs=True)
         model_sample = rbm.sample(model_probs) if SAMPLE else model_probs
         train_loss = cross_entropy(batch_x, model_probs)
         avg_train_loss += train_loss
         rbm.backward(batch_x, model_sample)
         optimizer.step()
         # Clean Up
         optimizer.zero_grad()
     avg_train_loss /= steps
     val_sample = rbm.gibbs_sampler(val_x, k, return_probs=True)
     val_loss = cross_entropy(val_x, val_sample)
     bar.update(epoch + 1,
                values=[("train_loss", avg_train_loss),
                        ("val_loss", val_loss)])
     if (epoch + 1) % args.plot_after == 0:
         # index = range(train_x.shape[0])
         # random.shuffle(index)
         # random_train = train_x[index, :][:16,:]
         random_train = np.random.uniform(0, 1., (100, 784))
         sample = rbm.gibbs_sampler(random_train,
                                    args.gibbs_steps,
                                    return_probs=True)
         plot(sample,
              "Plots/Image_Epoch_{}.png".format(format(epoch + 1,
                                                       '03')), epoch + 1)
 # Save the weights
 save_filename = "Models/weights_hdim_%d_epochs_%d_val_%.4f_k_%d.pkl" % (
     args.n_hidden, args.n_epochs, val_loss, k)
Exemple #23
0
        #forward, loss, backward, step
        output = encoder(images=images, rotation=rotation, mode=1)

        loss_1 = noise_contrastive_estimator(representations,
                                             output[1],
                                             index,
                                             memory,
                                             negative_nb=negative_nb)
        loss_2 = noise_contrastive_estimator(representations,
                                             output[0],
                                             index,
                                             memory,
                                             negative_nb=negative_nb)
        loss = loss_weight * loss_1 + (1 - loss_weight) * loss_2

        loss.backward()
        optimizer.step()

        #update representation memory
        memory.update(index, output[0].detach())

        # update metric and bar
        train_loss.update(loss.item(), images.shape[0])
        bar.update(step, values=[('train_loss', train_loss.return_avg())])

        #save model if improved
        checkpoint.save_model(encoder, optimizer, train_loss.return_avg(),
                              epoch)

    logger.update(epoch, train_loss.return_avg())
Exemple #24
0
def train(config, checkpoint_dir):

    print('Start Training the Model!')

    # generate data loader
    if config.longterm is True:
        config.output_window_size = 100
    # from choose_dataset import DatasetChooser
    choose = DatasetChooser(config)
    train_dataset, bone_length = choose(train=True)
    train_loader = DataLoader(train_dataset,
                              batch_size=config.batch_size,
                              shuffle=True)
    test_dataset, _ = choose(train=False)
    test_loader = DataLoader(test_dataset,
                             batch_size=config.batch_size,
                             shuffle=True)
    prediction_dataset, bone_length = choose(prediction=True)
    x_test, y_test, dec_in_test = prediction_dataset

    device = torch.device(
        "cuda:" +
        str(config.device_ids[0]) if torch.cuda.is_available() else "cpu")
    print('Device {} will be used to save parameters'.format(device))
    torch.cuda.manual_seed_all(112858)
    net = choose_net(config)
    net.to(device)
    print('Total param number:' + str(sum(p.numel()
                                          for p in net.parameters())))
    print('Encoder param number:' +
          str(sum(p.numel() for p in net.encoder_cell.parameters())))
    print('Decoder param number:' +
          str(sum(p.numel() for p in net.decoder.parameters())))
    if torch.cuda.device_count() > 1:
        print("{} GPUs are usable!".format(str(torch.cuda.device_count())))
    net = torch.nn.DataParallel(net, device_ids=config.device_ids)

    if config.restore is True:
        dir = utils.get_file_list(checkpoint_dir)
        print('Load model from:' + checkpoint_dir + dir[-1])
        net.load_state_dict(
            torch.load(checkpoint_dir + dir[-1], map_location='cuda:0'))

    optimizer = optim.Adam(net.parameters(), lr=config.learning_rate)

    # Save model
    if not (os.path.exists(checkpoint_dir)):
        os.makedirs(checkpoint_dir)

    best_error = float('inf')
    best_error_list = None

    # log 定义两个数组
    trainloss_list = []
    validloss_list = []
    Error_list = []

    for epoch in range(config.max_epoch):
        print("At epoch:{}".format(str(epoch + 1)))
        prog = Progbar(target=config.training_size)
        prog_valid = Progbar(target=config.validation_size)

        # Train
        #with torch.autograd.set_detect_anomaly(True):
        for it in range(config.training_size):
            for i, data in enumerate(train_loader, 0):
                encoder_inputs = data['encoder_inputs'].float().to(
                    device)  ## 前t-1帧
                decoder_inputs = data['decoder_inputs'].float().to(
                    device)  ## t-1到t-1+output_window_size帧
                decoder_outputs = data['decoder_outputs'].float().to(
                    device)  ## t帧到以后的
                prediction = net(encoder_inputs, decoder_inputs, train=True)
                loss = Loss(prediction, decoder_outputs, bone_length, config)
                net.zero_grad()
                loss.backward()
                _ = torch.nn.utils.clip_grad_norm_(net.parameters(), 5)
                optimizer.step()

            prog.update(it + 1, [("Training Loss", loss.item())])

        trainloss_list.append(loss.item)

        # valid
        with torch.no_grad():
            for it in range(config.validation_size):
                for j in range(3):
                    for i, data in enumerate(test_loader, 0):
                        if j == 0 and i == 0:
                            encoder_inputs = data['encoder_inputs'].float().to(
                                device)
                            decoder_inputs = data['decoder_inputs'].float().to(
                                device)
                            decoder_outputs = data['decoder_outputs'].float(
                            ).to(device)
                        else:
                            encoder_inputs = torch.cat([
                                data['encoder_inputs'].float().to(device),
                                encoder_inputs
                            ],
                                                       dim=0)
                            decoder_inputs = torch.cat([
                                data['decoder_inputs'].float().to(device),
                                decoder_inputs
                            ],
                                                       dim=0)
                            decoder_outputs = torch.cat([
                                data['decoder_outputs'].float().to(device),
                                decoder_outputs
                            ],
                                                        dim=0)

                prediction = net(encoder_inputs, decoder_inputs, train=True)
                loss = Loss(prediction, decoder_outputs, bone_length, config)
                prog_valid.update(it + 1, [("Testing Loss", loss.item())])
            validloss_list.append(loss.item)

        #Test prediction
        actions = list(x_test.keys())
        y_predict = {}
        with torch.no_grad():
            for act in actions:
                x_test_ = torch.from_numpy(x_test[act]).float().to(device)
                dec_in_test_ = torch.from_numpy(
                    dec_in_test[act]).float().to(device)
                pred = net(x_test_, dec_in_test_, train=False)
                pred = pred.cpu().numpy()
                y_predict[act] = pred

        error_actions = 0.0
        for act in actions:
            if config.datatype == 'lie':
                mean_error, _ = utils.mean_euler_error(
                    config, act, y_predict[act],
                    y_test[act][:, :config.output_window_size, :])
                error = mean_error[[1, 3, 7, 9]]

            if config.datatype == 'smpl':
                mean_error, _ = utils.mean_euler_error(
                    config, act, y_predict[act],
                    y_test[act][:, :config.output_window_size, :])
                error = mean_error[[1, 3, 7, 9]]
            # error_actions += error.mean()

            error_actions += mean_error.mean()

        error_actions /= len(actions)

        Error_list.append(error_actions)

        if error_actions < best_error:
            print(error_actions)
            print(best_error)
            best_error_list = error
            best_error = error_actions
            torch.save(net.state_dict(),
                       checkpoint_dir + 'Epoch_' + str(epoch + 1) + '.pth')
        print('Current best:' + str(round(best_error_list[0], 2)) + ' ' +
              str(round(best_error_list[1], 2)) + ' ' +
              str(round(best_error_list[2], 2)) + ' ' +
              str(round(best_error_list[3], 2)))

    if not (os.path.exists('./log')):
        os.makedirs('./log')
    time_now = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
    np.savez('./log/' + time_now + '.npz',
             trainloss_list=trainloss_list,
             validloss_list=validloss_list,
             Error_list=Error_list)
    parser.add_argument("--preprocessed_dataset_filename",
                        type=str,
                        required=True)
    args = parser.parse_args()

    nlp = en_core_web_sm.load()
    num_lines = len([1 for line in open(args.dataset_filename)])

    images = {}

    with open(args.dataset_filename) as in_file:
        dataset = json.load(in_file)

        for num_image, image in enumerate(dataset["images"], 1):
            images[image["id"]] = image["file_name"]

        print("Found {} images".format(len(images)))
        with open(args.preprocessed_dataset_filename, mode="w") as out_file:
            writer = csv.writer(out_file, delimiter="\t")

            progress = Progbar(len(dataset["annotations"]))
            for num_annotation, annotation in enumerate(
                    dataset["annotations"], 1):
                caption = annotation["caption"]
                image = images[annotation["image_id"]]
                label = "yes" if annotation["foil_word"] == "ORIG" else "no"
                # caption_tokens = [token.lower_ for token in nlp(caption)]
                # writer.writerow([label, " ".join(caption_tokens), image, caption])
                writer.writerow([label, image, caption])
                progress.update(num_annotation)
Exemple #26
0
def run_simultrans(model,
                   options_file=None,
                   config=None,
                   id=None,
                   remote=False):

    WORK = config['workspace']

    # check hidden folders
    paths = [
        '.policy', '.pretrained', '.log', '.config', '.images', '.translate'
    ]
    for p in paths:
        p = WORK + p
        if not os.path.exists(p):
            os.mkdir(p)

    if id is not None:
        fcon = WORK + '.config/{}.conf'.format(id)
        if os.path.exists(fcon):
            print 'load config files'
            policy, config = pkl.load(open(fcon, 'r'))

    # ============================================================================== #
    # load model model_options
    # ============================================================================== #
    _model = model.split('/')[-1]

    if options_file is not None:
        with open(options_file, 'rb') as f:
            options = pkl.load(f)
    else:
        with open('%s.pkl' % model, 'rb') as f:
            options = pkl.load(f)

    print 'merge configuration into options'
    for w in config:
        # if (w in options) and (config[w] is not None):
        options[w] = config[w]

    print 'load options...'
    for w, p in sorted(options.items(), key=lambda x: x[0]):
        print '{}: {}'.format(w, p)

    # load detail settings from option file:
    dictionary, dictionary_target = options['dictionaries']

    # load source dictionary and invert
    with open(dictionary, 'rb') as f:
        word_dict = pkl.load(f)
    word_idict = dict()
    for kk, vv in word_dict.iteritems():
        word_idict[vv] = kk
    word_idict[0] = '<eos>'
    word_idict[1] = 'UNK'

    # load target dictionary and invert
    with open(dictionary_target, 'rb') as f:
        word_dict_trg = pkl.load(f)
    word_idict_trg = dict()
    for kk, vv in word_dict_trg.iteritems():
        word_idict_trg[vv] = kk
    word_idict_trg[0] = '<eos>'
    word_idict_trg[1] = 'UNK'

    options['pre'] = config['pre']

    # ========================================================================= #
    # Build a Simultaneous Translator
    # ========================================================================= #

    # allocate model parameters
    params = init_params(options)
    params = load_params(model, params)
    tparams = init_tparams(params)

    # print 'build the model for computing cost (full source sentence).'
    trng, use_noise, \
    _x, _x_mask, _y, _y_mask, \
    opt_ret, \
    cost, f_cost = build_model(tparams, options)
    print 'done'

    # functions for sampler
    f_sim_ctx, f_sim_init, f_sim_next = build_simultaneous_sampler(
        tparams, options, trng)

    # function for finetune the underlying model
    if options['finetune']:
        ff_init, ff_cost, ff_update = build_simultaneous_model(tparams,
                                                               options,
                                                               rl=True)
        funcs = [
            f_sim_ctx, f_sim_init, f_sim_next, f_cost, ff_init, ff_cost,
            ff_update
        ]

    else:
        funcs = [f_sim_ctx, f_sim_init, f_sim_next, f_cost]

    # build a res-predictor
    if options['predict']:
        params_act = get_actor('gru')[0](options,
                                         prefix='pdt',
                                         nin=options['dim'])
        pass

    # check the ID:
    options['base'] = _model
    agent = Policy(trng,
                   options,
                   n_in=options['readout_dim'] +
                   1 if options['coverage'] else options['readout_dim'],
                   n_out=3 if config['forget'] else 2,
                   recurrent=options['recurrent'],
                   id=id)

    # make the dataset ready for training & validation
    trainIter = TextIterator(options['datasets'][0],
                             options['datasets'][1],
                             options['dictionaries'][0],
                             options['dictionaries'][1],
                             n_words_source=options['n_words_src'],
                             n_words_target=options['n_words'],
                             batch_size=config['batchsize'],
                             maxlen=options['maxlen'])

    train_num = trainIter.num

    validIter = TextIterator(options['valid_datasets'][0],
                             options['valid_datasets'][1],
                             options['dictionaries'][0],
                             options['dictionaries'][1],
                             n_words_source=options['n_words_src'],
                             n_words_target=options['n_words'],
                             batch_size=20,
                             cache=10,
                             maxlen=1000000)

    valid_num = validIter.num
    print 'training set {} lines / validation set {} lines'.format(
        train_num, valid_num)
    print 'use the reward function {}'.format(chr(config['Rtype'] + 65))

    # ========================================================================== #
    # Main Loop: Run
    # ========================================================================== #
    print 'Start Simultaneous Translator...'
    monitor = None
    if remote:
        monitor = Monitor(root='http://localhost:9000')

    # freqs
    save_freq = 200
    sample_freq = 10
    valid_freq = 200
    valid_size = 200
    display_freq = 50
    finetune_freq = 5

    history, last_it = agent.load()
    action_space = ['W', 'C', 'F']
    Log_avg = {}
    time0 = timer()

    pipe = OrderedDict()
    for key in ['x', 'x_mask', 'y', 'y_mask', 'c_mask']:
        pipe[key] = []

    def _translate(src,
                   trg,
                   samples=None,
                   train=False,
                   greedy=False,
                   show=False,
                   full=False):
        time0 = time.time()
        if full:
            options1 = copy.copy(options)
            options1['upper'] = True
        else:
            options1 = options

        ret = simultaneous_decoding(funcs, agent, options1, src, trg,
                                    word_idict_trg, samples, greedy, train)

        if show:
            info = ret[1]
            values = [(w, float(info[w])) for w in info if w != 'advantages']
            print ' , '.join(['{}={:.3f}'.format(k, f) for k, f in values]),
            print '...{}s'.format(time.time() - time0)

        return ret

    for it, (srcs,
             trgs) in enumerate(trainIter):  # only one sentence each iteration
        if it < last_it:  # go over the scanned lines.
            continue

        # for validation
        # doing the whole validation!!
        reference = []
        system = []

        if it % valid_freq == (valid_freq - 1):
            print 'start validation'

            collections = [[], [], [], [], []]
            probar_v = Progbar(valid_num / 20 + 1)
            for ij, (srcs, trgs) in enumerate(validIter):

                statistics = _translate(srcs,
                                        trgs,
                                        samples=1,
                                        train=False,
                                        greedy=True)

                quality, delay, reward = zip(*statistics['track'])
                reference += statistics['Ref']
                system += statistics['Sys']

                # compute the average consective waiting length
                def _consective(action):
                    waits = []
                    temp = 0
                    for a in action:
                        if a == 0:
                            temp += 1
                        elif temp > 0:
                            waits += [temp]
                            temp = 0

                    if temp > 0:
                        waits += [temp]

                    mean = numpy.mean(waits)
                    gec = numpy.max(
                        waits)  # numpy.prod(waits) ** (1./len(waits))
                    return mean, gec

                def _max_length(action):
                    _cur = 0
                    _end = 0
                    _max = 0
                    for it, a in enumerate(action):
                        if a == 0:
                            _cur += 1
                        elif a == 2:
                            _end += 1

                        temp = _cur - _end
                        if temp > _max:
                            _max = temp
                    return _max

                maxlen = [
                    _max_length(action) for action in statistics['action']
                ]
                means, gecs = zip(*(_consective(action)
                                    for action in statistics['action']))

                collections[0] += quality
                collections[1] += delay
                collections[2] += means
                collections[3] += gecs
                collections[4] += maxlen

                values = [('quality', numpy.mean(quality)),
                          ('delay', numpy.mean(delay)),
                          ('wait_mean', numpy.mean(means)),
                          ('wait_max', numpy.mean(gecs)),
                          ('max_len', numpy.mean(maxlen))]
                probar_v.update(ij + 1, values=values)

            validIter.reset()
            valid_bleu, valid_delay, valid_wait, valid_wait_gec, valid_mx = [
                numpy.mean(a) for a in collections
            ]
            print 'Iter = {}: AVG BLEU = {}, DELAY = {}, WAIT(MEAN) = {}, WAIT(MAX) = {}, MaxLen={}'.format(
                it, valid_bleu, valid_delay, valid_wait, valid_wait_gec,
                valid_mx)

            print 'Compute the Corpus BLEU={} (greedy)'.format(
                corpus_bleu(reference, system))

            with open(WORK + '.translate/test.txt', 'w') as fout:
                for sys in system:
                    fout.write('{}\n'.format(' '.join(sys)))

            with open(WORK + '.translate/ref.txt', 'w') as fout:
                for ref in reference:
                    fout.write('{}\n'.format(' '.join(ref[0])))

            history += [collections]
            print 'done'

        if options['upper']:
            print 'done'
            import sys
            sys.exit(-1)

        # training set sentence tuning
        new_srcs, new_trgs = [], []
        for src, trg in zip(srcs, trgs):
            if len(src) <= options['s0']:
                continue  # ignore when the source sentence is less than sidx.
            else:
                new_srcs += [src]
                new_trgs += [trg]

        if len(new_srcs) == 0:
            continue

        srcs, trgs = new_srcs, new_trgs
        statistics, info = _translate(srcs, trgs, train=True, show=True)

        if it % sample_freq == 0:

            # obtain the translation results
            samples = _bpe2words(
                _seqs2words(statistics['sample'], word_idict_trg,
                            statistics['action'], 1))
            sources = _bpe2words(
                _seqs2words(statistics['SWord'], word_idict,
                            statistics['action'], 0))
            targets = _bpe2words(
                _seqs2words(statistics['TWord'], word_idict_trg))

            # obtain the delay (normalized)
            # delays = _action2delay(srcs[0], statistics['action'])

            c = 0
            for j in xrange(len(samples)):

                if statistics['seq_info'][j][0] == 0:
                    if c < (config['sample'] / 2.):
                        c += 1
                        continue

                    print '--Iter: {}'.format(it)
                    print 'source: ', sources[j]
                    print 'sample: ', samples[j]
                    print 'target: ', targets[j]
                    print 'quality:', statistics['track'][j][0]
                    print 'delay:', statistics['track'][j][1]
                    print 'reward:', statistics['track'][j][2]
                    break

        # NaN detector
        #for w in info:
        #    if numpy.isnan(info[w]) or numpy.isinf(info[w]):
        #        raise RuntimeError, 'NaN/INF is detected!! {} : ID={}'.format(w, id)

        # remote display
        if remote:
            logs = {
                'R': info['R'],
                'Q': info['Q'],
                'D': info['D'],
                'P': float(info['P'])
            }
            if 'a_cost' in info:
                logs['A'] = info['a_cost']

            print logs
            for w in logs:
                Log_avg[w] = Log_avg.get(w, 0) + logs[w]

            if it % display_freq == (display_freq - 1):
                for w in Log_avg:
                    Log_avg[w] /= display_freq

                monitor.display(it + 1, Log_avg)
                Log_avg = dict()

        # save the history & model
        history += [info]
        if it % save_freq == 0:
            agent.save(history, it)
                        type=str,
                        required=True)
    args = parser.parse_args()

    nlp = en_core_web_sm.load()
    num_lines = len([1 for line in open(args.dataset_filename)])

    with open(args.dataset_filename) as in_file:
        reader = csv.reader(in_file, delimiter="\t")

        with open(args.preprocessed_dataset_filename, mode="w") as out_file:
            writer = csv.writer(out_file, delimiter="\t")

            progress = Progbar(num_lines)
            for row_number, row in enumerate(reader, 1):
                progress.update(row_number)
                label = row[0].strip()
                premise = row[1].strip()
                hypothesis = row[2].strip()
                premise_tokens = [token.lower_ for token in nlp(premise)]
                hypothesis_tokens = [token.lower_ for token in nlp(hypothesis)]
                if len(row) == 3:
                    writer.writerow([
                        label, " ".join(premise_tokens),
                        " ".join(hypothesis_tokens), premise, hypothesis
                    ])
                elif len(row) == 4:
                    image_filename = row[3].strip()
                    writer.writerow([
                        label, " ".join(premise_tokens),
                        " ".join(hypothesis_tokens), image_filename, premise,
        for att in f.readlines():
            attributes.append(att.split(",")[0].lower().strip())

    gpu_id = 0
    caffe.set_device(gpu_id)
    caffe.set_mode_gpu()
    net = None
    cfg_from_file(args.cfg_filename)
    net = caffe.Net(args.net_def_filename,
                    caffe.TEST,
                    weights=args.net_weights_filename)

    img_filenames = os.listdir(args.img_path)
    num_img_filenames = len(img_filenames)

    progress = Progbar(num_img_filenames)
    bottom_up_features = {}

    for filename_index, filename in enumerate(img_filenames):
        full_filename = os.path.join(args.img_path, filename)
        results = get_detections_from_im(net,
                                         full_filename,
                                         conf_thresh=0.2,
                                         min_num_boxes=args.num_boxes,
                                         max_num_boxes=args.num_boxes)
        bottom_up_features[filename] = results
        progress.update(filename_index)

    with open(args.features_filename, mode="wb") as out_file:
        pickle.dump(bottom_up_features, out_file)
            for indexes in batch(batches_indexes, args.batch_size):
                batch_premises = train_premises[indexes]
                batch_hypotheses = train_hypotheses[indexes]
                batch_labels = train_labels[indexes]
                batch_img_names = [train_img_names[i] for i in indexes]
                batch_img_features = image_reader.get_features(batch_img_names)

                loss, _ = session.run([loss_function, train_step], feed_dict={
                    premise_input: batch_premises,
                    hypothesis_input: batch_hypotheses,
                    img_features_input: batch_img_features,
                    label_input: batch_labels,
                    dropout_input: args.dropout_ratio
                })
                progress.update(batch_index, [("Loss", loss)])
                epoch_loss += loss
                batch_index += 1
            print("Current mean training loss: {}\n".format(epoch_loss / num_batches))

            print("-- Validating model")
            dev_num_examples = dev_labels.shape[0]
            dev_batches_indexes = np.arange(dev_num_examples)
            dev_num_correct = 0

            for indexes in batch(dev_batches_indexes, args.batch_size):
                dev_batch_premises = dev_premises[indexes]
                dev_batch_hypotheses = dev_hypotheses[indexes]
                dev_batch_labels = dev_labels[indexes]
                dev_batch_img_names = [dev_img_names[i] for i in indexes]
                dev_batch_img_features = image_reader.get_features(dev_batch_img_names)
def train_model(model, optimizer, criterion, training_data_loader, validation_data_loader, opt):
    logging.info('======================  Checking GPU Availability  =========================')
    if torch.cuda.is_available():
        if isinstance(opt.gpuid, int):
            opt.gpuid = [opt.gpuid]
        logging.info('Running on GPU! devices=%s' % str(opt.gpuid))
        model = model.cuda()
        model = nn.DataParallel(model, device_ids=opt.gpuid)
        criterion.cuda()
    else:
        logging.info('Running on CPU!')

    logging.info('======================  Start Training  =========================')

    train_history_losses = []
    valid_history_losses = []
    best_loss = sys.float_info.max

    train_losses = []
    total_batch = 0
    early_stop_flag = False

    for epoch in range(opt.start_epoch , opt.epochs):
        if early_stop_flag:
            break

        progbar = Progbar(title='Training', target=len(training_data_loader), batch_size=opt.batch_size,
                          total_examples=len(training_data_loader.dataset))
        model.train()

        for batch_i, batch in enumerate(training_data_loader):
            batch_i += 1 # for the aesthetics of printing
            total_batch += 1
            src = batch.src
            trg = batch.trg

            # print("src size - ",src.size())
            # print("target size - ",trg.size())
            if torch.cuda.is_available():
                src.cuda()
                trg.cuda()

            optimizer.zero_grad()
            decoder_logits, _, _ = model.forward(src, trg, must_teacher_forcing=False)

            start_time = time.time()

            # remove the 1st word in trg to let predictions and real goal match
            loss = criterion(
                decoder_logits.contiguous().view(-1, opt.vocab_size),
                trg[:, 1:].contiguous().view(-1)
            )
            print("--loss calculation- %s seconds ---" % (time.time() - start_time))

            start_time = time.time()
            loss.backward()
            print("--backward- %s seconds ---" % (time.time() - start_time))

            if opt.max_grad_norm > 0:
                pre_norm = torch.nn.utils.clip_grad_norm(model.parameters(), opt.max_grad_norm)
                after_norm = (sum([p.grad.data.norm(2) ** 2 for p in model.parameters() if p.grad is not None])) ** (1.0 / 2)
                logging.info('clip grad (%f -> %f)' % (pre_norm, after_norm))

            optimizer.step()

            train_losses.append(loss.data[0])
            perplexity = np.math.exp(loss.data[0])

            progbar.update(epoch, batch_i, [('train_loss', loss.data[0]), ('perplexity', perplexity)])

            if batch_i > 1 and batch_i % opt.report_every == 0:
                logging.info('======================  %d  =========================' % (batch_i))

                logging.info('Epoch : %d Minibatch : %d, Loss=%.5f, PPL=%.5f' % (epoch, batch_i, np.mean(loss.data[0]), perplexity))
                sampled_size = 2
                logging.info('Printing predictions on %d sampled examples by greedy search' % sampled_size)

                # softmax logits to get probabilities (batch_size, trg_len, vocab_size)
                # decoder_probs = torch.nn.functional.softmax(decoder_logits.view(trg.size(0) * trg.size(1), -1)).view(*trg.size(), -1)

                if torch.cuda.is_available():
                    src = src.data.cpu().numpy()
                    decoder_logits = decoder_logits.data.cpu().numpy()
                    max_words_pred = decoder_logits.argmax(axis=-1)
                    trg = trg.data.cpu().numpy()
                else:
                    src = src.data.numpy()
                    decoder_logits = decoder_logits.data.numpy()
                    max_words_pred = decoder_logits.argmax(axis=-1)
                    trg = trg.data.numpy()

                sampled_trg_idx = np.random.random_integers(low=0, high=len(trg) - 1, size=sampled_size)
                src             = src[sampled_trg_idx]
                max_words_pred  = [max_words_pred[i] for i in sampled_trg_idx]
                decoder_logits  = decoder_logits[sampled_trg_idx]
                trg = [trg[i][1:] for i in sampled_trg_idx] # the real target has removed the starting <BOS>

                for i, (src_wi, pred_wi, real_wi) in enumerate(zip(src, max_words_pred, trg)):
                    nll_prob = -np.sum(np.log2([decoder_logits[i][l][pred_wi[l]] for l in range(len(real_wi))]))
                    sentence_source = [opt.id2word[x] for x in src_wi]
                    sentence_pred   = [opt.id2word[x] for x in pred_wi]
                    sentence_real   = [opt.id2word[x] for x in real_wi]

                    logging.info('==================================================')
                    logging.info('Source: %s '          % (' '.join(sentence_source)))
                    logging.info('\t\tPred : %s (%.4f)' % (' '.join(sentence_pred), nll_prob))
                    logging.info('\t\tReal : %s '       % (' '.join(sentence_real)))

            if total_batch > 1 and total_batch % opt.run_valid_every == 0:
                logging.info('*' * 50)
                logging.info('Run validation test @Epoch=%d,#(Total batch)=%d' % (epoch, total_batch))
                valid_losses = _valid(validation_data_loader, model, criterion, optimizer, epoch, opt, is_train=False)

                train_history_losses.append(copy.copy(train_losses))
                valid_history_losses.append(valid_losses)
                train_losses = []

                # Plot the learning curve
                plot_learning_curve(train_history_losses, valid_history_losses, 'Training and Validation',
                                    curve1_name='Training Error', curve2_name='Validation Error',
                                    save_path=opt.exp_path + '/[epoch=%d,batch=%d,total_batch=%d]train_valid_curve.png' % (epoch, batch_i, total_batch))

                '''
                determine if early stop training
                '''
                valid_loss = np.average(valid_history_losses[-1])
                is_best_loss = valid_loss < best_loss
                rate_of_change = float(valid_loss - best_loss) / float(best_loss)

                # only store the checkpoints that make better validation performances
                if total_batch > 1 and epoch >= opt.start_checkpoint_at and (total_batch % opt.save_model_every == 0 or is_best_loss):
                    # Save the checkpoint
                    logging.info('Saving checkpoint to: %s' % os.path.join(opt.save_path, '%s.epoch=%d.batch=%d.total_batch=%d.error=%f' % (opt.exp, epoch, batch_i, total_batch, valid_loss) + '.model'))
                    torch.save(
                        model.state_dict(),
                        open(os.path.join(opt.save_path, '%s.epoch=%d.batch=%d.total_batch=%d' % (opt.exp, epoch, batch_i, total_batch) + '.model'), 'wb')
                    )

                # valid error doesn't decrease
                if rate_of_change >= 0:
                    stop_increasing += 1
                else:
                    stop_increasing = 0

                if is_best_loss:
                    logging.info('Validation: update best loss (%.4f --> %.4f), rate of change (ROC)=%.2f' % (
                        best_loss, valid_loss, rate_of_change * 100))
                else:
                    logging.info('Validation: best loss is not updated for %d times (%.4f --> %.4f), rate of change (ROC)=%.2f' % (
                        stop_increasing, best_loss, valid_loss, rate_of_change * 100))

                best_loss = min(valid_loss, best_loss)
                if stop_increasing >= opt.early_stop_tolerance:
                    logging.info('Have not increased for %d epoches, early stop training' % stop_increasing)
                    early_stop_flag = True
                    break
                logging.info('*' * 50)
Exemple #31
0
         file_name = MODEL_DIR + options['SAVE_PREFIX']
         model = MF.BiLSTM_Model(options)
         n_epochs = 10
         _mode = args.partial_mode if hasattr(args,
                                              'partial_mode') else 'em'
         model.fit(X=train_X,
                   y=train_y,
                   val_split=0.9,
                   shuffle=True,
                   n_epochs=n_epochs,
                   save_best=True,
                   save_prefix=file_name,
                   X_unlabeled=unlabeled_X,
                   y_unlabeled=unlabeled_y,
                   mode=_mode)
         bar.update(ix - start_ix + 1)
 else:
     predictions = []
     model = MF.BiLSTM_Model(options)
     for ix in xrange(len(posts)):
         BASE_DIR = options['BASE_DIR']
         MODEL_DIR = BASE_DIR + 'MODEL_' + str(ix) + '/'
         file_name = MODEL_DIR + options['SAVE_PREFIX']
         best_model_file = ut.get_best_model_file(file_name)
         model.load_model(best_model_file)
         test_X, test_y = ut.data_generator([posts[ix]], options)
         prediction = model.predict(test_X, mode=args.inf_mode)
         prediction = [
             options['IX_2_CLASSES'][x]
             if x in options['IX_2_CLASSES'] else 'O' for x in prediction
         ]
Exemple #32
0
def run_simultrans(model,
                   options_file=None,
                   config=None,
                   policy=None,
                   id=None,
                   remote=False):
    # check envoriments
    check_env()
    if id is not None:
        fcon = WORK + '.config/{}.conf'.format(id)
        if os.path.exists(fcon):
            print 'load config files'
            policy, config = pkl.load(open(fcon, 'r'))

    # ============================================================================== #
    # load model model_options
    # ============================================================================== #
    _model = model
    model = WORK + '.pretrained/{}'.format(model)

    if options_file is not None:
        with open(options_file, 'rb') as f:
            options = pkl.load(f)
    else:
        with open('%s.pkl' % model, 'rb') as f:
            options = pkl.load(f)

    print 'load options...'
    for w, p in sorted(options.items(), key=lambda x: x[0]):
        print '{}: {}'.format(w, p)

    # load detail settings from option file:
    dictionary, dictionary_target = options['dictionaries']

    def _iter(fname):
        with open(fname, 'r') as f:
            for line in f:
                words = line.strip().split()
                x = map(lambda w: word_dict[w] if w in word_dict else 1, words)
                x = map(lambda ii: ii if ii < options['n_words'] else 1, x)
                x += [0]
                yield x

    def _check_length(fname):
        f = open(fname, 'r')
        count = 0
        for _ in f:
            count += 1
        f.close()

        return count

    # load source dictionary and invert
    with open(dictionary, 'rb') as f:
        word_dict = pkl.load(f)
    word_idict = dict()
    for kk, vv in word_dict.iteritems():
        word_idict[vv] = kk
    word_idict[0] = '<eos>'
    word_idict[1] = 'UNK'

    # load target dictionary and invert
    with open(dictionary_target, 'rb') as f:
        word_dict_trg = pkl.load(f)
    word_idict_trg = dict()
    for kk, vv in word_dict_trg.iteritems():
        word_idict_trg[vv] = kk
    word_idict_trg[0] = '<eos>'
    word_idict_trg[1] = 'UNK'

    ## use additional input for the policy network
    options['pre'] = config['pre']

    # ================================================================================= #
    # Build a Simultaneous Translator
    # ================================================================================= #

    # allocate model parameters
    params = init_params(options)
    params = load_params(model, params)
    tparams = init_tparams(params)

    # print 'build the model for computing cost (full source sentence).'
    trng, use_noise, \
    _x, _x_mask, _y, _y_mask, \
    opt_ret, \
    cost, f_cost = build_model(tparams, options)
    print 'done'

    # functions for sampler
    f_sim_ctx, f_sim_init, f_sim_next = build_simultaneous_sampler(
        tparams, options, trng)

    # function for finetune
    if config['finetune'] != 'nope':
        f_fine_init, f_fine_cost, f_fine_update = build_fine(
            tparams,
            options,
            fullmodel=True if config['finetune'] == 'full' else False)

    def _translate(src,
                   trg,
                   train=False,
                   samples=config['sample'],
                   greedy=False):
        ret = simultaneous_decoding(
            f_sim_ctx,
            f_sim_init,
            f_sim_next,
            f_cost,
            _policy,
            src,
            trg,
            word_idict_trg,
            step=config['step'],
            peek=config['peek'],
            sidx=config['s0'],
            n_samples=samples,
            reward_config={
                'target': config['target'],
                'gamma': config['gamma'],
                'Rtype': config['Rtype'],
                'maxsrc': config['maxsrc'],
                'greedy': greedy,
                'upper': config['upper']
            },
            train=train,
            use_forget=config['forget'],
            use_newinput=config['pre'],
            use_coverage=config['coverage'],
            on_groundtruth=0 if config['finetune'] == 'nope' else 10)

        print ret
        import sys
        sys.exit(-1)

        return ret

        # if not train:
        #     sample, score, actions, R, tracks, attentions = ret
        #     return sample, score, actions, R, tracks
        # else:
        #     sample, score, actions, R, info, pipe_t = ret
        #     return sample, score, actions, R, info, pipe_t

    # check the ID:
    policy['base'] = _model
    _policy = Policy(trng,
                     options,
                     policy,
                     config,
                     n_in=options['readout_dim'] +
                     1 if config['coverage'] else options['readout_dim'],
                     n_out=3 if config['forget'] else 2,
                     recurrent=policy['recurrent'],
                     id=id)

    # make the dataset ready for training & validation
    # train_    = options['datasets'][0]
    # train_num = _check_length
    trainIter = TextIterator(options['datasets'][0],
                             options['datasets'][1],
                             options['dictionaries'][0],
                             options['dictionaries'][1],
                             n_words_source=options['n_words_src'],
                             n_words_target=options['n_words'],
                             batch_size=config['batchsize'],
                             maxlen=options['maxlen'])

    train_num = trainIter.num

    validIter = TextIterator(options['valid_datasets'][0],
                             options['valid_datasets'][1],
                             options['dictionaries'][0],
                             options['dictionaries'][1],
                             n_words_source=options['n_words_src'],
                             n_words_target=options['n_words'],
                             batch_size=1,
                             cache=1,
                             maxlen=1000000)

    valid_num = validIter.num

    valid_ = options['valid_datasets'][0]
    valid_num = _check_length(valid_)
    print 'training set {} lines / validation set {} lines'.format(
        train_num, valid_num)
    print 'use the reward function {}'.format(chr(config['Rtype'] + 65))

    # ================================================================================= #
    # Main Loop: Run
    # ================================================================================= #
    print 'Start Simultaneous Translator...'
    probar = Progbar(train_num / config['batchsize'], with_history=False)
    monitor = None
    if remote:
        monitor = Monitor(root='http://localhost:9000')

    # freqs
    save_freq = 200
    sample_freq = 10
    valid_freq = 200
    valid_size = 200
    display_freq = 50
    finetune_freq = 5

    history, last_it = _policy.load()
    action_space = ['W', 'C', 'F']
    Log_avg = {}
    time0 = timer()
    pipe = PIPE(['x', 'x_mask', 'y', 'y_mask', 'c_mask'])

    for it, (srcs,
             trgs) in enumerate(trainIter):  # only one sentence each iteration
        if it < last_it:  # go over the scanned lines.
            continue

        # for validation
        # doing the whole validation!!
        reference = []
        system = []

        reference2 = []
        system2 = []

        if it % valid_freq == 0:
            print 'start validation'

            collections = [[], [], [], [], []]
            probar_v = Progbar(valid_num / 64 + 1)
            for ij, (srcs, trgs) in enumerate(validIter):

                # new_srcs, new_trgs = [], []

                # for src, trg in zip(srcs, trgs):
                #     if len(src) < config['s0']:
                #         continue  # ignore when the source sentence is less than sidx. we don't use the policy\
                #     else:
                #         new_srcs += [src]
                #         new_trgs += [trg]

                # if len(new_srcs) == 0:
                #     continue
                # srcs, trgs = new_srcs, new_trgs

                statistics = _translate(srcs,
                                        trgs,
                                        train=False,
                                        samples=1,
                                        greedy=True)

                quality, delay, reward = zip(*statistics['track'])
                reference += statistics['Ref']
                system += statistics['Sys']

                # print ' '.join(reference[-1][0])
                # print ' '.join(system[-1])

                # compute the average consective waiting length
                def _consective(action):
                    waits = []
                    temp = 0
                    for a in action:
                        if a == 0:
                            temp += 1
                        elif temp > 0:
                            waits += [temp]
                            temp = 0

                    if temp > 0:
                        waits += [temp]

                    mean = numpy.mean(waits)
                    gec = numpy.max(
                        waits)  # numpy.prod(waits) ** (1./len(waits))
                    return mean, gec

                def _max_length(action):
                    _cur = 0
                    _end = 0
                    _max = 0
                    for it, a in enumerate(action):
                        if a == 0:
                            _cur += 1
                        elif a == 2:
                            _end += 1

                        temp = _cur - _end
                        if temp > _max:
                            _max = temp
                    return _max

                maxlen = [
                    _max_length(action) for action in statistics['action']
                ]
                means, gecs = zip(*(_consective(action)
                                    for action in statistics['action']))

                collections[0] += quality
                collections[1] += delay
                collections[2] += means
                collections[3] += gecs
                collections[4] += maxlen

                values = [('quality', numpy.mean(quality)),
                          ('delay', numpy.mean(delay)),
                          ('wait_mean', numpy.mean(means)),
                          ('wait_max', numpy.mean(gecs)),
                          ('max_len', numpy.mean(maxlen))]
                probar_v.update(ij + 1, values=values)

            validIter.reset()
            valid_bleu, valid_delay, valid_wait, valid_wait_gec, valid_mx = [
                numpy.mean(a) for a in collections
            ]
            print 'Iter = {}: AVG BLEU = {}, DELAY = {}, WAIT(MEAN) = {}, WAIT(MAX) = {}, MaxLen={}'.format(
                it, valid_bleu, valid_delay, valid_wait, valid_wait_gec,
                valid_mx)

            print 'Compute the Corpus BLEU={} (greedy)'.format(
                corpus_bleu(reference, system))

            with open(WORK + '.translate/test.txt', 'w') as fout:
                for sys in system:
                    fout.write('{}\n'.format(' '.join(sys)))

            with open(WORK + '.translate/ref.txt', 'w') as fout:
                for ref in reference:
                    fout.write('{}\n'.format(' '.join(ref[0])))

        if config['upper']:
            print 'done'
            import sys
            sys.exit(-1)

        # training set sentence tuning
        new_srcs, new_trgs = [], []
        for src, trg in zip(srcs, trgs):
            if len(src) <= config['s0']:
                continue  # ignore when the source sentence is less than sidx. we don't use the policy\
            else:
                new_srcs += [src]
                new_trgs += [trg]

        if len(new_srcs) == 0:
            continue

        srcs, trgs = new_srcs, new_trgs
        try:
            statistics, info, pipe_t = _translate(srcs, trgs, train=True)
        except Exception:
            print 'translate a empty sentence. bug.'
            continue

        # samples, scores, actions, rewards, info, pipe_t = _translate(srcs, trgs, train=True)
        # print pipe_t

        if config['finetune'] != 'nope':

            for idx, act in enumerate(pipe_t['action']):
                _start = 0
                _end = 0
                _mask = [0 for _ in srcs[0]]
                _cmask = []

                pipe.messages['x'] += srcs
                pipe.messages['y'] += [pipe_t['sample'][idx]]

                for a in act:
                    # print _start, _end
                    if a == 0:
                        _mask[_start] = 1
                        _start += 1
                    elif a == 2:
                        _mask[_end] = 0
                        _end += 1
                    else:
                        _cmask.append(_mask)
                # print numpy.asarray(_cmask).shape

                pipe.messages['c_mask'].append(_cmask)

            if it % finetune_freq == (finetune_freq - 1):
                num = len(pipe.messages['x'])
                max_x = max([len(v) for v in pipe.messages['x']])
                max_y = max([len(v) for v in pipe.messages['y']])

                xx, xx_mask = _padding(pipe.messages['x'],
                                       shape=(max_x, num),
                                       return_mask=True,
                                       dtype='int64')
                yy, yy_mask = _padding(pipe.messages['y'],
                                       shape=(max_y, num),
                                       return_mask=True,
                                       dtype='int64')
                cc_mask = _padding(pipe.messages['c_mask'],
                                   shape=(max_y, num,
                                          max_x)).transpose([0, 2, 1])

                # fine-tune the EncDec of translation
                if config['finetune'] == 'full':
                    cost = f_fine_cost(xx, xx_mask, yy, yy_mask, cc_mask)
                elif config['finetune'] == 'decoder':
                    cost = f_fine_cost(xx, xx_mask, yy, yy_mask, cc_mask)
                else:
                    raise NotImplementedError

                print '\nIter={} || cost = {}'.format(it, cost[0])
                f_fine_update(0.00001)
                pipe.reset()

        if it % sample_freq == 0:

            print '\nModel:{} has been trained for {} hours'.format(
                _policy.id, (timer() - time0) / 3600.)
            print 'source: ', _bpe2words(_seqs2words([srcs[0]], word_idict))[0]
            print 'target: ', _bpe2words(_seqs2words([trgs[0]],
                                                     word_idict_trg))[0]

            # obtain the translation results
            samples = _bpe2words(
                _seqs2words(statistics['sample'], word_idict_trg))

            # obtain the delay (normalized)
            # delays = _action2delay(srcs[0], statistics['action'])

            c = 0
            for j in xrange(len(samples)):

                if statistics['secs'][j][0] == 0:
                    if c < 5:
                        c += 1

                    print '---ID: {}'.format(_policy.id)
                    print 'sample: ', samples[j]
                    # print 'action: ', ','.join(
                    #     ['{}({})'.format(action_space[t], f)
                    #      for t, f in
                    #          zip(statistics['action'][j], statistics['forgotten'][j])])

                    print 'action: ', ','.join([
                        '{}'.format(action_space[t])
                        for t in statistics['action'][j]
                    ])

                    print 'quality:', statistics['track'][j][0]
                    print 'delay:', statistics['track'][j][1]
                    # print 'score:', statistics['score'][j]
                    break

        values = [(w, info[w]) for w in info]
        probar.update(it + 1, values=values)

        # NaN detector
        for w in info:
            if numpy.isnan(info[w]) or numpy.isinf(info[w]):
                raise RuntimeError, 'NaN/INF is detected!! {} : ID={}'.format(
                    w, id)

        # remote display
        if remote:
            logs = {
                'R': info['R'],
                'Q': info['Q'],
                'D': info['D'],
                'P': float(info['P'])
            }
            # print logs
            for w in logs:
                Log_avg[w] = Log_avg.get(w, 0) + logs[w]

            if it % display_freq == (display_freq - 1):
                for w in Log_avg:
                    Log_avg[w] /= display_freq

                monitor.display(it + 1, Log_avg)
                Log_avg = dict()

        # save the history & model
        history += [info]
        if it % save_freq == 0:
            _policy.save(history, it)
Exemple #33
0
for epoch in xrange(N_EPOCHS):
    steps = (train_x.shape[0] //
             BATCH_SIZE) if train_x.shape[0] % BATCH_SIZE == 0 else (
                 train_x.shape[0] // BATCH_SIZE) + 1
    entropy_train_loss = 0.
    mse_train_loss = 0.
    for ix in xrange(steps):
        batch_x = train_x[ix:ix + BATCH_SIZE]
        mask = np.random.binomial(1, NOISE, batch_x.shape)
        input_x = batch_x * mask
        reconstruction = autoencoder(input_x)
        loss = cross_entropy(batch_x, reconstruction)
        entropy_train_loss += loss
        mse_train_loss += mse_loss(batch_x, reconstruction)
        gradient = cross_entropy.grad()
        # gradient = mse_loss.grad()
        autoencoder.backward(gradient)
        optimizer.step()
        # Cleanup
        optimizer.zero_grad()
    entropy_train_loss /= steps
    mse_train_loss /= steps
    val_preds = autoencoder(val_x)
    entropy_val_loss = cross_entropy(val_x, val_preds)
    mse_val_loss = mse_loss(val_x, val_preds)
    bar.update(epoch + 1,
               values=[("train_entropy", entropy_train_loss),
                       ("val_entropy", entropy_val_loss),
                       ("train_mse", mse_train_loss),
                       ("val_mse", mse_val_loss)])
Exemple #34
0
def evaluate_beam_search(generator,
                         data_loader,
                         opt,
                         title='',
                         epoch=1,
                         save_path=None):
    logging = config.init_logging(title, save_path + '/%s.log' % title)
    progbar = Progbar(logger=logging,
                      title=title,
                      target=len(data_loader) / opt.beam_batch,
                      batch_size=opt.beam_batch,
                      total_examples=len(data_loader) / opt.beam_batch)

    beam_batch_idx = 0
    score_dict = defaultdict(
        list
    )  # {'precision@5':[],'recall@5':[],'f1score@5':[], 'precision@10':[],'recall@10':[],'f1score@10':[]}

    sample_idx = 0
    for i, batch in enumerate(data_loader):
        beam_batch_idx += 1

        src_list, src_len, trg_list, _, _, src_oov_map_list, oov_list, query_lists, query_len, src_str_list, trg_str_list = batch

        if torch.cuda.is_available() and opt.use_gpu:
            src_list = src_list.cuda()
            src_oov_map_list = src_oov_map_list.cuda()
            query_lists = query_lists.cuda()

        pred_seq_list = generator.beam_search(src_list, src_len,
                                              src_oov_map_list, oov_list,
                                              opt.word2id, query_lists,
                                              query_len)

        for src, src_str, trg, trg_str_seqs, pred_seq, oov in zip(
                src_list, src_str_list, trg_list, trg_str_list, pred_seq_list,
                oov_list):
            # logging.info('======================  %d =========================' % (beam_batch_idx))

            pred_is_valid, processed_pred_seqs, processed_pred_str_seqs, processed_pred_score = process_predseqs(
                pred_seq, oov, opt.id2word, opt)

            # 2nd filtering: if filter out phrases that don't appear in text, and keep unique ones after stemming

            pred_is_present = [True] * len(processed_pred_str_seqs)

            valid_and_present = np.asarray(pred_is_valid) * np.asarray(
                pred_is_present)
            '''
            Evaluate predictions w.r.t different filterings and metrics
            '''
            num_oneword_seq = -1  # -1,1
            topk_range = [5, 10]  #5,10
            score_names = ['precision', 'recall', 'f_score']

            processed_pred_seqs = np.asarray(
                processed_pred_seqs)[valid_and_present]
            processed_pred_str_seqs = np.asarray(
                processed_pred_str_seqs)[valid_and_present]
            processed_pred_score = np.asarray(
                processed_pred_score)[valid_and_present]

            # 3rd round filtering (one-word phrases)
            filtered_pred_seq, filtered_pred_str_seqs, filtered_pred_score = post_process_predseqs(
                (processed_pred_seqs, processed_pred_str_seqs,
                 processed_pred_score), num_oneword_seq)

            match_list = get_match_result(true_seqs=trg_str_seqs,
                                          pred_seqs=filtered_pred_str_seqs,
                                          type='exact')
            # logging_result(src_str,trg_str_seqs,filtered_pred_str_seqs,match_list)

            assert len(filtered_pred_seq) == len(
                filtered_pred_str_seqs) == len(filtered_pred_score) == len(
                    match_list)

            for topk in topk_range:
                results = evaluate(match_list,
                                   filtered_pred_seq,
                                   trg_str_seqs,
                                   topk=topk)
                for k, v in zip(score_names, results):
                    if '%s@%d#oneword=%d' % (
                            k, topk, num_oneword_seq) not in score_dict:
                        score_dict['%s@%d#oneword=%d' %
                                   (k, topk, num_oneword_seq)] = []
                    score_dict['%s@%d#oneword=%d' %
                               (k, topk, num_oneword_seq)].append(v)

        if beam_batch_idx % 10 == 0:

            # print('#(precision@5#oneword=1)=%d, avg=%f' % (len(score_dict['precision@5#oneword=1']), np.average(score_dict['precision@5#oneword=1'])))
            # print('#(precision@10#oneword=1)=%d, avg=%f' % (len(score_dict['precision@10#oneword=1']), np.average(score_dict['precision@10#oneword=1'])))

            # print('#(recall@5#oneword=1)=%d, avg=%f' % (len(score_dict['recall@5#oneword=1']), np.average(score_dict['recall@5#oneword=1'])))
            # print('#(recall@10#oneword=1)=%d, avg=%f' % (len(score_dict['recall@10#oneword=1']), np.average(score_dict['recall@10#oneword=1'])))
            metric5 = 'f_score@5#oneword=' + str(num_oneword_seq)
            metric10 = 'f_score@10#oneword=' + str(num_oneword_seq)

            x, y = np.average(score_dict[metric5]), np.average(
                score_dict[metric10])
            print(metric5, x)
            print(metric10, y)
            # print('#(f_score@5#oneword=1)=%d, avg=%f' % (len(score_dict['f_score@5#oneword=1']), x))
            # print('#(f_score@10#oneword=1)=%d, avg=%f' % (len(score_dict['f_score@10#oneword=1']), y))

            progbar.update(epoch, beam_batch_idx, [(metric5, x),
                                                   (metric10, y)])

            print('*' * 50)
        '''
        process each example in current batch
        '''

    print(metric5, x)
    print(metric10, y)

    # if save_path:
    #     # export scores. Each row is scores (precision, recall and f-score) of different way of filtering predictions (how many one-word predictions to keep)
    #     with open(save_path + os.path.sep + title + '_result.csv', 'w') as result_csv:
    #         csv_lines = []

    #         for topk in topk_range:
    #             csv_line = '#oneword=%d,@%d' % (num_oneword_seq, topk)
    #             for k in score_names:
    #                 csv_line += ',%f' % np.average(score_dict['%s@%d#oneword=%d' % (k, topk, num_oneword_seq)])
    #             csv_lines.append(csv_line + '\n')

    #         result_csv.writelines(csv_lines)

    return score_dict
    def train(self, n_iters):
        eval_step = 10

        with tf.Session() as sess:

            self.train_iterator_handle = sess.run(
                self.train_iterator.string_handle())
            self.val_iterator_handle = sess.run(
                self.val_iterator.string_handle())
            self.train_eval_iterator_handle = sess.run(
                self.train_eval_iterator.string_handle())

            sess.run(tf.global_variables_initializer())
            # writer = tf.summary.FileWriter(
            #    'graphs/attention1', sess.graph)
            initial_step = self.gstep.eval()
            sess.run(self.val_iterator.initializer)
            sess.run(self.train_eval_iterator.initializer)

            variables = tf.trainable_variables()
            num_vars = np.sum(
                [np.prod(v.get_shape().as_list()) for v in variables])

            print("Number of variables in models: {}".format(num_vars))
            for epoch in range(n_iters):
                print("epoch #", epoch)
                num_batches = int(67978.0 / self.batch_size)
                progress = Progbar(target=num_batches)
                sess.run(self.train_iterator.initializer)
                index = 0
                total_loss = 0
                progress.update(index, [("training loss", total_loss)])
                while True:
                    index += 1
                    try:
                        total_loss, opt = sess.run(
                            [self.total_loss, self.opt],
                            feed_dict={
                                self.handle: self.train_iterator_handle,
                                self.keep_prob: 0.75
                            })  # , options=options, run_metadata=run_metadata)
                        progress.update(index, [("training loss", total_loss)])

                    except tf.errors.OutOfRangeError:
                        break
                print('evaluation on 500 training elements:')
                preds, contexts, answers = sess.run(
                    [self.preds, self.contexts, self.answers],
                    feed_dict={
                        self.handle: self.train_eval_iterator_handle,
                        self.keep_prob: 1.0
                    })
                predictions = []
                ground_truths = []
                for i in range(len(preds)):
                    predictions.append(
                        convert_indices_to_text(self.vocabulary, contexts[i],
                                                preds[i, 0], preds[i, 1]))
                    ground_truths.append(
                        convert_indices_to_text(self.vocabulary, contexts[i],
                                                answers[i, 0], answers[i, 1]))
                print(evaluate(predictions, ground_truths))
                print('evaluation on 500 validation elements:')
                preds, contexts, answers = sess.run(
                    [self.preds, self.contexts, self.answers],
                    feed_dict={
                        self.handle: self.val_iterator_handle,
                        self.keep_prob: 1.0
                    })
                predictions = []
                ground_truths = []
                for i in range(len(preds)):
                    predictions.append(
                        convert_indices_to_text(self.vocabulary, contexts[i],
                                                preds[i, 0], preds[i, 1]))
                    ground_truths.append(
                        convert_indices_to_text(self.vocabulary, contexts[i],
                                                answers[i, 0], answers[i, 1]))
                print(evaluate(predictions, ground_truths))
                predictions = []
                ground_truths = []
Exemple #36
0
def evaluate_beam_search(generator, data_loader, opt, title='', epoch=1, save_path=None):
    logging = config.init_logging(title, save_path + '/%s.log' % title)
    progbar = Progbar(logger=logging, title=title, target=len(data_loader.dataset.examples), batch_size=data_loader.batch_size,
                      total_examples=len(data_loader.dataset.examples))

    example_idx = 0
    score_dict = {}  # {'precision@5':[],'recall@5':[],'f1score@5':[], 'precision@10':[],'recall@10':[],'f1score@10':[]}

    for i, batch in enumerate(data_loader):
        # if i > 3:
        #     break

        one2many_batch, one2one_batch = batch
        src_list, src_len, trg_list, _, trg_copy_target_list, src_oov_map_list, oov_list, src_str_list, trg_str_list = one2many_batch

        if torch.cuda.is_available():
            src_list = src_list.cuda()
            src_oov_map_list = src_oov_map_list.cuda()

        print("batch size - %s" % str(src_list.size(0)))
        # print("src size - %s" % str(src_list.size()))
        # print("target size - %s" % len(trg_copy_target_list))

        pred_seq_list = generator.beam_search(src_list, src_len, src_oov_map_list, oov_list, opt.word2id)

        '''
        process each example in current batch
        '''
        for src, src_str, trg, trg_str_seqs, trg_copy, pred_seq, oov in zip(src_list, src_str_list, trg_list, trg_str_list, trg_copy_target_list, pred_seq_list, oov_list):
            # logging.info('======================  %d =========================' % (example_idx))
            print_out = ''
            print_out += '[Source][%d]: %s \n' % (len(src_str), ' '.join(src_str))
            # src = src.cpu().data.numpy() if torch.cuda.is_available() else src.data.numpy()
            # print_out += '\nSource Input: \n %s\n' % (' '.join([opt.id2word[x] for x in src[:len(src_str) + 5]]))
            # print_out += 'Real Target String [%d] \n\t\t%s \n' % (len(trg_str_seqs), trg_str_seqs)
            # print_out += 'Real Target Input:  \n\t\t%s \n' % str([[opt.id2word[x] for x in t] for t in trg])
            # print_out += 'Real Target Copy:   \n\t\t%s \n' % str([[opt.id2word[x] if x < opt.vocab_size else oov[x - opt.vocab_size] for x in t] for t in trg_copy])
            trg_str_is_present = if_present_duplicate_phrase(src_str, trg_str_seqs)
            print_out += '[GROUND-TRUTH] #(present)/#(all targets)=%d/%d\n' % (sum(trg_str_is_present), len(trg_str_is_present))
            print_out += '\n'.join(['\t\t[%s]' % ' '.join(phrase) if is_present else '\t\t%s' % ' '.join(phrase) for phrase, is_present in zip(trg_str_seqs, trg_str_is_present)])
            print_out += '\noov_list:   \n\t\t%s \n' % str(oov)

            # 1st filtering
            pred_is_valid, processed_pred_seqs, processed_pred_str_seqs, processed_pred_score = process_predseqs(pred_seq, oov, opt.id2word, opt)
            # 2nd filtering: if filter out phrases that don't appear in text, and keep unique ones after stemming
            if opt.must_appear_in_src:
                pred_is_present = if_present_duplicate_phrase(src_str, processed_pred_str_seqs)
                trg_str_seqs = np.asarray(trg_str_seqs)[trg_str_is_present]
            else:
                pred_is_present = [True] * len(processed_pred_str_seqs)

            valid_and_present = np.asarray(pred_is_valid) * np.asarray(pred_is_present)
            match_list = get_match_result(true_seqs=trg_str_seqs, pred_seqs=processed_pred_str_seqs)
            print_out += '[PREDICTION] #(valid)=%d, #(present)=%d, #(retained&present)=%d, #(all)=%d\n' % (sum(pred_is_valid), sum(pred_is_present), sum(valid_and_present), len(pred_seq))
            print_out += ''
            '''
            Print and export predictions
            '''
            preds_out = ''

            for p_id, (seq, word, score, match, is_valid, is_present) in enumerate(
                    zip(processed_pred_seqs, processed_pred_str_seqs, processed_pred_score, match_list, pred_is_valid, pred_is_present)):
                # if p_id > 5:
                #     break

                preds_out += '%s\n' % (' '.join(word))
                if is_present:
                    print_phrase = '[%s]' % ' '.join(word)
                else:
                    print_phrase = ' '.join(word)

                if is_valid:
                    print_phrase = '*%s' % print_phrase

                if match == 1.0:
                    correct_str = '[correct!]'
                else:
                    correct_str = ''
                if any([t >= opt.vocab_size for t in seq.sentence]):
                    copy_str = '[copied!]'
                else:
                    copy_str = ''

                # print_out += '\t\t[%.4f]\t%s \t %s %s%s\n' % (-score, print_phrase, str(seq.sentence), correct_str, copy_str)

            '''
            Evaluate predictions w.r.t different filterings and metrics
            '''
            num_oneword_range = [-1, 1]
            topk_range = [5, 10]
            score_names = ['precision', 'recall', 'f_score']

            processed_pred_seqs = np.asarray(processed_pred_seqs)[valid_and_present]
            processed_pred_str_seqs = np.asarray(processed_pred_str_seqs)[valid_and_present]
            processed_pred_score = np.asarray(processed_pred_score)[valid_and_present]

            for num_oneword_seq in num_oneword_range:
                # 3rd round filtering (one-word phrases)
                filtered_pred_seq, filtered_pred_str_seqs, filtered_pred_score = post_process_predseqs((processed_pred_seqs, processed_pred_str_seqs, processed_pred_score), num_oneword_seq)

                match_list = get_match_result(true_seqs=trg_str_seqs, pred_seqs=filtered_pred_str_seqs)

                assert len(filtered_pred_seq) == len(filtered_pred_str_seqs) == len(filtered_pred_score) == len(match_list)

                for topk in topk_range:
                    results = evaluate(match_list, filtered_pred_seq, trg_str_seqs, topk=topk)
                    for k, v in zip(score_names, results):
                        if '%s@%d#oneword=%d' % (k, topk, num_oneword_seq) not in score_dict:
                            score_dict['%s@%d#oneword=%d' % (k, topk, num_oneword_seq)] = []
                        score_dict['%s@%d#oneword=%d' % (k, topk, num_oneword_seq)].append(v)

                        print_out += '\t%s@%d#oneword=%d = %f\n' % (k, topk, num_oneword_seq, v)

            # logging.info(print_out)

            if save_path:
                if not os.path.exists(os.path.join(save_path, title + '_detail')):
                    os.makedirs(os.path.join(save_path, title + '_detail'))
                with open(os.path.join(save_path, title + '_detail', str(example_idx) + '_print.txt'), 'w') as f_:
                    f_.write(print_out)
                with open(os.path.join(save_path, title + '_detail', str(example_idx) + '_prediction.txt'), 'w') as f_:
                    f_.write(preds_out)

            progbar.update(epoch, example_idx, [('f_score@5#oneword=-1', np.average(score_dict['f_score@5#oneword=-1'])), ('f_score@10#oneword=-1', np.average(score_dict['f_score@10#oneword=-1']))])

            example_idx += 1

    print('#(f_score@5#oneword=-1)=%d, sum=%f' % (len(score_dict['f_score@5#oneword=-1']), sum(score_dict['f_score@5#oneword=-1'])))
    print('#(f_score@10#oneword=-1)=%d, sum=%f' % (len(score_dict['f_score@10#oneword=-1']), sum(score_dict['f_score@10#oneword=-1'])))
    print('#(f_score@5#oneword=1)=%d, sum=%f' % (len(score_dict['f_score@5#oneword=1']), sum(score_dict['f_score@5#oneword=1'])))
    print('#(f_score@10#oneword=1)=%d, sum=%f' % (len(score_dict['f_score@10#oneword=1']), sum(score_dict['f_score@10#oneword=1'])))

    if save_path:
        # export scores. Each row is scores (precision, recall and f-score) of different way of filtering predictions (how many one-word predictions to keep)
        with open(save_path + os.path.sep + title + '_result.csv', 'w') as result_csv:
            csv_lines = []
            for num_oneword_seq in num_oneword_range:
                for topk in topk_range:
                    csv_line = '#oneword=%d,@%d' % (num_oneword_seq, topk)
                    for k in score_names:
                        csv_line += ',%f' % np.average(score_dict['%s@%d#oneword=%d' % (k, topk, num_oneword_seq)])
                    csv_lines.append(csv_line + '\n')

            result_csv.writelines(csv_lines)

    # precision, recall, f_score = macro_averaged_score(precisionlist=score_dict['precision'], recalllist=score_dict['recall'])
    # logging.info("Macro@5\n\t\tprecision %.4f\n\t\tmacro recall %.4f\n\t\tmacro fscore %.4f " % (np.average(score_dict['precision@5']), np.average(score_dict['recall@5']), np.average(score_dict['f1score@5'])))
    # logging.info("Macro@10\n\t\tprecision %.4f\n\t\tmacro recall %.4f\n\t\tmacro fscore %.4f " % (np.average(score_dict['precision@10']), np.average(score_dict['recall@10']), np.average(score_dict['f1score@10'])))
    # precision, recall, f_score = evaluate(true_seqs=target_all, pred_seqs=prediction_all, topn=5)
    # logging.info("micro precision %.4f , micro recall %.4f, micro fscore %.4f " % (precision, recall, f_score))

    return score_dict