if k in vars(opt) and getattr(opt, k) is not None:
            assert vars(opt)[k] == vars(
                infos['opt'])[k], k + ' option not consistent'
        else:
            vars(opt).update({k: vars(infos['opt'])[k]
                              })  # copy over options from model

vocab = infos['vocab']  # ix -> word mapping

assert opt.seq_per_img == 5

opt.vse_loss_weight = vars(opt).get('vse_loss_weight', 1)
opt.caption_loss_weight = vars(opt).get('caption_loss_weight', 1)

# Setup the model
model = models.JointModel(opt)
utils.load_state_dict(model, torch.load(opt.model))
if opt.initialize_retrieval is not None:
    print("Make sure the vse opt are the same !!!!!\n" * 100)
    utils.load_state_dict(
        model, {
            k: v
            for k, v in torch.load(opt.initialize_retrieval).items()
            if 'vse' in k
        })
model.cuda()
model.eval()

# Create the Data Loader instance
if len(opt.image_folder) == 0:
    loader = DataLoader(opt)
def eval_split(model, crit, loader, eval_kwargs={}):
    verbose = eval_kwargs.get('verbose', True)
    verbose_beam = eval_kwargs.get('verbose_beam', 1)
    verbose_loss = eval_kwargs.get('verbose_loss', 1)
    num_images = eval_kwargs.get('num_images',
                                 eval_kwargs.get('val_images_use', -1))
    split = eval_kwargs.get('split', 'val')
    lang_eval = eval_kwargs.get('language_eval', 0)
    dataset = eval_kwargs.get('dataset', 'coco')
    beam_size = eval_kwargs.get('beam_size', 1)

    if eval_kwargs.get('rank', 0):
        infos_path = 'log_fc_con/infos_vse_fc_con-best.pkl'  # 'log_fc_con_discsplit/infos_vse_fc_con_discsplit-best.pkl'
        model_path = 'log_fc_con/model_vse-best.pth'  # 'log_fc_con_discsplit/model_vse-best.pth'
        with open(infos_path) as f:
            infos = cPickle.load(f)

        rank_model = models.JointModel(infos['opt'])
        utils.load_state_dict(rank_model, torch.load(model_path))
        rank_model.cuda()
        rank_model.eval()
        print('success loaded retrieval model !')

    # Make sure in the evaluation mode
    model.eval()

    loader.reset_iterator(split)

    n = 0
    loss = 0
    loss_sum = 0
    loss_evals = 1e-8
    predictions = []
    seqs = []
    while True:
        data = loader.get_batch(split)
        n = n + loader.batch_size
        sys.stdout.flush()
        if data.get('labels', None) is not None and verbose_loss:
            # forward the model to get loss
            tmp = [
                data['fc_feats'], data['att_feats'], data['labels'],
                data['masks'], data['att_masks']
            ]
            tmp = [
                torch.from_numpy(_).cuda() if _ is not None else _ for _ in tmp
            ]
            fc_feats, att_feats, labels, masks, att_masks = tmp

            with torch.no_grad():
                loss = crit(model(fc_feats, att_feats, labels, att_masks),
                            labels[:, 1:], masks[:, 1:]).item()
            loss_sum = loss_sum + loss
            loss_evals = loss_evals + 1

        # forward the model to also get generated samples for each image
        # Only leave one feature for each image, in case duplicate sample
        tmp = [
            data['fc_feats'][np.arange(loader.batch_size) *
                             loader.seq_per_img],
            data['att_feats'][np.arange(loader.batch_size) *
                              loader.seq_per_img],
            data['att_masks'][np.arange(loader.batch_size) *
                              loader.seq_per_img]
            if data['att_masks'] is not None else None
        ]
        tmp = [torch.from_numpy(_).cuda() if _ is not None else _ for _ in tmp]
        fc_feats, att_feats, att_masks = tmp
        # forward the model to also get generated samples for each image
        with torch.no_grad():
            seq = model(fc_feats,
                        att_feats,
                        att_masks,
                        opt=eval_kwargs,
                        mode='sample')[0].data

        # Print beam search
        if beam_size > 1 and verbose_beam:
            for i in range(loader.batch_size):
                pass  #print('\n'.join([utils.decode_sequence(loader.get_vocab(), _['seq'].unsqueeze(0))[0] for _ in model.done_beams[i]]))
                #print('--' * 10)
        sents = utils.decode_sequence(loader.get_vocab(), seq)

        for k, sent in enumerate(sents):
            entry = {'image_id': data['infos'][k]['id'], 'caption': sent}
            if eval_kwargs.get('dump_path', 0) == 1:
                entry['file_name'] = data['infos'][k]['file_path']
            predictions.append(entry)
            if eval_kwargs.get('dump_images', 0) == 1:
                # dump the raw image to vis/ folder
                cmd = 'cp "' + os.path.join(
                    eval_kwargs['image_root'],
                    data['infos'][k]['file_path']) + '" vis/imgs/img' + str(
                        len(predictions)) + '.jpg'  # bit gross
                print(cmd)
                os.system(cmd)

            if verbose:
                print('image %s: %s' % (entry['image_id'], entry['caption']))

        # if we wrapped around the split or used up val imgs budget then bail
        ix0 = data['bounds']['it_pos_now']
        ix1 = data['bounds']['it_max']
        if eval_kwargs.get('rank', 0):
            seqs.append(padding(seq, 30))
        if num_images != -1:
            ix1 = min(ix1, num_images)
        for i in range(n - ix1):
            predictions.pop()
        if n > ix1:
            seq = seq[:(ix1 - n) * loader.seq_per_img]

        if verbose:
            print('evaluating validation preformance... %d/%d (%f)' %
                  (ix0 - 1, ix1, loss))

        if data['bounds']['wrapped']:
            break
        if num_images >= 0 and n >= num_images:
            break

    if eval_kwargs.get('rank', 0):
        seqs = torch.cat(seqs, 0).contiguous()
        seqs = change_seq(seqs, loader.ix_to_word)

    if eval_kwargs.get('vsepp', 0):
        from eval_vsepp import evalrank_vsepp
        from eval_utils_pair import get_transform
        import torchvision.transforms as transforms
        from PIL import Image

        imgids = [_['image_id'] for _ in predictions]
        seqs = seqs[:num_images]

        transform = get_transform('COCO', 'val', None)
        imgs = []
        for i, imgid in enumerate(imgids):
            img_path = '../imgcap/data/raw_images/val2014/COCO_val2014_' + str(
                imgid).zfill(12) + '.jpg'
            if i % 100 == 0:
                print('load %d images' % i)
            image = Image.open(img_path).convert('RGB')
            image = transform(image)
            imgs.append(image.unsqueeze(0))
        imgs = torch.cat(imgs, 0).contiguous()
        lengths = torch.sum((seqs > 0), 1) + 1
        lengths = lengths.cpu()
        with torch.no_grad():
            evalrank_vsepp(imgs, loader.ix_to_word, seqs, lengths)

    lang_stats = None
    if lang_eval == 1:
        lang_stats = language_eval(eval_kwargs.get('data', 'coco'),
                                   predictions, eval_kwargs['id'], split)

    if eval_kwargs.get('rank', 0):
        ranks = evalrank(rank_model, loader, seqs, eval_kwargs)

    # Switch back to training mode
    model.train()
    return loss_sum / loss_evals, predictions, lang_stats
