Exemple #1
0
def train(opt):
    # Deal with feature things before anything
    opt.use_fc, opt.use_att = utils.if_use_feat(opt.caption_model)
    if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5

    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length
    opt.vocab = loader.get_vocab()

    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        # with open(os.path.join(opt.start_from, 'infos_'+opt.id+'.pkl'), 'rb') as f:
        with open(os.path.join(opt.start_from, 'infos_'+opt.start_from.split('/')[-1]+'.pkl'), 'rb') as f:
            infos = utils.pickle_load(f)
            saved_model_opt = infos['opt']
            need_be_same=["caption_model", "rnn_type", "rnn_size", "num_layers"]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        # if os.path.isfile(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')):
        #     with open(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl'), 'rb') as f:
        if os.path.isfile(os.path.join(opt.start_from, 'histories_'+opt.start_from.split('/')[-1]+'.pkl')):
            with open(os.path.join(opt.start_from, 'histories_'+opt.start_from.split('/')[-1]+'.pkl'), 'rb') as f:
                histories = utils.pickle_load(f)
    else:
        infos['iter'] = 0
        infos['epoch'] = 0
        infos['iterators'] = loader.iterators
        infos['split_ix'] = loader.split_ix
        infos['vocab'] = loader.get_vocab()
    infos['opt'] = opt

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt).cuda()
    dp_model = torch.nn.DataParallel(model)
    lw_model = LossWrapper(model, opt)
    dp_lw_model = torch.nn.DataParallel(lw_model)

    epoch_done = True
    # Assure in training mode
    dp_lw_model.train()

    if opt.noamopt:
        assert opt.caption_model == 'transformer', 'noamopt can only work with transformer'
        optimizer = utils.get_std_opt(model, factor=opt.noamopt_factor, warmup=opt.noamopt_warmup)
        optimizer._step = iteration
    elif opt.reduce_on_plateau:
        optimizer = utils.build_optimizer(model.parameters(), opt)
        optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3)
    else:
        optimizer = utils.build_optimizer(model.parameters(), opt)
    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(os.path.join(opt.start_from,"optimizer.pth")):
        optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'optimizer.pth')))


    def save_checkpoint(model, infos, optimizer, histories=None, append=''):
        if len(append) > 0:
            append = '-' + append
        # if checkpoint_path doesn't exist
        if not os.path.isdir(opt.checkpoint_path):
            os.makedirs(opt.checkpoint_path)
        checkpoint_path = os.path.join(opt.checkpoint_path, 'model%s.pth' %(append))
        torch.save(model.state_dict(), checkpoint_path)
        print("model saved to {}".format(checkpoint_path))
        optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer%s.pth' %(append))
        torch.save(optimizer.state_dict(), optimizer_path)
        with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'%s.pkl' %(append)), 'wb') as f:
            utils.pickle_dump(infos, f)
        if histories:
            with open(os.path.join(opt.checkpoint_path, 'histories_'+opt.id+'%s.pkl' %(append)), 'wb') as f:
                utils.pickle_dump(histories, f)
    # pdb.set_trace()
    try:
        while True:
            if epoch_done:
                if not opt.noamopt and not opt.reduce_on_plateau:
                    # Assign the learning rate
                    if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                        frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every
                        decay_factor = opt.learning_rate_decay_rate  ** frac
                        opt.current_lr = opt.learning_rate * decay_factor
                    else:
                        opt.current_lr = opt.learning_rate
                    utils.set_lr(optimizer, opt.current_lr) # set the decayed rate
                # Assign the scheduled sampling prob
                if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                    frac = (epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every
                    opt.ss_prob = min(opt.scheduled_sampling_increase_prob  * frac, opt.scheduled_sampling_max_prob)
                    model.ss_prob = opt.ss_prob

                # If start self critical training
                if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                    sc_flag = True
                    init_scorer(opt.cached_tokens)
                else:
                    sc_flag = False
                    init_scorer(opt.cached_tokens)

                epoch_done = False
                    
            start = time.time()
            # Load data from train split (0)
            data = loader.get_batch('train')
            # pdb.set_trace()
            print('Read data:', time.time() - start)

            torch.cuda.synchronize()
            start = time.time()

            tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks'], data['sents_mask']]
            tmp = [_ if _ is None else _.cuda() for _ in tmp]
            fc_feats, att_feats, labels, masks, att_masks, sents_mask = tmp
            box_inds = None
                
            optimizer.zero_grad()
            model_out = dp_lw_model(fc_feats, att_feats, labels, masks, att_masks, data['gts'], torch.arange(0, len(data['gts'])), sc_flag, box_inds, epoch, sents_mask)

            loss = model_out['loss'].mean()

            loss.backward()
            utils.clip_gradient(optimizer, opt.grad_clip)
            optimizer.step()
            train_loss = loss.item()
            torch.cuda.synchronize()
            end = time.time()
            if not sc_flag:
                print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, train_loss, end - start))
            else:
                print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, model_out['reward'].mean(), end - start))

            # Update the iteration and epoch
            iteration += 1
            if data['bounds']['wrapped']:
                epoch += 1
                epoch_done = True

            # Write the training loss summary
            if (iteration % opt.losses_log_every == 0):
                add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration)
                if opt.noamopt:
                    opt.current_lr = optimizer.rate()
                elif opt.reduce_on_plateau:
                    opt.current_lr = optimizer.current_lr
                add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration)
                add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration)
                if sc_flag:
                    add_summary_value(tb_summary_writer, 'avg_reward', model_out['reward'].mean(), iteration)

                loss_history[iteration] = train_loss if not sc_flag else model_out['reward'].mean()
                lr_history[iteration] = opt.current_lr
                ss_prob_history[iteration] = model.ss_prob

            # update infos
            infos['iter'] = iteration
            infos['epoch'] = epoch
            infos['iterators'] = loader.iterators
            infos['split_ix'] = loader.split_ix
            
            # make evaluation on validation set, and save model
            # eval model
            # eval_kwargs = {'split': 'val',
            #                 'dataset': opt.input_json}
            # eval_kwargs.update(vars(opt))
            # val_loss, predictions, lang_stats = eval_utils.eval_split(
            #     dp_model, lw_model.crit, loader, eval_kwargs)

            if (iteration % opt.save_checkpoint_every == 0):
                # eval model
                eval_kwargs = {'split': 'val',
                                'dataset': opt.input_json}
                eval_kwargs.update(vars(opt))
                val_loss, predictions, lang_stats = eval_utils.eval_split(
                    dp_model, lw_model.crit, loader, eval_kwargs)

                if opt.reduce_on_plateau:
                    if 'CIDEr' in lang_stats:
                        optimizer.scheduler_step(-lang_stats['CIDEr'])
                    else:
                        optimizer.scheduler_step(val_loss)
                # Write validation result into summary
                add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration)
                if lang_stats is not None:
                    for k,v in lang_stats.items():
                        add_summary_value(tb_summary_writer, k, v, iteration)
                val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions}

                # Save model if is improving on validation result
                if opt.language_eval == 1:
                    current_score = lang_stats['CIDEr']
                else:
                    current_score = - val_loss

                best_flag = False

                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True

                # Dump miscalleous informations
                infos['best_val_score'] = best_val_score
                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history

                save_checkpoint(model, infos, optimizer, histories)
                if opt.save_history_ckpt:
                    save_checkpoint(model, infos, optimizer, append=str(iteration))

                if best_flag:
                    save_checkpoint(model, infos, optimizer, append='best')

            # Stop if reaching max epochs
            if epoch >= opt.max_epochs and opt.max_epochs != -1:
                break
    except (RuntimeError, KeyboardInterrupt):
        print('Save ckpt on exception ...')
        save_checkpoint(model, infos, optimizer)
        print('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)
Exemple #2
0
import torch
import torch.nn as nn

import tutor_opts
from tutor_loader import DataLoader

import models
from misc.loss_wrapper import LossWrapper



opt = tutor_opts.parse_opt()
loader = DataLoader(opt)
opt.vocab_size = loader.vocab_size
opt.seq_length = loader.seq_length
opt.vocab = loader.get_vocab()

model = models.setup(opt).cuda(); del opt.vocab
dp_model = torch.nn.DataParallel(model)
lw_model = LossWrapper(model, opt)
dp_lw_model = torch.nn.DataParallel(lw_model)
dp_lw_model.train()

data = loader.get_batch('train')
tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks']]
tmp = [_ if _ is None else _.cuda() for _ in tmp]
fc_feats, att_feats, labels, masks, att_masks = tmp

model_out = dp_lw_model(fc_feats, att_feats, labels, masks, att_masks, data['gts'], torch.arange(0, len(data['gts'])), sc_flag=False)
Exemple #3
0
def train(opt):

    ################################
    # Build dataloader
    ################################
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    ##########################
    # Initialize infos
    ##########################
    infos = {
        'iter': 0,
        'epoch': 0,
        'loader_state_dict': None,
        'vocab': loader.get_vocab(),
    }
    # Load old infos(if there is) and check if models are compatible
    if opt.start_from is not None and os.path.isfile(os.path.join(opt.start_from, 'infos_'+opt.id+'.pkl')):
        with open(os.path.join(opt.start_from, 'infos_'+opt.id+'.pkl'), 'rb') as f:
            infos = utils.pickle_load(f)
            saved_model_opt = infos['opt']
            need_be_same=["caption_model", "rnn_type", "rnn_size", "num_layers"]
            for checkme in need_be_same:
                assert getattr(saved_model_opt, checkme) == getattr(opt, checkme), "Command line argument and saved model disagree on '%s' " % checkme
    infos['opt'] = opt

    #########################
    # Build logger
    #########################
    # naive dict logger
    histories = defaultdict(dict)
    if opt.start_from is not None and os.path.isfile(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')):
        with open(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl'), 'rb') as f:
            histories.update(utils.pickle_load(f))

    # tensorboard logger
    tb_summary_writer = SummaryWriter(opt.checkpoint_path)

    ##########################
    # Build model
    ##########################
    opt.vocab = loader.get_vocab()
    model = models.setup(opt).cuda()
    del opt.vocab
    # Load pretrained weights:
    if opt.start_from is not None and os.path.isfile(os.path.join(opt.start_from, 'model.pth')):
        model.load_state_dict(torch.load(os.path.join(opt.start_from, 'model.pth')))
    
    # Wrap generation model with loss function(used for training)
    # This allows loss function computed separately on each machine
    lw_model = LossWrapper(model, opt)
    # Wrap with dataparallel
    dp_model = torch.nn.DataParallel(model)
    dp_lw_model = torch.nn.DataParallel(lw_model)

    ##########################
    #  Build optimizer
    ##########################
    if opt.noamopt:
        assert opt.caption_model == 'transformer', 'noamopt can only work with transformer'
        optimizer = utils.get_std_opt(model, factor=opt.noamopt_factor, warmup=opt.noamopt_warmup)
    elif opt.reduce_on_plateau:
        optimizer = utils.build_optimizer(model.parameters(), opt)
        optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3)
    else:
        optimizer = utils.build_optimizer(model.parameters(), opt)
    # Load the optimizer
    if opt.start_from is not None and os.path.isfile(os.path.join(opt.start_from,"optimizer.pth")):
        optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    #########################
    # Get ready to start
    #########################
    iteration = infos['iter']
    epoch = infos['epoch']
    # For back compatibility
    if 'iterators' in infos:
        infos['loader_state_dict'] = {split: {'index_list': infos['split_ix'][split], 'iter_counter': infos['iterators'][split]} for split in ['train', 'val', 'test']}
    loader.load_state_dict(infos['loader_state_dict'])
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)
    if opt.noamopt:
        optimizer._step = iteration
    # flag indicating finish of an epoch
    # Always set to True at the beginning to initialize the lr or etc.
    epoch_done = True
    # Assure in training mode
    dp_lw_model.train()

    # Start training
    try:
        while True:
            if epoch_done:
                if not opt.noamopt and not opt.reduce_on_plateau:
                    # Assign the learning rate
                    if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                        frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every
                        decay_factor = opt.learning_rate_decay_rate  ** frac
                        opt.current_lr = opt.learning_rate * decay_factor
                    else:
                        opt.current_lr = opt.learning_rate
                    utils.set_lr(optimizer, opt.current_lr) # set the decayed rate
                # Assign the scheduled sampling prob
                if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                    frac = (epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every
                    opt.ss_prob = min(opt.scheduled_sampling_increase_prob  * frac, opt.scheduled_sampling_max_prob)
                    model.ss_prob = opt.ss_prob

                # If start self critical training
                if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                    sc_flag = True
                    init_scorer(opt.cached_tokens)
                else:
                    sc_flag = False
                
                # If start structure loss training
                if opt.structure_after != -1 and epoch >= opt.structure_after:
                    struc_flag = True
                    init_scorer(opt.cached_tokens)
                else:
                    struc_flag = False

                epoch_done = False
                    
            start = time.time()
            # Load data from train split (0)
            data = loader.get_batch('train')
            print('Read data:', time.time() - start)

            torch.cuda.synchronize()
            start = time.time()

            tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks']]
            tmp = [_ if _ is None else _.cuda() for _ in tmp]
            fc_feats, att_feats, labels, masks, att_masks = tmp
            
            optimizer.zero_grad()
            model_out = dp_lw_model(fc_feats, att_feats, labels, masks, att_masks, data['gts'], torch.arange(0, len(data['gts'])), sc_flag, struc_flag)

            loss = model_out['loss'].mean()

            loss.backward()
            utils.clip_gradient(optimizer, opt.grad_clip)
            optimizer.step()
            train_loss = loss.item()
            torch.cuda.synchronize()
            end = time.time()
            if struc_flag:
                print("iter {} (epoch {}), train_loss = {:.3f}, lm_loss = {:.3f}, struc_loss = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, train_loss, model_out['lm_loss'].mean().item(), model_out['struc_loss'].mean().item(), end - start))
            elif not sc_flag:
                print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, train_loss, end - start))
            else:
                print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, model_out['reward'].mean(), end - start))

            # Update the iteration and epoch
            iteration += 1
            if data['bounds']['wrapped']:
                epoch += 1
                epoch_done = True

            # Write the training loss summary
            if (iteration % opt.losses_log_every == 0):
                tb_summary_writer.add_scalar('train_loss', train_loss, iteration)
                if opt.noamopt:
                    opt.current_lr = optimizer.rate()
                elif opt.reduce_on_plateau:
                    opt.current_lr = optimizer.current_lr
                tb_summary_writer.add_scalar('learning_rate', opt.current_lr, iteration)
                tb_summary_writer.add_scalar('scheduled_sampling_prob', model.ss_prob, iteration)
                if sc_flag:
                    tb_summary_writer.add_scalar('avg_reward', model_out['reward'].mean(), iteration)
                elif struc_flag:
                    tb_summary_writer.add_scalar('lm_loss', model_out['lm_loss'].mean().item(), iteration)
                    tb_summary_writer.add_scalar('struc_loss', model_out['struc_loss'].mean().item(), iteration)
                    tb_summary_writer.add_scalar('reward', model_out['reward'].mean().item(), iteration)

                histories['loss_history'][iteration] = train_loss if not sc_flag else model_out['reward'].mean()
                histories['lr_history'][iteration] = opt.current_lr
                histories['ss_prob_history'][iteration] = model.ss_prob

            # update infos
            infos['iter'] = iteration
            infos['epoch'] = epoch
            infos['loader_state_dict'] = loader.state_dict()
            
            # make evaluation on validation set, and save model
            if (iteration % opt.save_checkpoint_every == 0):
                # eval model
                eval_kwargs = {'split': 'val',
                                'dataset': opt.input_json}
                eval_kwargs.update(vars(opt))
                val_loss, predictions, lang_stats = eval_utils.eval_split(
                    dp_model, lw_model.crit, loader, eval_kwargs)

                if opt.reduce_on_plateau:
                    if 'CIDEr' in lang_stats:
                        optimizer.scheduler_step(-lang_stats['CIDEr'])
                    else:
                        optimizer.scheduler_step(val_loss)
                # Write validation result into summary
                tb_summary_writer.add_scalar('validation loss', val_loss, iteration)
                if lang_stats is not None:
                    for k,v in lang_stats.items():
                        tb_summary_writer.add_scalar(k, v, iteration)
                histories['val_result_history'][iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions}

                # Save model if is improving on validation result
                if opt.language_eval == 1:
                    current_score = lang_stats['CIDEr']
                else:
                    current_score = - val_loss

                best_flag = False

                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True

                # Dump miscalleous informations
                infos['best_val_score'] = best_val_score

                utils.save_checkpoint(opt, model, infos, optimizer, histories)
                if opt.save_history_ckpt:
                    utils.save_checkpoint(opt, model, infos, optimizer, append=str(iteration))

                if best_flag:
                    utils.save_checkpoint(opt, model, infos, optimizer, append='best')

            # Stop if reaching max epochs
            if epoch >= opt.max_epochs and opt.max_epochs != -1:
                break
    except (RuntimeError, KeyboardInterrupt):
        print('Save ckpt on exception ...')
        utils.save_checkpoint(opt, model, infos, optimizer)
        print('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)
Exemple #4
0
def train(opt):

    ################################
    # Build dataloader
    ################################
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    ##########################
    # Initialize infos
    ##########################
    infos = {
        'iter': 0,
        'epoch': 0,
        'loader_state_dict': None,
        'vocab': loader.get_vocab(),
    }
    # Load old infos(if there is) and check if models are compatible
    if opt.start_from is not None and os.path.isfile(
            os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl')):
        with open(os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl'),
                  'rb') as f:
            infos = utils.pickle_load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert getattr(saved_model_opt, checkme) == getattr(
                    opt, checkme
                ), "Command line argument and saved model disagree on '%s' " % checkme
    infos['opt'] = opt

    #########################
    # Build logger
    #########################
    # naive dict logger
    histories = defaultdict(dict)
    if opt.start_from is not None and os.path.isfile(
            os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
        with open(os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl'),
                  'rb') as f:
            histories.update(utils.pickle_load(f))

    # tensorboard logger
    tb_summary_writer = SummaryWriter(opt.checkpoint_path)

    ##########################
    # Build model
    ##########################
    opt.vocab = loader.get_vocab()
    multi_models_list = []
    for order in range(opt.number_of_models):
        multi_models_list.append(models.setup(opt).cuda())
    for order in range(opt.number_of_models):
        multi_models_list.append(models.setup(opt).cuda())
    for order in range(opt.number_of_models, 2 * opt.number_of_models):
        for param in multi_models_list[order].parameters():
            param.detach_()
    for order in range(opt.number_of_models):
        for param, param_ema in zip(
                multi_models_list[order].parameters(),
                multi_models_list[order + opt.number_of_models].parameters()):
            param_ema.data = param.data.clone()
    # multi_models = MultiModels(multi_models_list)
    # multi_models_list.append(SenEncodeModel(opt).cuda())
    multi_models = nn.ModuleList(multi_models_list)
    del opt.vocab
    # Load pretrained weights:
    if opt.start_from is not None and os.path.isfile(
            os.path.join(opt.start_from, 'model.pth')):
        multi_models.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'model.pth')))

    # Wrap generation model with loss function(used for training)
    # This allows loss function computed separately on each machine
    lw_models = nn.ModuleList([
        LossWrapper(multi_models[index], opt)
        for index in range(opt.number_of_models)
    ])
    kdlw_models = nn.ModuleList([
        KDLossWrapper(multi_models[index], opt)
        for index in range(opt.number_of_models)
    ])
    lw_models_ema = nn.ModuleList([
        LossWrapper(multi_models[opt.number_of_models + index], opt)
        for index in range(opt.number_of_models)
    ])
    kdlw_models_ema = nn.ModuleList([
        KDLossWrapper(multi_models[opt.number_of_models + index], opt)
        for index in range(opt.number_of_models)
    ])
    # Wrap with dataparallel
    dp_models = nn.ModuleList([
        torch.nn.DataParallel(multi_models[index])
        for index in range(opt.number_of_models)
    ])
    dp_lw_models = nn.ModuleList([
        torch.nn.DataParallel(lw_models[index])
        for index in range(opt.number_of_models)
    ])
    dp_kdlw_models = nn.ModuleList([
        torch.nn.DataParallel(kdlw_models[index])
        for index in range(opt.number_of_models)
    ])
    dp_models_ema = nn.ModuleList([
        torch.nn.DataParallel(multi_models[opt.number_of_models + index])
        for index in range(opt.number_of_models)
    ])
    dp_lw_models_ema = nn.ModuleList([
        torch.nn.DataParallel(lw_models_ema[index])
        for index in range(opt.number_of_models)
    ])
    dp_kdlw_models_ema = nn.ModuleList([
        torch.nn.DataParallel(kdlw_models_ema[index])
        for index in range(opt.number_of_models)
    ])

    ##########################
    #  Build optimizer
    ##########################
    if opt.noamopt:
        assert opt.caption_model in [
            'transformer', 'bert', 'm2transformer'
        ], 'noamopt can only work with transformer'
        optimizer = utils.get_std_opt(multi_models,
                                      factor=opt.noamopt_factor,
                                      warmup=opt.noamopt_warmup)
    elif opt.reduce_on_plateau:
        optimizer = utils.build_optimizer(multi_models.parameters(), opt)
        optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3)
    else:
        optimizer = utils.build_optimizer(multi_models.parameters(), opt)
    # Load the optimizer
    if opt.start_from is not None and os.path.isfile(
            os.path.join(opt.start_from, "optimizer.pth")):
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    ##########################
    #  Build loss
    ##########################
    # triplet_loss = nn.TripletMarginLoss()

    #########################
    # Get ready to start
    #########################
    iteration = infos['iter']
    epoch = infos['epoch']
    # For back compatibility
    if 'iterators' in infos:
        infos['loader_state_dict'] = {
            split: {
                'index_list': infos['split_ix'][split],
                'iter_counter': infos['iterators'][split]
            }
            for split in [
                'paired_train', 'unpaired_images_train',
                'unpaired_captions_train', 'train', 'val', 'test'
            ]
        }
    loader.load_state_dict(infos['loader_state_dict'])
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)
    if opt.noamopt:
        optimizer._step = iteration
    # flag indicating finish of an epoch
    # Always set to True at the beginning to initialize the lr or etc.
    epoch_done = True
    # Assure in training mode
    dp_lw_models.train()
    dp_kdlw_models.train()
    dp_lw_models_ema.train()
    dp_kdlw_models_ema.train()

    # Build the ensemble model
    # # Setup the model
    model_ensemble = AttEnsemble(multi_models_list[opt.number_of_models:2 *
                                                   opt.number_of_models],
                                 weights=None)
    # model_ensemble.seq_length = 20
    model_ensemble.cuda()
    # model_ensemble.eval()
    kd_model_outs_list = []

    # Start training
    try:
        while True:
            # Stop if reaching max epochs
            if epoch >= opt.max_epochs and opt.max_epochs != -1:
                break

            if epoch_done:
                if not opt.noamopt and not opt.reduce_on_plateau:
                    # Assign the learning rate
                    if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                        frac = (epoch - opt.learning_rate_decay_start
                                ) // opt.learning_rate_decay_every
                        decay_factor = opt.learning_rate_decay_rate**frac
                        opt.current_lr = opt.learning_rate * decay_factor
                    else:
                        opt.current_lr = opt.learning_rate
                    utils.set_lr(optimizer,
                                 opt.current_lr)  # set the decayed rate
                # Assign the scheduled sampling prob
                if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                    frac = (epoch - opt.scheduled_sampling_start
                            ) // opt.scheduled_sampling_increase_every
                    opt.ss_prob = min(
                        opt.scheduled_sampling_increase_prob * frac,
                        opt.scheduled_sampling_max_prob)
                    for index in range(opt.number_of_models):
                        multi_models[index].ss_prob = opt.ss_prob

                # If start self critical training
                if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                    sc_flag = True
                    init_scorer(opt.cached_tokens)
                else:
                    sc_flag = False

                # If start structure loss training
                if opt.structure_after != -1 and epoch >= opt.structure_after:
                    struc_flag = True
                    init_scorer(opt.cached_tokens)
                else:
                    struc_flag = False

                if epoch >= opt.paired_train_epoch:
                    opt.current_lambda_x = opt.hyper_parameter_lambda_x * \
                                         (epoch - (opt.paired_train_epoch - 1)) /\
                                         (opt.max_epochs - opt.paired_train_epoch)
                    opt.current_lambda_y = opt.hyper_parameter_lambda_y * \
                                           (epoch - (opt.paired_train_epoch - 1)) / \
                                           (opt.max_epochs - opt.paired_train_epoch)

                epoch_done = False

            start = time.time()
            # Load data from train split (0)
            if epoch < opt.language_pretrain_epoch:
                data = loader.get_batch('unpaired_captions_train')
            elif epoch < opt.paired_train_epoch:
                data = loader.get_batch('paired_train')
            else:
                data = loader.get_batch('paired_train')
                unpaired_data = loader.get_batch('unpaired_images_train')
                unpaired_caption = loader.get_batch('unpaired_captions_train')
            print('Read data:', time.time() - start)

            torch.cuda.synchronize()
            start = time.time()
            if epoch < opt.language_pretrain_epoch:
                tmp = [
                    data['fc_feats'] * 0, data['att_feats'] * 0,
                    data['labels'], data['masks'], data['att_masks']
                ]
            elif epoch < opt.paired_train_epoch:
                tmp = [
                    data['fc_feats'], data['att_feats'], data['labels'],
                    data['masks'], data['att_masks']
                ]
            else:
                tmp = [
                    data['fc_feats'], data['att_feats'], data['labels'],
                    data['masks'], data['att_masks']
                ]
                unpaired_tmp = [
                    unpaired_data['fc_feats'], unpaired_data['att_feats'],
                    unpaired_data['labels'], unpaired_data['masks'],
                    unpaired_data['att_masks']
                ]
                unpaired_caption_tmp = [
                    unpaired_caption['fc_feats'] * 0,
                    unpaired_caption['att_feats'] * 0,
                    unpaired_caption['labels'], unpaired_caption['masks'],
                    unpaired_caption['att_masks']
                ]

            tmp = [_ if _ is None else _.cuda() for _ in tmp]
            fc_feats, att_feats, labels, masks, att_masks = tmp

            if epoch >= opt.paired_train_epoch:
                unpaired_tmp = [
                    _ if _ is None else _.cuda() for _ in unpaired_tmp
                ]
                unpaired_fc_feats, unpaired_att_feats, unpaired_labels, unpaired_masks, unpaired_att_masks = unpaired_tmp
                unpaired_caption_tmp = [
                    _ if _ is None else _.cuda() for _ in unpaired_caption_tmp
                ]
                unpaired_caption_fc_feats, unpaired_caption_att_feats, unpaired_caption_labels, unpaired_caption_masks, unpaired_caption_att_masks = unpaired_caption_tmp
                unpaired_caption_fc_feats = unpaired_caption_fc_feats.repeat(
                    5, 1)
                unpaired_caption_fc_feats = opt.std_pseudo_visual_feature * torch.randn_like(
                    unpaired_caption_fc_feats)
                unpaired_caption_att_feats = unpaired_caption_att_feats.repeat(
                    5, 1, 1)
                unpaired_caption_fc_feats.requires_grad = True
                unpaired_caption_att_feats.requires_grad = True
                unpaired_caption_labels = unpaired_caption_labels.reshape(
                    unpaired_caption_fc_feats.shape[0], -1)
                unpaired_caption_masks = unpaired_caption_masks.reshape(
                    unpaired_caption_fc_feats.shape[0], -1)

            optimizer.zero_grad()
            if epoch < opt.language_pretrain_epoch:
                language_loss = 0
                model_outs_list = []
                for index in range(opt.number_of_models):
                    model_out = dp_lw_models[index](
                        fc_feats, att_feats, labels, masks,
                        att_masks, data['gts'],
                        torch.arange(0, len(data['gts'])), sc_flag, struc_flag)
                    model_outs_list.append(model_out)
                    language_loss += model_out['loss'].mean()

                loss = language_loss
            elif epoch < opt.paired_train_epoch:
                language_loss = 0
                model_outs_list = []
                for index in range(opt.number_of_models):
                    model_out = dp_lw_models[index](
                        fc_feats, att_feats, labels, masks,
                        att_masks, data['gts'],
                        torch.arange(0, len(data['gts'])), sc_flag, struc_flag)
                    model_outs_list.append(model_out)
                    language_loss += model_out['loss'].mean()

                loss = language_loss
            else:
                language_loss = 0
                model_outs_list = []
                for index in range(opt.number_of_models):
                    model_out = dp_lw_models[index](
                        fc_feats, att_feats, labels, masks,
                        att_masks, data['gts'],
                        torch.arange(0, len(data['gts'])), sc_flag, struc_flag)
                    model_outs_list.append(model_out)
                    language_loss += model_out['loss'].mean()
                loss = language_loss

                # else:
                # for unpaired image sentences
                # # Setup the model
                # model_ensemble = AttEnsemble(multi_models_list[:opt.number_of_models], weights=None)
                # model_ensemble.seq_length = 16
                # model_ensemble.cuda()
                # model_ensemble.eval()

                model_ensemble.eval()
                eval_kwargs = dict()
                eval_kwargs.update(vars(opt))

                with torch.no_grad():
                    seq, seq_logprobs = model_ensemble(unpaired_fc_feats,
                                                       unpaired_att_feats,
                                                       unpaired_att_masks,
                                                       opt=eval_kwargs,
                                                       mode='sample')
                    # val_loss, predictions, lang_stats = eval_utils.eval_split(model_ensemble, lw_models[0].crit, loader,
                    #                                                           eval_kwargs)
                # print('\n'.join([utils.decode_sequence(loader.get_vocab(), _['seq'].unsqueeze(0))[0] for _ in
                #                  model_ensemble.done_beams[0]]))
                # print('++' * 10)
                # for ii in range(10):
                #     sents = utils.decode_sequence(loader.get_vocab(), seq[ii].unsqueeze(0))
                #     gt_sent = utils.decode_sequence(loader.get_vocab(), labels[ii,0].unsqueeze(0))
                #     a=1

                model_ensemble.train()

                model_ensemble_sudo_labels = labels.new_zeros(
                    (opt.batch_size, opt.beam_size,
                     eval_kwargs['max_length'] + 2))
                model_ensemble_sudo_log_prob = masks.new_zeros(
                    (opt.batch_size,
                     opt.beam_size, eval_kwargs['max_length'] + 2,
                     len(loader.get_vocab()) + 1))
                model_ensemble_sum_log_prob = masks.new_zeros(
                    (opt.batch_size, opt.beam_size))

                for batch_index in range(opt.batch_size):
                    for beam_index in range(opt.beam_size):
                        # for beam_index in range(3):
                        pred = model_ensemble.done_beams[batch_index][
                            beam_index]['seq']
                        log_prob = model_ensemble.done_beams[batch_index][
                            beam_index]['logps']
                        model_ensemble_sudo_labels[batch_index, beam_index,
                                                   1:pred.shape[0] + 1] = pred
                        model_ensemble_sudo_log_prob[batch_index, beam_index,
                                                     1:pred.shape[0] +
                                                     1] = log_prob
                        model_ensemble_sum_log_prob[batch_index][
                            beam_index] = model_ensemble.done_beams[
                                batch_index][beam_index]['p']

                # model_ensemble_prob = F.softmax(model_ensemble_sum_log_prob)

                data_ensemble_sudo_gts = list()
                for data_ensemble_sudo_gts_index in range(
                        model_ensemble_sudo_labels.shape[0]):
                    data_ensemble_sudo_gts.append(model_ensemble_sudo_labels[
                        data_ensemble_sudo_gts_index, :,
                        1:-1].data.cpu().numpy())

                # generated_sentences = list()
                # for i in range(unpaired_fc_feats.shape[0]):
                #     generated_sentences.append(
                #         [utils.decode_sequence(loader.get_vocab(), _['seq'].unsqueeze(0))[0] for _ in
                #          model_ensemble.done_beams[i]])
                #
                # pos_tag_results = list()
                # for i in range(unpaired_fc_feats.shape[0]):
                #     generated_sentences_i = generated_sentences[i]
                #     pos_tag_results_i = []
                #     for text in generated_sentences_i:
                #         text_tokenize = nltk.word_tokenize(text)
                #         pos_tag_results_i_jbeam = []
                #         for vob, vob_type in nltk.pos_tag(text_tokenize):
                #             if vob_type == 'NN' or vob_type == 'NNS':
                #                 pos_tag_results_i_jbeam.append(vob)
                #         pos_tag_results_i.append(pos_tag_results_i_jbeam)
                #     pos_tag_results.append(pos_tag_results_i)

                # for i in range(fc_feats.shape[0]):
                #     print('\n'.join([utils.decode_sequence(loader.get_vocab(), _['seq'].unsqueeze(0))[0] for _ in
                #                      model_ensemble.done_beams[i]]))
                #     print('--' * 10)
                # dets = data['dets']
                #
                # promising_flag = labels.new_zeros(opt.batch_size, opt.beam_size)
                # for batch_index in range(opt.batch_size):
                #     dets_batch = dets[batch_index]
                #     for beam_index in range(opt.beam_size):
                #         indicator = [0] * len(dets_batch)
                #         pos_tag_batch_beam = pos_tag_results[batch_index][beam_index]
                #         for pos_tag_val in pos_tag_batch_beam:
                #             for ii in range(len(dets_batch)):
                #                 possible_list = vob_transform_list[dets_batch[ii]]
                #                 if pos_tag_val in possible_list:
                #                     indicator[ii] = 1
                #         if sum(indicator) == len(dets_batch) or sum(indicator) >= 2:
                #             promising_flag[batch_index, beam_index] = 1
                #
                # # model_ensemble_sudo_log_prob = model_ensemble_sudo_log_prob * promising_flag.unsqueeze(-1).unsqueeze(-1)
                # model_ensemble_sudo_labels = model_ensemble_sudo_labels * promising_flag.unsqueeze(-1)

                #sudo_masks_for_model = sudo_masks_for_model.detach()
                distilling_loss = 0
                # We use the random study machinism
                who_to_study = random.randint(0, opt.number_of_models - 1)

                # for index in range(opt.number_of_models):
                #     model_out = dp_kdlw_models[index](unpaired_fc_feats, unpaired_att_feats, model_ensemble_sudo_labels,
                #                                     model_ensemble_sudo_log_prob, att_masks, data_ensemble_sudo_gts,
                #                                     torch.arange(0, len(data_ensemble_sudo_gts)), sc_flag,
                #                                     struc_flag, model_ensemble_sum_log_prob)
                #     kd_model_outs_list.append(model_out)

                model_out = dp_kdlw_models[who_to_study](
                    unpaired_fc_feats, unpaired_att_feats,
                    model_ensemble_sudo_labels, model_ensemble_sudo_log_prob,
                    att_masks, data_ensemble_sudo_gts,
                    torch.arange(0, len(data_ensemble_sudo_gts)), sc_flag,
                    struc_flag, model_ensemble_sum_log_prob)
                # kd_model_outs_list.append(model_out)
                distilling_loss += model_out['loss'].mean()
                loss += opt.number_of_models * opt.current_lambda_x * distilling_loss

                ###################################################################
                # use unlabelled captions
                # simple_sgd = utils.gradient_descent(unpaired_caption_fc_feats, stepsize=1e3)
                simple_sgd = utils.gradient_descent_adagrad(
                    unpaired_caption_fc_feats, stepsize=1)
                gts_tmp = unpaired_caption['gts']
                new_gts = []
                for ii in range(len(data['gts'])):
                    for jj in range(gts_tmp[ii].shape[0]):
                        new_gts.append(gts_tmp[ii][jj])
                unpaired_caption['gts'] = new_gts
                for itr in range(opt.inner_iteration):
                    unlabelled_caption_model_out = dp_lw_models_ema[
                        itr % opt.number_of_models](
                            unpaired_caption_fc_feats,
                            unpaired_caption_att_feats,
                            unpaired_caption_labels, unpaired_caption_masks,
                            unpaired_caption_att_masks,
                            unpaired_caption['gts'],
                            torch.arange(0, len(unpaired_caption['gts'])),
                            sc_flag, struc_flag)
                    unlabelled_caption_loss = unlabelled_caption_model_out[
                        'loss'].mean()
                    unlabelled_caption_loss.backward()
                    # print(unlabelled_caption_loss)
                    simple_sgd.update(unpaired_caption_fc_feats)
                    # a=1

                unpaired_caption_fc_feats.requires_grad = False
                unpaired_caption_att_feats.requires_grad = False
                unlabelled_caption_model_out = dp_lw_models[who_to_study](
                    unpaired_caption_fc_feats, unpaired_caption_att_feats,
                    unpaired_caption_labels, unpaired_caption_masks,
                    unpaired_caption_att_masks, unpaired_caption['gts'],
                    torch.arange(0, len(unpaired_caption['gts'])), sc_flag,
                    struc_flag)
                unlabelled_caption_loss = unlabelled_caption_model_out[
                    'loss'].mean()
                loss += opt.number_of_models * opt.current_lambda_y * unlabelled_caption_loss

            loss.backward()
            if opt.grad_clip_value != 0:
                getattr(torch.nn.utils, 'clip_grad_%s_' %
                        (opt.grad_clip_mode))(multi_models.parameters(),
                                              opt.grad_clip_value)
            optimizer.step()

            for order in range(opt.number_of_models):
                for param, param_ema in zip(
                        multi_models_list[order].parameters(),
                        multi_models_list[order +
                                          opt.number_of_models].parameters()):
                    param_ema.data = opt.alpha * param_ema.data + (
                        1 - opt.alpha) * param.data

            train_loss = loss.item()
            torch.cuda.synchronize()
            end = time.time()
            # if struc_flag:
            #     print("iter {} (epoch {}), train_loss = {:.3f}, lm_loss = {:.3f}, struc_loss = {:.3f}, time/batch = {:.3f}" \
            #         .format(iteration, epoch, train_loss, model_out['lm_loss'].mean().item(), model_out['struc_loss'].mean().item(), end - start))
            # elif not sc_flag:
            #     print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
            #         .format(iteration, epoch, train_loss, end - start))
            # else:
            #     print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \
            #         .format(iteration, epoch, model_out['reward'].mean(), end - start))
            if struc_flag:
                print("iter {} (epoch {}), train_loss = {:.3f}, lm_loss = {:.3f}, struc_loss = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, train_loss/opt.number_of_models, sum([model_outs_list[index]['lm_loss'].mean().item() for index in range(opt.number_of_models)])/opt.number_of_models,
                            sum([model_outs_list[index]['struc_loss'].mean().item() for index in range(opt.number_of_models)])/opt.number_of_models,
                            end - start))
            elif not sc_flag:
                print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, language_loss.item()/opt.number_of_models, end - start))
            else:
                print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, sum([model_outs_list[index]['reward'].mean().item() for index in range(opt.number_of_models)])/opt.number_of_models, end - start))

            # Update the iteration and epoch
            iteration += 1
            if epoch < opt.paired_train_epoch:
                if data['bounds']['wrapped']:
                    epoch += 1
                    epoch_done = True
            else:
                if data['bounds']['wrapped']:
                    epoch += 1
                    epoch_done = True

            # Write the training loss summary
            if (iteration % opt.losses_log_every == 0):
                # tb_summary_writer.add_scalar('train_loss', train_loss, iteration)
                for index in range(opt.number_of_models):
                    model_id = 'model_{}'.format(index)
                    tb_summary_writer.add_scalars('language_loss', {
                        model_id:
                        model_outs_list[index]['loss'].mean().item()
                    }, iteration)
                if epoch >= opt.paired_train_epoch:
                    # for index in range(opt.number_of_models):
                    #     model_id = 'model_{}'.format(index)
                    #     kd_model_outs_val = 0 if len(kd_model_outs_list) == 0 else kd_model_outs_list[index]['loss'].mean().item()
                    #     tb_summary_writer.add_scalars('distilling_loss',
                    #                                   {model_id: kd_model_outs_val},
                    #                                   iteration)
                    tb_summary_writer.add_scalar('distilling_loss',
                                                 distilling_loss.item(),
                                                 iteration)
                    tb_summary_writer.add_scalar(
                        'unlabelled_caption_loss',
                        unlabelled_caption_loss.item(), iteration)
                    tb_summary_writer.add_scalar('hyper_parameter_lambda_x',
                                                 opt.current_lambda_x,
                                                 iteration)
                    tb_summary_writer.add_scalar('hyper_parameter_lambda_y',
                                                 opt.current_lambda_y,
                                                 iteration)
                # tb_summary_writer.add_scalar('triplet_loss', triplet_loss_val.item(), iteration)
                if opt.noamopt:
                    opt.current_lr = optimizer.rate()
                elif opt.reduce_on_plateau:
                    opt.current_lr = optimizer.current_lr
                tb_summary_writer.add_scalar('learning_rate', opt.current_lr,
                                             iteration)
                tb_summary_writer.add_scalar('scheduled_sampling_prob',
                                             multi_models[0].ss_prob,
                                             iteration)
                if sc_flag:
                    for index in range(opt.number_of_models):
                        # tb_summary_writer.add_scalar('avg_reward', model_out['reward'].mean(), iteration)
                        model_id = 'model_{}'.format(index)
                        tb_summary_writer.add_scalars(
                            'avg_reward', {
                                model_id:
                                model_outs_list[index]['reward'].mean().item()
                            }, iteration)
                elif struc_flag:
                    # tb_summary_writer.add_scalar('lm_loss', model_out['lm_loss'].mean().item(), iteration)
                    # tb_summary_writer.add_scalar('struc_loss', model_out['struc_loss'].mean().item(), iteration)
                    # tb_summary_writer.add_scalar('reward', model_out['reward'].mean().item(), iteration)
                    # tb_summary_writer.add_scalar('reward_var', model_out['reward'].var(1).mean(), iteration)
                    model_id = 'model_{}'.format(index)
                    for index in range(opt.number_of_models):
                        tb_summary_writer.add_scalars(
                            'lm_loss', {
                                model_id:
                                model_outs_list[index]
                                ['lm_loss'].mean().item()
                            }, iteration)
                        tb_summary_writer.add_scalars(
                            'struc_loss', {
                                model_id:
                                model_outs_list[index]
                                ['struc_loss'].mean().item()
                            }, iteration)
                        tb_summary_writer.add_scalars(
                            'reward', {
                                model_id:
                                model_outs_list[index]['reward'].mean().item()
                            }, iteration)
                        tb_summary_writer.add_scalars(
                            'reward_var', {
                                model_id:
                                model_outs_list[index]['reward'].var(1).mean()
                            }, iteration)

                histories['loss_history'][
                    iteration] = train_loss if not sc_flag else sum([
                        model_outs_list[index]['reward'].mean().item()
                        for index in range(opt.number_of_models)
                    ]) / opt.number_of_models
                histories['lr_history'][iteration] = opt.current_lr
                histories['ss_prob_history'][iteration] = multi_models[
                    0].ss_prob

            # update infos
            infos['iter'] = iteration
            infos['epoch'] = epoch
            infos['loader_state_dict'] = loader.state_dict()

            # make evaluation on validation set, and save model
            if (iteration % opt.save_checkpoint_every == 0 and not opt.save_every_epoch and epoch >= opt.paired_train_epoch) or \
                (epoch_done and opt.save_every_epoch and epoch >= opt.paired_train_epoch):
                # load ensemble
                # Setup the model
                model = AttEnsemble(multi_models_list[opt.number_of_models:2 *
                                                      opt.number_of_models],
                                    weights=None)
                model.seq_length = opt.max_length
                model.cuda()
                model.eval()
                # eval model
                eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
                eval_kwargs.update(vars(opt))
                # eval_kwargs['beam_size'] = 5
                # eval_kwargs['verbose_beam'] = 1
                # eval_kwargs['verbose_loss'] = 1
                # val_loss, predictions, lang_stats = eval_utils.eval_split(
                #     dp_model, lw_model.crit, loader, eval_kwargs)
                with torch.no_grad():
                    val_loss, predictions, lang_stats = eval_utils.eval_split(
                        model, lw_models[0].crit, loader, eval_kwargs)
                model.train()

                if opt.reduce_on_plateau:
                    if 'CIDEr' in lang_stats:
                        optimizer.scheduler_step(-lang_stats['CIDEr'])
                    else:
                        optimizer.scheduler_step(val_loss)
                # Write validation result into summary
                tb_summary_writer.add_scalar('validation loss', val_loss,
                                             iteration)
                if lang_stats is not None:
                    for k, v in lang_stats.items():
                        tb_summary_writer.add_scalar(k, v, iteration)
                histories['val_result_history'][iteration] = {
                    'loss': val_loss,
                    'lang_stats': lang_stats,
                    'predictions': predictions
                }

                # Save model if is improving on validation result
                if opt.language_eval == 1:
                    current_score = lang_stats['CIDEr']
                else:
                    current_score = -val_loss

                best_flag = False

                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True

                # Dump miscalleous informations
                infos['best_val_score'] = best_val_score

                utils.save_checkpoint(opt, multi_models, infos, optimizer,
                                      histories)
                if opt.save_history_ckpt:
                    utils.save_checkpoint(
                        opt,
                        multi_models,
                        infos,
                        optimizer,
                        append=str(epoch)
                        if opt.save_every_epoch else str(iteration))

                if best_flag:
                    utils.save_checkpoint(opt,
                                          multi_models,
                                          infos,
                                          optimizer,
                                          append='best')

            # if epoch_done and epoch == opt.paired_train_epoch:
            #     utils.save_checkpoint(opt, multi_models, infos, optimizer, histories)
            #     if opt.save_history_ckpt:
            #         utils.save_checkpoint(opt, multi_models, infos, optimizer,
            #                               append=str(epoch) if opt.save_every_epoch else str(iteration))
            #     cmd = 'cp -r ' + 'log_' + opt.id + ' ' + 'log_' + opt.id + '_backup'
            #     os.system(cmd)

    except (RuntimeError, KeyboardInterrupt):
        print('Save ckpt on exception ...')
        utils.save_checkpoint(opt, multi_models, infos, optimizer)
        print('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)
Exemple #5
0
def train(opt):
    print("=================Training Information==============")
    print("start from {}".format(opt.start_from))
    print("box from {}".format(opt.input_box_dir))
    print("input json {}".format(opt.input_json))
    print("attributes from {}".format(opt.input_att_dir))
    print("features from {}".format(opt.input_fc_dir))
    print("batch size ={}".format(opt.batch_size))
    print("#GPU={}".format(torch.cuda.device_count()))
    # Deal with feature things before anything
    opt.use_fc, opt.use_att = utils.if_use_feat(opt.caption_model)
    if opt.use_box:
        opt.att_feat_size = opt.att_feat_size + 5

    acc_steps = getattr(opt, 'acc_steps', 1)
    name_append = opt.name_append
    if len(name_append) > 0 and name_append[0] != '-':
        name_append = '_' + name_append

    loader = DataLoader(opt)

    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length
    opt.write_summary = write_summary
    if opt.write_summary:
        print("write summary to {}".format(opt.checkpoint_path))
        tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    infos = {}
    histories = {}

    if opt.start_from is not None:
        # open old infos and check if models are compatible
        infors_path = os.path.join(opt.start_from,
                                   'infos' + name_append + '.pkl')
        print("Load model information {}".format(infors_path))
        with open(infors_path, 'rb') as f:
            infos = utils.pickle_load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        histories_path = os.path.join(opt.start_from,
                                      'histories_' + name_append + '.pkl')
        if os.path.isfile(histories_path):
            with open(histories_path, 'rb') as f:
                histories = utils.pickle_load(f)
    else:  # start from scratch
        print("Initialize training process from all begining")
        infos['iter'] = 0
        infos['epoch'] = 0
        infos['iterators'] = loader.iterators
        infos['split_ix'] = loader.split_ix
        infos['vocab'] = loader.get_vocab()
    infos['opt'] = opt

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    #  sanity check for the saved model name has a correct index
    if opt.name_append.isdigit() and int(opt.name_append) < 100:
        assert int(
            opt.name_append
        ) == epoch, "dismatch in the model index and the real epoch number"
        epoch += 1

    print(
        "==================start from {} epoch================".format(epoch))
    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})
    # pdb.set_trace()
    loader.iterators = infos.get('iterators', loader.iterators)
    start_Img_idx = loader.iterators['train']
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    opt.vocab = loader.get_vocab()
    model = models.setup(opt).cuda()
    del opt.vocab
    dp_model = torch.nn.DataParallel(model)
    lw_model = LossWrapper(model, opt)  # wrap loss into model
    dp_lw_model = torch.nn.DataParallel(lw_model)

    epoch_done = True
    # Assure in training mode
    dp_lw_model.train()

    if opt.noamopt:
        assert opt.caption_model in [
            'transformer', 'aoa'
        ], 'noamopt can only work with transformer'
        optimizer = utils.get_std_opt(model,
                                      factor=opt.noamopt_factor,
                                      warmup=opt.noamopt_warmup)
        optimizer._step = iteration
    elif opt.reduce_on_plateau:
        optimizer = utils.build_optimizer(model.parameters(), opt)
        optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3)
    else:
        optimizer = utils.build_optimizer(model.parameters(), opt)
    # Load the optimizer
    if vars(opt).get('start_from', None) is not None:
        optimizer_path = os.path.join(opt.start_from,
                                      'optimizer' + name_append + '.pth')
        if os.path.isfile(optimizer_path):
            print("Loading optimizer............")
            optimizer.load_state_dict(torch.load(optimizer_path))

    def save_checkpoint(model, infos, optimizer, histories=None, append=''):
        if len(append) > 0:
            append = '_' + append
        # if checkpoint_path doesn't exist
        if not os.path.isdir(opt.checkpoint_path):
            os.makedirs(opt.checkpoint_path)
        checkpoint_path = os.path.join(opt.checkpoint_path,
                                       'model%s.pth' % (append))
        torch.save(model.state_dict(), checkpoint_path)
        print("Save model state to {}".format(checkpoint_path))

        optimizer_path = os.path.join(opt.checkpoint_path,
                                      'optimizer%s.pth' % (append))
        torch.save(optimizer.state_dict(), optimizer_path)
        print("Save model optimizer to {}".format(optimizer_path))

        with open(
                os.path.join(opt.checkpoint_path,
                             'infos' + '%s.pkl' % (append)), 'wb') as f:
            utils.pickle_dump(infos, f)
            print("Save training information to {}".format(
                os.path.join(opt.checkpoint_path,
                             'infos' + '%s.pkl' % (append))))

        if histories:
            with open(
                    os.path.join(opt.checkpoint_path,
                                 'histories_' + '%s.pkl' % (append)),
                    'wb') as f:
                utils.pickle_dump(histories, f)
                print("Save training historyes to {}".format(
                    os.path.join(opt.checkpoint_path,
                                 'histories_' + opt.id + '%s.pkl' % (append))))

    try:
        while True:
            # pdb.set_trace()
            if epoch_done:
                if not opt.noamopt and not opt.reduce_on_plateau:
                    # Assign the learning rate
                    if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                        frac = (epoch - opt.learning_rate_decay_start
                                ) // opt.learning_rate_decay_every
                        decay_factor = opt.learning_rate_decay_rate**frac
                        opt.current_lr = opt.learning_rate * decay_factor
                    else:
                        opt.current_lr = opt.learning_rate
                    utils.set_lr(optimizer,
                                 opt.current_lr)  # set the decayed rate
                # Assign the scheduled sampling prob
                if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                    frac = (epoch - opt.scheduled_sampling_start
                            ) // opt.scheduled_sampling_increase_every
                    opt.ss_prob = min(
                        opt.scheduled_sampling_increase_prob * frac,
                        opt.scheduled_sampling_max_prob)
                    model.ss_prob = opt.ss_prob

                # If start self critical training
                if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                    sc_flag = True
                    init_scorer(opt.cached_tokens)
                else:
                    sc_flag = False

                epoch_done = False
            print("{}th Epoch Training starts now!".format(epoch))
            with tqdm(total=len(loader.split_ix['train']),
                      initial=start_Img_idx) as pbar:
                for i in range(start_Img_idx, len(loader.split_ix['train']),
                               opt.batch_size):
                    # import ipdb; ipdb.set_trace()
                    start = time.time()
                    if (opt.use_warmup
                            == 1) and (iteration < opt.noamopt_warmup):
                        opt.current_lr = opt.learning_rate * (
                            iteration + 1) / opt.noamopt_warmup
                        utils.set_lr(optimizer, opt.current_lr)
                    # Load data from train split (0)
                    data = loader.get_batch('train')
                    # print('Read data:', time.time() - start)

                    if (iteration % acc_steps == 0):
                        optimizer.zero_grad()

                    torch.cuda.synchronize()
                    start = time.time()
                    tmp = [
                        data['fc_feats'], data['att_feats'], data['labels'],
                        data['masks'], data['att_masks']
                    ]
                    tmp = [_ if _ is None else _.cuda() for _ in tmp]
                    fc_feats, att_feats, labels, masks, att_masks = tmp

                    model_out = dp_lw_model(fc_feats, att_feats, labels, masks,
                                            att_masks, data['gts'],
                                            torch.arange(0, len(data['gts'])),
                                            sc_flag)

                    loss = model_out['loss'].mean()
                    loss_sp = loss / acc_steps

                    loss_sp.backward()
                    if ((iteration + 1) % acc_steps == 0):
                        utils.clip_gradient(optimizer, opt.grad_clip)
                        optimizer.step()
                    torch.cuda.synchronize()
                    train_loss = loss.item()
                    end = time.time()
                    # if not sc_flag:
                    #     print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}"
                    #         .format(iteration, epoch, train_loss, end - start))
                    # else:
                    #     print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}"
                    #         .format(iteration, epoch, model_out['reward'].mean(), end - start))
                    if not sc_flag:
                        pbar.set_description(
                            "iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}"
                            .format(iteration, epoch, train_loss, end - start))
                    else:
                        pbar.set_description(
                            "iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}"
                            .format(iteration, epoch,
                                    model_out['reward'].mean(), end - start))

                    # Update the iteration and epoch
                    iteration += 1
                    pbar.update(opt.batch_size)
                    if data['bounds']['wrapped']:
                        # save after each epoch
                        save_checkpoint(model,
                                        infos,
                                        optimizer,
                                        append=str(epoch))
                        epoch += 1
                        # infos['epoch'] = epoch
                        epoch_done = True

                    # Write validation result into summary
                    if (iteration % opt.losses_log_every
                            == 0) and opt.write_summary:
                        add_summary_value(tb_summary_writer, 'loss/train_loss',
                                          train_loss, iteration)
                        if opt.noamopt:
                            opt.current_lr = optimizer.rate()
                        elif opt.reduce_on_plateau:
                            opt.current_lr = optimizer.current_lr
                        add_summary_value(tb_summary_writer,
                                          'hyperparam/learning_rate',
                                          opt.current_lr, iteration)
                        add_summary_value(
                            tb_summary_writer,
                            'hyperparam/scheduled_sampling_prob',
                            model.ss_prob, iteration)
                        if sc_flag:
                            add_summary_value(tb_summary_writer, 'avg_reward',
                                              model_out['reward'].mean(),
                                              iteration)

                        loss_history[
                            iteration] = train_loss if not sc_flag else model_out[
                                'reward'].mean()
                        lr_history[iteration] = opt.current_lr
                        ss_prob_history[iteration] = model.ss_prob

                    # update infos
                    infos['iter'] = iteration
                    infos['epoch'] = epoch
                    infos['iterators'] = loader.iterators
                    infos['split_ix'] = loader.split_ix

                    # make evaluation on validation set, and save model
                    # TODO modify it to evaluate by each epoch
                    # ipdb.set_trace()
                    if (iteration % opt.save_checkpoint_every
                            == 0) and eval_ and epoch > 20:
                        model_path = os.path.join(
                            opt.checkpoint_path,
                            'model_itr%s.pth' % (iteration))
                        eval_kwargs = {
                            'split': 'val',
                            'dataset': opt.input_json,
                            'model': model_path
                        }
                        eval_kwargs.update(vars(opt))
                        val_loss, predictions, lang_stats = eval_utils.eval_split(
                            dp_model, lw_model.crit, loader, eval_kwargs)

                        if opt.reduce_on_plateau:
                            if 'CIDEr' in lang_stats:
                                optimizer.scheduler_step(-lang_stats['CIDEr'])
                            else:
                                optimizer.scheduler_step(val_loss)

                        # Write validation result into summary
                        if opt.write_summary:
                            add_summary_value(tb_summary_writer,
                                              'loss/validation loss', val_loss,
                                              iteration)

                            if lang_stats is not None:
                                bleu_dict = {}
                                for k, v in lang_stats.items():
                                    if 'Bleu' in k:
                                        bleu_dict[k] = v
                                if len(bleu_dict) > 0:
                                    tb_summary_writer.add_scalars(
                                        'val/Bleu', bleu_dict, epoch)

                                for k, v in lang_stats.items():
                                    if 'Bleu' not in k:
                                        add_summary_value(
                                            tb_summary_writer, 'val/' + k, v,
                                            iteration)
                        val_result_history[iteration] = {
                            'loss': val_loss,
                            'lang_stats': lang_stats,
                            'predictions': predictions
                        }

                        # Save model if is improving on validation result
                        if opt.language_eval == 1:
                            current_score = lang_stats['CIDEr']
                        else:
                            current_score = -val_loss

                        best_flag = False

                        if best_val_score is None or current_score > best_val_score:
                            best_val_score = current_score
                            best_flag = True

                        # Dump miscalleous informations
                        infos['best_val_score'] = best_val_score
                        histories['val_result_history'] = val_result_history
                        histories['loss_history'] = loss_history
                        histories['lr_history'] = lr_history
                        histories['ss_prob_history'] = ss_prob_history

                        # save_checkpoint(model, infos, optimizer, histories, append=str(iteration))
                        save_checkpoint(model, infos, optimizer, histories)
                        # if opt.save_history_ckpt:
                        #     save_checkpoint(model, infos, optimizer, append=str(iteration))

                        if best_flag:
                            save_checkpoint(model,
                                            infos,
                                            optimizer,
                                            append='best')
                            print(
                                "update best model at {} iteration--{} epoch".
                                format(iteration, epoch))

                    start_Img_idx = 0
                    # if epoch_done: # go through the set, start a new epoch loop
                    #     break
            # Stop if reaching max epochs
            if epoch >= opt.max_epochs and opt.max_epochs != -1:
                print("epoch {} break all".format(epoch))
                save_checkpoint(model, infos, optimizer)
                tb_summary_writer.close()
                print("============{} Training Done !==============".format(
                    'Refine' if opt.use_test or opt.use_val else ''))
                break
    except (RuntimeError, KeyboardInterrupt):
        print('Save ckpt on exception ...')
        save_checkpoint(model, infos, optimizer, append='_interrupt')
        print('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)
Exemple #6
0
def train(opt):
    print(opt)

    # To reproduce training results
    init_seed()
    # Image Preprocessing
    # For normalization, see https://github.com/pytorch/vision#models
    transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(degrees=10),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406),(0.229, 0.224, 0.225))
                             ])
    # Deal with feature things before anything
    opt.use_fc, opt.use_att = utils.if_use_feat(opt.caption_model)
    if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5

    loader = DataLoader(opt, transform=transform)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from, 'infos_' + opt.id + '-best.pkl'), 'rb') as f:
            infos = utils.pickle_load(f)
            saved_model_opt = infos['opt']
            need_be_same = ["caption_model", "rnn_type", "rnn_size", "num_layers"]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(opt)[
                    checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl'), 'rb') as f:
                histories = utils.pickle_load(f)
    else:
        infos['iter'] = 0
        infos['epoch'] = 0
        infos['iterators'] = loader.iterators
        infos['split_ix'] = loader.split_ix
        infos['vocab'] = loader.get_vocab()
    infos['opt'] = opt

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    opt.vocab = loader.get_vocab()
    if torch.cuda.is_available():
        model = models.setup(opt).cuda()
    else:
        model = models.setup(opt)
    del opt.vocab
    dp_model = torch.nn.DataParallel(model)
    lw_model = LossWrapper(model, opt)
    dp_lw_model = torch.nn.DataParallel(lw_model)
    #fgm = FGM(model)

    cnn_model = ResnetBackbone()
    if torch.cuda.is_available():
        cnn_model = cnn_model.cuda()
    if opt.start_from is not None:
        model_dict = cnn_model.state_dict()
        predict_dict = torch.load(os.path.join(opt.start_from, 'cnn_model-best.pth'))
        model_dict = {k: predict_dict["module."+k] for k, _ in model_dict.items() if "module."+ k in predict_dict}
        cnn_model.load_state_dict(model_dict)
    cnn_model = torch.nn.DataParallel(cnn_model)

    epoch_done = True
    # Assure in training mode
    dp_lw_model.train()

    if opt.noamopt:
        assert opt.caption_model == 'transformer', 'noamopt can only work with transformer'
        optimizer = utils.get_std_opt(model, factor=opt.noamopt_factor, warmup=opt.noamopt_warmup)
        optimizer._step = iteration
    elif opt.reduce_on_plateau:
        optimizer = utils.build_optimizer(model.parameters(), opt)
        optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3)
    else:
        optimizer = utils.build_optimizer(model.parameters(), opt)
    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(os.path.join(opt.start_from, "optimizer.pth")):
        optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'optimizer-best.pth')))

    def save_checkpoint(model, cnn_model, infos, optimizer, histories=None, append=''):
        if len(append) > 0:
            append = '-' + append
        # if checkpoint_path doesn't exist
        if not os.path.isdir(opt.checkpoint_path):
            os.makedirs(opt.checkpoint_path)
        #Transformer model
        checkpoint_path = os.path.join(opt.checkpoint_path, 'model%s.pth' % (append))
        torch.save(model.state_dict(), checkpoint_path)
        print("model saved to {}".format(checkpoint_path))
        #CNN model
        checkpoint_path = os.path.join(opt.checkpoint_path, 'cnn_model%s.pth' % (append))
        if not os.path.exists(checkpoint_path):
            torch.save(cnn_model.state_dict(), checkpoint_path)
            print("model saved to {}".format(checkpoint_path))
        optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer%s.pth' % (append))
        torch.save(optimizer.state_dict(), optimizer_path)
        with open(os.path.join(opt.checkpoint_path, 'infos_' + opt.id + '%s.pkl' % (append)), 'wb') as f:
            utils.pickle_dump(infos, f)
        if histories:
            with open(os.path.join(opt.checkpoint_path, 'histories_' + opt.id + '%s.pkl' % (append)), 'wb') as f:
                utils.pickle_dump(histories, f)

    cnn_after = 3
    try:
        while True:
            if epoch_done:
                if  opt.fix_cnn or epoch < cnn_after:
                    for p in cnn_model.parameters():
                        p.requires_grad = False
                    cnn_model.eval()
                    cnn_optimizer = None
                else:
                    for p in cnn_model.parameters():
                        p.requires_grad = True
                    # Fix the first few layers:
                    for module in cnn_model._modules['module']._modules['resnet_conv'][:5]._modules.values():
                        for p in module.parameters():
                            p.requires_grad = False
                    cnn_model.train()
                    # Constructing CNN parameters for optimization, only fine-tuning higher layers
                    cnn_optimizer = torch.optim.Adam(
                        (filter(lambda p: p.requires_grad, cnn_model.parameters())),
                        lr=2e-6 if (opt.self_critical_after != -1 and epoch >= opt.self_critical_after) else 5e-5, betas=(0.8, 0.999))

                if not opt.noamopt and not opt.reduce_on_plateau:
                    # Assign the learning rate
                    if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                        frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every
                        decay_factor = opt.learning_rate_decay_rate ** frac
                        opt.current_lr = opt.learning_rate * decay_factor
                    else:
                        opt.current_lr = opt.learning_rate
                    utils.set_lr(optimizer, opt.current_lr)  # set the decayed rate
                # Assign the scheduled sampling prob
                if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                    frac = (epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every
                    opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac, opt.scheduled_sampling_max_prob)
                    model.ss_prob = opt.ss_prob

                # If start self critical training
                if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                    sc_flag = True
                    init_scorer(opt.cached_tokens)
                else:
                    sc_flag = False

                epoch_done = False

            start = time.time()
            # Load data from train split (0)
            data = loader.get_batch('train')
            if iteration % opt.losses_log_every == 0:
                print('Read data:', time.time() - start)

            if torch.cuda.is_available():
                torch.cuda.synchronize()
            start = time.time()

            if torch.cuda.is_available():
                data['att_feats'] = cnn_model( data['att_feats'].cuda())
            else:
                data['att_feats'] = cnn_model( data['att_feats'] )
            data['att_feats'] = repeat_feat(data['att_feats'])
            tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks']]
            if torch.cuda.is_available():
                tmp = [_ if _ is None else _.cuda() for _ in tmp]
            fc_feats, att_feats, labels, masks, att_masks = tmp

            optimizer.zero_grad()
            if cnn_optimizer is not None:
                cnn_optimizer.zero_grad()

            # if epoch >= cnn_after:
            #     att_feats.register_hook(save_grad("att_feats"))
            model_out = dp_lw_model(fc_feats, att_feats, labels, masks, att_masks, data['gts'],
                                    torch.arange(0, len(data['gts'])), sc_flag)

            loss = model_out['loss'].mean()

            loss.backward()

            #loss.backward(retain_graph=True)

            # adversarial training
            #fgm.attack(emb_name='model.tgt_embed.0.lut.weight')
            #adv_out = dp_lw_model(fc_feats, att_feats, labels, masks, att_masks, data['gts'],
            #                      torch.arange(0, len(data['gts'])), sc_flag)

            #adv_loss = adv_out['loss'].mean()
            #adv_loss.backward()
            #fgm.restore(emb_name="model.tgt_embed.0.lut.weight")


            # utils.clip_gradient(optimizer, opt.grad_clip)
            optimizer.step()
            if cnn_optimizer is not None:
                cnn_optimizer.step()
            train_loss = loss.item()
            if torch.cuda.is_available():
                torch.cuda.synchronize()
            end = time.time()
            if not sc_flag and iteration % opt.losses_log_every == 0:
                print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                      .format(iteration, epoch, train_loss, end - start))
            elif iteration % opt.losses_log_every == 0:
                print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \
                      .format(iteration, epoch, model_out['reward'].mean(), end - start))

            # Update the iteration and epoch
            iteration += 1
            if data['bounds']['wrapped']:
                epoch += 1
                epoch_done = True

            # Write the training loss summary
            if (iteration % opt.losses_log_every == 0):
                add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration)
                if opt.noamopt:
                    opt.current_lr = optimizer.rate()
                elif opt.reduce_on_plateau:
                    opt.current_lr = optimizer.current_lr
                add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration)
                add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration)
                if sc_flag:
                    add_summary_value(tb_summary_writer, 'avg_reward', model_out['reward'].mean(), iteration)

                loss_history[iteration] = train_loss if not sc_flag else model_out['reward'].mean()
                lr_history[iteration] = opt.current_lr
                ss_prob_history[iteration] = model.ss_prob

            # update infos
            infos['iter'] = iteration
            infos['epoch'] = epoch
            infos['iterators'] = loader.iterators
            infos['split_ix'] = loader.split_ix

            # make evaluation on validation set, and save model
            if (iteration % opt.save_checkpoint_every == 0):
                # eval model
                eval_kwargs = {'split': 'val',
                               'dataset': opt.input_json}
                eval_kwargs.update(vars(opt))
                eval_kwargs["cnn_model"] = cnn_model
                val_loss, predictions, lang_stats = eval_utils.eval_split(
                    dp_model, lw_model.crit, loader, eval_kwargs)

                if opt.reduce_on_plateau:
                    if 'CIDEr' in lang_stats:
                        optimizer.scheduler_step(-lang_stats['CIDEr'])
                    else:
                        optimizer.scheduler_step(val_loss)
                # Write validation result into summary
                add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration)
                if lang_stats is not None:
                    for k, v in lang_stats.items():
                        add_summary_value(tb_summary_writer, k, v, iteration)
                val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions}

                # Save model if is improving on validation result
                if opt.language_eval == 1:
                    current_score = lang_stats['CIDEr']
                else:
                    current_score = - val_loss

                best_flag = False

                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True

                # Dump miscalleous informations
                infos['best_val_score'] = best_val_score
                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history

                save_checkpoint(model, cnn_model, infos, optimizer, histories)
                if opt.save_history_ckpt:
                    save_checkpoint(model, cnn_model, infos, optimizer, append=str(iteration))

                if best_flag:
                    save_checkpoint(model, cnn_model, infos, optimizer, append='best')

            # Stop if reaching max epochs
            if epoch >= opt.max_epochs and opt.max_epochs != -1:
                break
    except (RuntimeError, KeyboardInterrupt):
        print('Save ckpt on exception ...')
        save_checkpoint(model, cnn_model, infos, optimizer)
        print('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)
Exemple #7
0
def train(opt):
    acc_steps = getattr(opt, 'acc_steps', 1)

    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length
    opt.ix_to_word = loader.ix_to_word

    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        with open(os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl'), 'rb') as f:
            infos = utils.pickle_load(f)
            saved_model_opt = infos['opt']
            need_be_same = ["caption_model", "rnn_type", "rnn_size", "num_layers"]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], "Command line argument and saved model disagree on '%s' " % checkme
        if os.path.isfile(os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl'), 'rb') as f:
                histories = utils.pickle_load(f)
    else:
        infos['iter'] = 0
        infos['epoch'] = 0
        infos['iterators'] = loader.iterators
        infos['split_ix'] = loader.split_ix
        infos['vocab'] = loader.get_vocab()
    infos['opt'] = opt

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    opt.vocab = loader.get_vocab()
    model = models.setup(opt).cuda()
    del opt.vocab
    dp_model = torch.nn.DataParallel(model)
    lw_model = LossWrapper(model, opt)
    dp_lw_model = torch.nn.DataParallel(lw_model)

    epoch_done = True
    # Assure in training mode
    dp_lw_model.train()

    if opt.noamopt:
        optimizer = utils.get_std_opt(model, factor=opt.noamopt_factor, warmup=opt.noamopt_warmup)
        optimizer._step = iteration
    elif opt.reduce_on_plateau:
        optimizer = utils.build_optimizer(model.parameters(), opt)
        optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3)
    else:
        optimizer = utils.build_optimizer(model.parameters(), opt)

    def save_checkpoint(model, infos, optimizer, histories=None, append=''):
        if len(append) > 0:
            append = '-' + append
        # if checkpoint_path doesn't exist
        if not os.path.isdir(opt.checkpoint_path):
            os.makedirs(opt.checkpoint_path)
        checkpoint_path = os.path.join(opt.checkpoint_path, 'model%s.pth' %(append))
        torch.save(model.state_dict(), checkpoint_path)
        print("model saved to {}".format(checkpoint_path))
        optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer%s.pth' %(append))
        torch.save(optimizer.state_dict(), optimizer_path)
        with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'%s.pkl' %(append)), 'wb') as f:
            utils.pickle_dump(infos, f)
        if histories:
            with open(os.path.join(opt.checkpoint_path, 'histories_'+opt.id+'%s.pkl' %(append)), 'wb') as f:
                utils.pickle_dump(histories, f)

    try:
        while True:
            sys.stdout.flush()
            if epoch_done:
                if not opt.noamopt and not opt.reduce_on_plateau:
                    # Assign the learning rate
                    if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                        frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every
                        decay_factor = opt.learning_rate_decay_rate ** frac
                        opt.current_lr = opt.learning_rate * decay_factor
                    else:
                        opt.current_lr = opt.learning_rate
                    utils.set_lr(optimizer, opt.current_lr)
                    print('Learning Rate: ', opt.current_lr)
                if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                    sc_flag = True
                    init_scorer(opt.cached_tokens)
                else:
                    sc_flag = False
                epoch_done = False

            data = loader.get_batch('train')
            if (iteration % acc_steps == 0):
                optimizer.zero_grad()

            torch.cuda.synchronize()
            start = time.time()
            tmp = [data['fc_feats'], data['att_feats'], data['c3d_feats'], data['labels'], data['masks'], data['att_masks'], data['c3d_masks']]
            tmp = [_ if _ is None else _.cuda() for _ in tmp]
            fc_feats, att_feats, c3d_feats, labels, masks, att_masks, c3d_masks = tmp

            model_out = dp_lw_model(fc_feats, att_feats, c3d_feats, labels, masks, att_masks, c3d_masks, data['gts'], torch.arange(0, len(data['gts'])), sc_flag)

            loss = model_out['loss'].mean()
            loss_sp = loss / acc_steps

            loss_sp.backward()
            if ((iteration + 1) % acc_steps == 0):
                utils.clip_gradient(optimizer, opt.grad_clip)
                optimizer.step()
            torch.cuda.synchronize()
            train_loss = loss.item()
            end = time.time()
            if iteration % 1 == 0:
                if not sc_flag:
                    print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}".format(iteration, epoch, train_loss, end - start))
                else:
                    print("iter {} (epoch {}), reward1 = {:.3f}, reward2 = {:.3f}, reward3 = {:.3f}, train_loss = {:.3f}, time/batch = {:.3f}".format(iteration, epoch, model_out['reward_layer1'].mean(), model_out['reward_layer2'].mean(), model_out['reward_layer3'].mean(), train_loss, end - start))

            iteration += 1
            if data['bounds']['wrapped']:
                epoch += 1
                epoch_done = True

            if (iteration % opt.losses_log_every == 0):
                add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration)
                if opt.noamopt:
                    opt.current_lr = optimizer.rate()
                elif opt.reduce_on_plateau:
                    opt.current_lr = optimizer.current_lr
                add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration)
                add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration)
                if sc_flag:
                    add_summary_value(tb_summary_writer, 'reward1', model_out['reward_layer1'].mean(), iteration)
                    add_summary_value(tb_summary_writer, 'reward2', model_out['reward_layer2'].mean(), iteration)
                    add_summary_value(tb_summary_writer, 'reward3', model_out['reward_layer3'].mean(), iteration)

                loss_history[iteration] = train_loss
                lr_history[iteration] = opt.current_lr
                ss_prob_history[iteration] = model.ss_prob

            infos['iter'] = iteration
            infos['epoch'] = epoch
            infos['iterators'] = loader.iterators
            infos['split_ix'] = loader.split_ix

            if (iteration % opt.save_checkpoint_every == 0):
                # eval model
                eval_kwargs = {'split': opt.val_split, 'dataset': opt.input_json}
                eval_kwargs.update(vars(opt))
                val_loss, predictions, lang_stats = eval_utils.eval_split(dp_model, lw_model.crit, loader, eval_kwargs)
                print('Summary Epoch {} Iteration {}: CIDEr: {} BLEU-4: {}'.format(epoch, iteration, lang_stats['CIDEr'], lang_stats['Bleu_4']))

                if opt.reduce_on_plateau:
                    if opt.reward_metric == 'cider':
                        optimizer.scheduler_step(-lang_stats['CIDEr'])
                    elif opt.reward_metric == 'bleu':
                        optimizer.scheduler_step(-lang_stats['Bleu_4'])
                    elif opt.reward_metric == 'meteor':
                        optimizer.scheduler_step(-lang_stats['METEOR'])
                    elif opt.reward_metric == 'rouge':
                        optimizer.scheduler_step(-lang_stats['ROUGE_L'])
                    else:
                        optimizer.scheduler_step(val_loss)
                # Write validation result into summary
                add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration)
                if lang_stats is not None:
                    for k,v in lang_stats.items():
                        add_summary_value(tb_summary_writer, k, v, iteration)
                val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions}

                # Save model if is improving on validation result
                if opt.language_eval == 1:
                    if opt.reward_metric == 'cider':
                        current_score = lang_stats['CIDEr']
                    elif opt.reward_metric == 'bleu':
                        current_score = lang_stats['Bleu_4']
                    elif opt.reward_metric == 'meteor':
                        current_score = lang_stats['METEOR']
                    elif opt.reward_metric == 'rouge':
                        current_score = lang_stats['ROUGE_L']
                else:
                    current_score = - val_loss

                best_flag = False

                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True

                # Dump miscalleous informations
                infos['best_val_score'] = best_val_score
                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history

                save_checkpoint(model, infos, optimizer, histories)
                if opt.save_history_ckpt:
                    save_checkpoint(model, infos, optimizer, append=str(iteration))

                if best_flag:
                    save_checkpoint(model, infos, optimizer, append='best')

            # Stop if reaching max epochs
            if epoch >= opt.max_epochs and opt.max_epochs != -1:
                break

    except (RuntimeError, KeyboardInterrupt):
        print('Save ckpt on exception ...')
        save_checkpoint(model, infos, optimizer)
        print('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)
def train(opt):
    ################################
    # Build dataloader
    ################################
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    ##########################
    # Initialize infos
    ##########################
    infos = {
        'iter': 0,
        'epoch': 0,
        'vocab': loader.get_vocab(),
    }
    # Load old infos (if there is) and check if models are compatible
    if opt.checkpoint_path is not None and os.path.isfile(
            os.path.join(opt.checkpoint_path, 'infos_' + opt.id + '.pkl')):
        with open(
                os.path.join(opt.checkpoint_path, 'infos_' + opt.id + '.pkl'),
                'rb') as f:
            infos = utils.pickle_load(f)
            print('infos load success')
    infos['opt'] = opt

    # tensorboard logger
    tb_summary_writer = SummaryWriter(opt.checkpoint_path)

    ##########################
    # Build model
    ##########################
    opt.vocab = loader.get_vocab()
    model = models.setup(opt).cuda()
    del opt.vocab
    # Load pretrained weights:
    if opt.checkpoint_path is not None and os.path.isfile(
            os.path.join(opt.checkpoint_path, 'model.pth')):
        model.load_state_dict(
            torch.load(os.path.join(opt.checkpoint_path, 'model.pth')))
        print('model load success')

    # Wrap generation model with loss function(used for training)
    # This allows loss function computed separately on each machine
    lw_model = LossWrapper(model, opt)
    # Wrap with dataparallel
    dp_model = torch.nn.DataParallel(model)
    dp_lw_model = torch.nn.DataParallel(lw_model)

    ##########################
    #  Build optimizer
    ##########################
    optimizer = utils.ReduceLROnPlateau(optim.Adam(model.parameters(),
                                                   opt.learning_rate),
                                        factor=0.5,
                                        patience=3)
    # Load the optimizer
    if opt.checkpoint_path is not None and os.path.isfile(
            os.path.join(opt.checkpoint_path, "optimizer.pth")):
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.checkpoint_path, 'optimizer.pth')))

    #########################
    # Get ready to start
    #########################
    iteration = infos['iter']
    epoch = infos['epoch']
    best_val_score = infos.get('best_val_score', None)
    print('iter {}, epoch {}, best_val_score {}'.format(
        iteration, epoch, best_val_score))

    print(sorted(dict(set(vars(opt).items())).items(), key=lambda x: x[0]))
    # Start training
    if opt.self_critical:
        init_scorer(opt.cached_tokens)
    # Assure in training mode
    dp_lw_model.train()
    try:
        while True:
            # Stop if reaching max_epoch
            if epoch >= opt.max_epochs:
                break

            # Load data from train split (0)
            data = loader.get_batch('train')

            torch.cuda.synchronize()

            tmp = [
                data['fc_feats'], data['att_feats'], data['labels'],
                data['masks'], data['att_masks']
            ]
            tmp = [_ if _ is None else _.cuda() for _ in tmp]
            fc_feats, att_feats, labels, masks, att_masks = tmp

            optimizer.zero_grad()
            model_out = dp_lw_model(fc_feats, att_feats, labels, masks,
                                    att_masks, data['gts'],
                                    torch.arange(0, len(data['gts'])))

            loss = model_out['loss'].mean()

            loss.backward()
            torch.nn.utils.clip_grad_value_(model.parameters(), 0.1)
            optimizer.step()
            train_loss = loss.item()
            torch.cuda.synchronize()

            # Update the iteration and epoch
            iteration += 1
            if data['bounds']['wrapped']:
                epoch += 1

            # Write the training loss summary
            if iteration % opt.losses_log_every == 0:
                tb_summary_writer.add_scalar('train_loss', train_loss,
                                             iteration)
                opt.current_lr = optimizer.current_lr
                tb_summary_writer.add_scalar('learning_rate', opt.current_lr,
                                             iteration)
                if opt.self_critical:
                    tb_summary_writer.add_scalar('avg_reward',
                                                 model_out['reward'].mean(),
                                                 iteration)

            # update infos
            infos['iter'] = iteration
            infos['epoch'] = epoch

            # make evaluation on validation set, and save model
            if iteration % opt.save_checkpoint_every == 0:
                tb_summary_writer.add_scalar('epoch', epoch, iteration)
                # eval model
                eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
                eval_kwargs.update(vars(opt))
                _, _, lang_stats = eval_utils.eval_split(
                    dp_model, loader, eval_kwargs)

                optimizer.scheduler_step(-lang_stats['CIDEr'])
                # Write validation result into summary
                for k, v in lang_stats.items():
                    tb_summary_writer.add_scalar(k, v, iteration)

                # Save model if is improving on validation result
                current_score = lang_stats['CIDEr']

                best_flag = False
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True

                # Dump miscellaneous information
                infos['best_val_score'] = best_val_score

                utils.save_checkpoint(opt, model, infos, optimizer)
                if best_flag:
                    utils.save_checkpoint(opt,
                                          model,
                                          infos,
                                          optimizer,
                                          append='best')

    except (RuntimeError, KeyboardInterrupt):
        pass