Beispiel #3
0
def train(opt):
    opt.use_att = utils.if_use_att(opt)
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tf_summary_writer = tf and SummaryWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from,
                               'infos_' + opt.id + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)
        best_val_score_vse = infos.get('best_val_score_vse', None)

    model = models.JointModel(opt)
    model.cuda()

    update_lr_flag = True
    # Assure in training mode
    model.train()

    optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad],
                           lr=opt.learning_rate,
                           weight_decay=opt.weight_decay)

    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(
            os.path.join(opt.start_from, 'optimizer.pth')):
        state_dict = torch.load(os.path.join(opt.start_from, 'optimizer.pth'))
        if len(state_dict['state']) == len(optimizer.state_dict()['state']):
            optimizer.load_state_dict(state_dict)
        else:
            print(
                'Optimizer param group number not matched? There must be new parameters. Reinit the optimizer.'
            )

    init_scorer(opt.cached_tokens)
    while True:
        if update_lr_flag:
            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
                utils.set_lr(optimizer, opt.current_lr)  # set the decayed rate
            else:
                opt.current_lr = opt.learning_rate
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.caption_generator.ss_prob = opt.ss_prob
            # Assign retrieval loss weight
            if epoch > opt.retrieval_reward_weight_decay_start and opt.retrieval_reward_weight_decay_start >= 0:
                frac = (epoch - opt.retrieval_reward_weight_decay_start
                        ) // opt.retrieval_reward_weight_decay_every
                model.retrieval_reward_weight = opt.retrieval_reward_weight * (
                    opt.retrieval_reward_weight_decay_rate**frac)
            update_lr_flag = False

        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch('train')
        print('Read data:', time.time() - start)

        torch.cuda.synchronize()
        start = time.time()

        tmp = [
            data['fc_feats'], data['att_feats'], data['att_masks'],
            data['labels'], data['masks']
        ]
        tmp = utils.var_wrapper(tmp)
        fc_feats, att_feats, att_masks, labels, masks = tmp

        optimizer.zero_grad()

        loss = model(fc_feats, att_feats, att_masks, labels, masks, data)
        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        train_loss = loss.data[0]
        torch.cuda.synchronize()
        end = time.time()
        print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
            .format(iteration, epoch, train_loss, end - start))
        prt_str = ""
        for k, v in model.loss().items():
            prt_str += "{} = {:.3f} ".format(k, v)
        print(prt_str)

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            if tf is not None:
                tf_summary_writer.add_scalar('train_loss', train_loss,
                                             iteration)
                for k, v in model.loss().items():
                    tf_summary_writer.add_scalar(k, v, iteration)
                tf_summary_writer.add_scalar('learning_rate', opt.current_lr,
                                             iteration)
                tf_summary_writer.add_scalar('scheduled_sampling_prob',
                                             model.caption_generator.ss_prob,
                                             iteration)
                tf_summary_writer.add_scalar('retrieval_reward_weight',
                                             model.retrieval_reward_weight,
                                             iteration)
                tf_summary_writer.file_writer.flush()

            loss_history[iteration] = train_loss
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.caption_generator.ss_prob

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            # eval model
            eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            # Load the retrieval model for evaluation
            val_loss, predictions, lang_stats = eval_utils.eval_split(
                model, loader, eval_kwargs)

            # Write validation result into summary
            if tf is not None:
                for k, v in val_loss.items():
                    tf_summary_writer.add_scalar('validation ' + k, v,
                                                 iteration)
                for k, v in lang_stats.items():
                    tf_summary_writer.add_scalar(k, v, iteration)
                tf_summary_writer.add_text(
                    'Captions',
                    '.\n\n'.join([_['caption'] for _ in predictions[:100]]),
                    iteration)
                #tf_summary_writer.add_image('images', utils.make_summary_image(), iteration)
                #utils.make_html(opt.id, iteration)
                tf_summary_writer.file_writer.flush()

            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['SPICE'] * 100
            else:
                current_score = -val_loss['loss_cap']
            current_score_vse = val_loss.get(opt.vse_eval_criterion, 0) * 100

            best_flag = False
            best_flag_vse = False
            if True:  # if true
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                if best_val_score_vse is None or current_score_vse > best_val_score_vse:
                    best_val_score_vse = current_score_vse
                    best_flag_vse = True
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model-%d.pth' % (iteration))
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['best_val_score_vse'] = best_val_score_vse
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'infos_' + opt.id + '.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(
                        os.path.join(
                            opt.checkpoint_path,
                            'infos_' + opt.id + '-%d.pkl' % (iteration)),
                        'wb') as f:
                    cPickle.dump(infos, f)
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'histories_' + opt.id + '.pkl'),
                        'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos_' + opt.id + '-best.pkl'),
                            'wb') as f:
                        cPickle.dump(infos, f)
                if best_flag_vse:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model_vse-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos_vse_' + opt.id + '-best.pkl'),
                            'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
Beispiel #4
0
def train(opt):
    # Deal with feature things before anything
    opt.use_fc, opt.use_att = utils.if_use_feat(opt.caption_model)
    if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5

    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)
    #opt.ss_prob=0.0
    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from, 'infos_'+opt.id+'.pkl'), 'rb') as f:
            infos = utils.pickle_load(f)
            saved_model_opt = infos['opt']
            need_be_same=["caption_model", "rnn_type", "rnn_size", "num_layers"]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')):
            with open(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl'), 'rb') as f:
                histories = utils.pickle_load(f)
    else:
        infos['iter'] = 0
        infos['epoch'] = 0
        infos['iterators'] = loader.iterators
        infos['split_ix'] = loader.split_ix
        infos['vocab'] = loader.get_vocab()
        infos['pix_perss']=loader.get_personality()
    infos['opt'] = opt
    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)
    print("current epoch: ",epoch)
    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    opt.vocab = loader.get_vocab()
    opt.xpersonality=loader.get_personality()
    if opt.use_joint==0:
        #torch.cuda.set_device(0)
        model = models.setup(opt).cuda()
    elif opt.use_joint==1:
        model = models.JointModel(opt)
        model.cuda()
    #model=models.setup(opt)
    del opt.vocab
    if opt.start_from is not None:
        opt.model=os.path.join(opt.start_from, 'model'+'.pth')
        model.load_state_dict(torch.load(opt.model))
    dp_model = torch.nn.DataParallel(model)
    lw_model = LossWrapper(model, opt)
    dp_lw_model = torch.nn.DataParallel(lw_model)
    #dp_lw_model=LossWrapper(model, opt)  # this is for no cuda
    epoch_done = True
    # Assure in training mode
    #dp_lw_model=lw_model
    dp_lw_model.train()
    if opt.noamopt:
        assert opt.caption_model == 'transformer', 'noamopt can only work with transformer'
        optimizer = utils.get_std_opt(model, factor=opt.noamopt_factor, warmup=opt.noamopt_warmup)
        optimizer._step = iteration
    elif opt.reduce_on_plateau:
        optimizer = utils.build_optimizer([p for p in model.parameters() if p.requires_grad], opt)
        optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3)
    else:
        optimizer = utils.build_optimizer([p for p in model.parameters() if p.requires_grad], opt)
    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(os.path.join(opt.start_from,"optimizer.pth")):
        optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'optimizer.pth')))
    else:
        print('Optimizer param group number not matched? There must be new parameters. Reinit the optimizer.')


    def save_checkpoint(model, infos, optimizer, histories=None, append=''):
        if len(append) > 0:
            append = '-' + append
        # if checkpoint_path doesn't exist
        if not os.path.isdir(opt.checkpoint_path):
            os.makedirs(opt.checkpoint_path)
        checkpoint_path = os.path.join(opt.checkpoint_path, 'model%s.pth' %(append))
        torch.save(model.state_dict(), checkpoint_path)
        print("model saved to {}".format(checkpoint_path))
        optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer%s.pth' %(append))
        torch.save(optimizer.state_dict(), optimizer_path)
        with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'%s.pkl' %(append)), 'wb') as f:
            utils.pickle_dump(infos, f)
        if histories:
            with open(os.path.join(opt.checkpoint_path, 'histories_'+opt.id+'%s.pkl' %(append)), 'wb') as f:
                utils.pickle_dump(histories, f)

    try:
        while True:
            if epoch_done:
                if not opt.noamopt and not opt.reduce_on_plateau:
                    # Assign the learning rate
                    if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                        frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every
                        decay_factor = opt.learning_rate_decay_rate  ** frac
                        opt.current_lr = opt.learning_rate * decay_factor
                    else:
                        opt.current_lr = opt.learning_rate
                    utils.set_lr(optimizer, opt.current_lr) # set the decayed rate
                # Assign the scheduled sampling prob
                if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                    frac = (epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every
                    opt.ss_prob = min(opt.scheduled_sampling_increase_prob  * frac, opt.scheduled_sampling_max_prob)
                    model.ss_prob = opt.ss_prob

                # If start self critical training
                if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                    sc_flag = True
                    init_scorer(opt.cached_tokens)
                else:
                    sc_flag = False
                # Assign retrieval loss weight
                if epoch > opt.retrieval_reward_weight_decay_start and opt.retrieval_reward_weight_decay_start >= 0:
                    frac = (epoch - opt.retrieval_reward_weight_decay_start) // opt.retrieval_reward_weight_decay_every
                    model.retrieval_reward_weight = opt.retrieval_reward_weight * (opt.retrieval_reward_weight_decay_rate  ** frac)
                epoch_done = False
                    
            start = time.time()
            # Load data from train split (0)
            data = loader.get_batch('train')
            print('Read data:', time.time() - start)

            torch.cuda.synchronize()
            start = time.time()
            with torch.autograd.set_detect_anomaly(True):
                tmp = [data['fc_feats'], data['att_feats'],data['densecap'], data['labels'], data['masks'], data['att_masks'], data['personality']]
                tmp = [_ if _ is None else _.cuda() for _ in tmp]
                fc_feats, att_feats,densecap, labels, masks, att_masks,personality = tmp
                optimizer.zero_grad()
                model_out = dp_lw_model(fc_feats, att_feats,densecap, labels, masks, att_masks,personality, data['gts'], torch.arange(0, len(data['gts'])), sc_flag)

                loss = model_out['loss'].mean()
                
                loss.backward()
                utils.clip_gradient(optimizer, opt.grad_clip)
                optimizer.step()
                train_loss = loss.item()
                torch.cuda.synchronize()
                end = time.time()
            if not sc_flag:
                print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, train_loss, end - start))
            else:
                print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f},train_loss = {:.3f}" \
                    .format(iteration, epoch, model_out['reward'].mean(), end - start,train_loss))

            if opt.use_joint==1:
                for k, v in model.loss().items():
                    prt_str += "{} = {:.3f} ".format(k, v)
                print(prt_str)

            # Update the iteration and epoch
            iteration += 1
            if data['bounds']['wrapped']:
                epoch += 1
                epoch_done = True

            # Write the training loss summary
            if (iteration % opt.losses_log_every == 0):
                add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration)
                if opt.noamopt:
                    opt.current_lr = optimizer.rate()
                elif opt.reduce_on_plateau:
                    opt.current_lr = optimizer.current_lr
                add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration)
                add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration)
                if sc_flag:
                    add_summary_value(tb_summary_writer, 'avg_reward', model_out['reward'].mean(), iteration)

                loss_history[iteration] = train_loss if not sc_flag else model_out['reward'].mean()
                lr_history[iteration] = opt.current_lr
                ss_prob_history[iteration] = model.ss_prob

            # update infos
            infos['iter'] = iteration
            infos['epoch'] = epoch
            infos['iterators'] = loader.iterators
            infos['split_ix'] = loader.split_ix
            
            # make evaluation on validation set, and save model
            if (iteration % opt.save_checkpoint_every == 0):
                # eval model
                eval_kwargs = {'split': 'val',
                                'dataset': opt.input_json}
                eval_kwargs.update(vars(opt))
                val_loss, predictions, lang_stats = eval_utils.eval_split(
                    dp_model, lw_model.crit, loader, eval_kwargs)

                if opt.reduce_on_plateau:                    
                    if 'CIDEr' in lang_stats:
                        optimizer.scheduler_step(-lang_stats['CIDEr'])
                    else:
                        optimizer.scheduler_step(val_loss)
                # Write validation result into summary
                add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration)
                if lang_stats is not None:
                    for k,v in lang_stats.items():
                        add_summary_value(tb_summary_writer, k, v, iteration)
                val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions}

                # Save model if is improving on validation result
                    
                if opt.language_eval == 1:
                    if opt.use_joint==1:
                        current_score = lang_stats['SPICE']*100
                    elif opt.use_joint==0:
                        current_score = lang_stats['CIDEr'] # could use SPICE
                else:
                    if opt.use_joint==0:
                        current_score = - val_loss
                    elif opt.use_joint==1:
                        current_score= - val_loss['loss_cap']
                if opt.use_joint==1:
                    current_score_vse = val_loss.get(opt.vse_eval_criterion, 0)*100

                best_flag = False
                best_flag_vse= False
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                if opt.use_joint==1:
                    if best_val_score_vse is None or current_score_vse > best_val_score_vse:
                        best_val_score_vse = current_score_vse
                        best_flag_vse = True
                    infos['best_val_score_vse'] = best_val_score_vse
                # Dump miscalleous informations
                infos['best_val_score'] = best_val_score
                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history

                save_checkpoint(model, infos, optimizer, histories)
                if opt.save_history_ckpt:
                    save_checkpoint(model, infos, optimizer, append=str(iteration))

                if best_flag:
                    save_checkpoint(model, infos, optimizer, append='best')
                if best_flag_vse:
                    save_checkpoint(model, infos, optimizer, append='vse-best')

            # Stop if reaching max epochs
            if epoch >= opt.max_epochs and opt.max_epochs != -1:
                break
    except (RuntimeError, KeyboardInterrupt):
        print('Save ckpt on exception ...')
        save_checkpoint(model, infos, optimizer)
        print('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)