Ejemplo n.º 1
0
def train(opt):

    # Load data
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    # Tensorboard summaries (they're great!)
    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    # Load pretrained model, info file, histories file
    infos = {}
    histories = {}
    if opt.start_from is not None:
        print("opt.start_from: " + str(opt.start_from))
        with open(os.path.join(opt.start_from,
                               'infos_' + opt.id + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = ["rnn_type", "rnn_size", "num_layers"]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme
        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl')) as f:
                histories = cPickle.load(f)
    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)
    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})
    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    # create model
    model = models.setup(opt).cuda()
    dp_model = torch.nn.DataParallel(model)

    # load model
    if os.path.isfile("log_sc/model.pth"):
        model_path = "log_sc/model.pth"
        state_dict = torch.load(model_path)
        dp_model.load_state_dict(state_dict)

    dp_model.train()

    # create/load vector model
    vectorModel = models.setup_vectorModel().cuda()
    dp_vectorModel = torch.nn.DataParallel(vectorModel)

    # load vector model
    if os.path.isfile("log_sc/model_vec.pth"):
        model_vec_path = "log_sc/model_vec.pth"
        state_dict_vec = torch.load(model_vec_path)
        dp_vectorModel.load_state_dict(state_dict_vec)

    dp_vectorModel.train()

    optimizer = utils.build_optimizer(
        list(model.parameters()) + list(vectorModel.parameters()), opt)
    update_lr_flag = True

    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(
            os.path.join(opt.start_from, "optimizer.pth")):
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    # Loss function
    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()
    vec_crit = nn.L1Loss()

    # create idxs for doc2vec vectors
    with open('paragraphs_image_ids.txt', 'r') as file:
        paragraph_image_ids = file.readlines()

    paragraph_image_ids = [int(i) for i in paragraph_image_ids]

    # select corresponding vectors
    with open('paragraphs_vectors.txt', 'r') as the_file:
        vectors = the_file.readlines()

    vectors_list = []
    for string in vectors:
        vectors_list.append([float(s) for s in string.split(' ')])

    vectors_list_np = np.asarray(vectors_list)

    print("Starting training loop!")

    # Training loop
    while True:

        # Update learning rate once per epoch
        if update_lr_flag:

            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate
            utils.set_lr(optimizer, opt.current_lr)

            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            update_lr_flag = False

        # Load data from train split (0)
        start = time.time()
        data = loader.get_batch('train')
        data_time = time.time() - start
        start = time.time()

        # pad data['att_feats'] axis=1 to have length = 83
        def pad_along_axis(array, target_length, axis=0):

            pad_size = target_length - array.shape[axis]
            axis_nb = len(array.shape)

            if pad_size < 0:
                return a

            npad = [(0, 0) for x in range(axis_nb)]
            npad[axis] = (0, pad_size)

            b = np.pad(array,
                       pad_width=npad,
                       mode='constant',
                       constant_values=0)

            return b

        data['att_feats'] = pad_along_axis(data['att_feats'], 83, axis=1)

        # Unpack data
        torch.cuda.synchronize()
        tmp = [
            data['fc_feats'], data['att_feats'], data['labels'], data['masks'],
            data['att_masks']
        ]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        fc_feats, att_feats, labels, masks, att_masks = tmp

        idx = []
        for element in data['infos']:
            idx.append(paragraph_image_ids.index(element['id']))

        batch_vectors = vectors_list_np[idx]

        # Forward pass and loss
        optimizer.zero_grad()
        if not sc_flag:
            loss = crit(dp_model(fc_feats, att_feats, labels, att_masks),
                        labels[:, 1:], masks[:, 1:])
        else:
            gen_result, sample_logprobs = dp_model(fc_feats,
                                                   att_feats,
                                                   att_masks,
                                                   opt={'sample_max': 0},
                                                   mode='sample')
            reward = get_self_critical_reward(dp_model, fc_feats, att_feats,
                                              att_masks, data, gen_result, opt)
            loss = rl_crit(sample_logprobs, gen_result.data,
                           torch.from_numpy(reward).float().cuda())

        att_feats_reshaped = att_feats.permute(0, 2, 1).cuda()
        semantic_features = dp_vectorModel(att_feats_reshaped.cuda(),
                                           fc_feats)  # (10, 2048)
        batch_vectors = torch.from_numpy(
            batch_vectors).float().cuda()  # (10, 512)
        vec_loss = vec_crit(semantic_features, batch_vectors)
        alpha_ = 1
        loss = loss + (alpha_ * vec_loss)

        # Backward pass
        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        train_loss = loss.item()
        torch.cuda.synchronize()

        # Print
        total_time = time.time() - start
        if iteration % opt.print_freq == 1:
            print('Read data:', time.time() - start)
            if not sc_flag:
                print("iter {} (epoch {}), train_loss = {:.3f}, data_time = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, train_loss, data_time, total_time))
            else:
                print("iter {} (epoch {}), avg_reward = {:.3f}, data_time = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, np.mean(reward[:,0]), data_time, total_time))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            add_summary_value(tb_summary_writer, 'train_loss', train_loss,
                              iteration)
            add_summary_value(tb_summary_writer, 'learning_rate',
                              opt.current_lr, iteration)
            add_summary_value(tb_summary_writer, 'scheduled_sampling_prob',
                              model.ss_prob, iteration)
            if sc_flag:
                add_summary_value(tb_summary_writer, 'avg_reward',
                                  np.mean(reward[:, 0]), iteration)
            loss_history[iteration] = train_loss if not sc_flag else np.mean(
                reward[:, 0])
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # Validate and save model
        if True:

            # Evaluate model
            eval_kwargs = {'split': 'test', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = eval_utils.eval_split(
                dp_model, crit, loader, eval_kwargs)

            # Write validation result into summary
            add_summary_value(tb_summary_writer, 'validation loss', val_loss,
                              iteration)
            if lang_stats is not None:
                for k, v in lang_stats.items():
                    add_summary_value(tb_summary_writer, k, v, iteration)
            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Our metric is CIDEr if available, otherwise validation loss
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            # Save model in checkpoint path
            best_flag = False
            if best_val_score is None or current_score > best_val_score:
                best_val_score = current_score
                best_flag = True
            checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth')
            torch.save(dp_model.state_dict(), checkpoint_path)
            print("model saved to {}".format(checkpoint_path))

            # save vec model
            checkpoint_path = os.path.join(opt.checkpoint_path,
                                           'model_vec.pth')
            torch.save(dp_vectorModel.state_dict(), checkpoint_path)
            print("model_vec saved to {}".format(checkpoint_path))

            optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth')
            torch.save(optimizer.state_dict(), optimizer_path)

            # Dump miscalleous informations
            infos['iter'] = iteration
            infos['epoch'] = epoch
            infos['iterators'] = loader.iterators
            infos['split_ix'] = loader.split_ix
            infos['best_val_score'] = best_val_score
            infos['opt'] = opt
            infos['vocab'] = loader.get_vocab()
            histories['val_result_history'] = val_result_history
            histories['loss_history'] = loss_history
            histories['lr_history'] = lr_history
            histories['ss_prob_history'] = ss_prob_history
            with open(
                    os.path.join(opt.checkpoint_path,
                                 'infos_' + opt.id + '.pkl'), 'wb') as f:
                cPickle.dump(infos, f)
            with open(
                    os.path.join(opt.checkpoint_path,
                                 'histories_' + opt.id + '.pkl'), 'wb') as f:
                cPickle.dump(histories, f)

            # Save model to unique file if new best model
            if best_flag:
                model_fname = 'model-best-i{:05d}-score{:.4f}.pth'.format(
                    iteration, best_val_score)
                infos_fname = 'model-best-i{:05d}-infos.pkl'.format(iteration)
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               model_fname)
                torch.save(dp_model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))

                # best vec
                model_fname_vec = 'model-best-vec-i{:05d}-score{:.4f}.pth'.format(
                    iteration, best_val_score)
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               model_fname_vec)
                torch.save(dp_vectorModel.state_dict(), checkpoint_path)
                print("model_vec saved to {}".format(checkpoint_path))

                with open(os.path.join(opt.checkpoint_path, infos_fname),
                          'wb') as f:
                    cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
Ejemplo n.º 2
0
def train(opt):
    # Deal with feature things before anything

    acc_steps = getattr(opt, 'acc_steps', 1)

    loader = DataLoaderRaw(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl'),
                  'rb') as f:
            infos = utils.pickle_load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl'), 'rb') as f:
                histories = utils.pickle_load(f)
    else:
        infos['iter'] = 0
        infos['epoch'] = 0
        infos['iterators'] = loader.iterators
        infos['split_ix'] = loader.split_ix
        infos['vocab'] = loader.get_vocab()
    infos['opt'] = opt

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    opt.vocab = loader.get_vocab()
    dp_model = models.setup(opt)

    model = dp_model.cuda()

    del opt.vocab
    dp_lw_model = LossWrapper(dp_model, opt)
    lw_model = dp_lw_model

    epoch_done = True
    # Assure in training mode
    dp_lw_model.train()

    if opt.noamopt:
        assert opt.caption_model in [
            'transformer', 'mngrcnn'
        ], 'noamopt can only work with transformer'
        optimizer = utils.get_std_opt(model,
                                      factor=opt.noamopt_factor,
                                      warmup=opt.noamopt_warmup)
        optimizer._step = iteration
    elif opt.reduce_on_plateau:
        optimizer = utils.build_optimizer(model.parameters(), opt)
        optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3)
    else:
        optimizer = utils.build_optimizer(model.parameters(), opt)
    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(
            os.path.join(opt.start_from, "optimizer.pth")):
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    def save_checkpoint(model, infos, optimizer, histories=None, append=''):
        if len(append) > 0:
            append = '-' + append
        # if checkpoint_path doesn't exist
        if not os.path.isdir(opt.checkpoint_path):
            os.makedirs(opt.checkpoint_path)
        checkpoint_path = os.path.join(opt.checkpoint_path,
                                       'model%s.pth' % (append))
        torch.save(model.state_dict(), checkpoint_path)
        print("model saved to {}".format(checkpoint_path))
        optimizer_path = os.path.join(opt.checkpoint_path,
                                      'optimizer%s.pth' % (append))
        torch.save(optimizer.state_dict(), optimizer_path)
        with open(
                os.path.join(opt.checkpoint_path,
                             'infos_' + opt.id + '%s.pkl' % (append)),
                'wb') as f:
            utils.pickle_dump(infos, f)
        if histories:
            with open(
                    os.path.join(opt.checkpoint_path,
                                 'histories_' + opt.id + '%s.pkl' % (append)),
                    'wb') as f:
                utils.pickle_dump(histories, f)

    try:
        while True:
            if epoch_done:
                if not opt.noamopt and not opt.reduce_on_plateau:
                    # Assign the learning rate
                    if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                        frac = (epoch - opt.learning_rate_decay_start
                                ) // opt.learning_rate_decay_every
                        decay_factor = opt.learning_rate_decay_rate**frac
                        opt.current_lr = opt.learning_rate * decay_factor
                    else:
                        opt.current_lr = opt.learning_rate
                    # set the decayed rate
                    utils.set_lr(optimizer, opt.current_lr)
                # Assign the scheduled sampling prob
                if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                    frac = (epoch - opt.scheduled_sampling_start
                            ) // opt.scheduled_sampling_increase_every
                    opt.ss_prob = min(
                        opt.scheduled_sampling_increase_prob * frac,
                        opt.scheduled_sampling_max_prob)
                    model.ss_prob = opt.ss_prob

                # If start self critical training
                if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                    sc_flag = True
                    init_scorer(opt.cached_tokens)
                else:
                    sc_flag = False

                epoch_done = False

            start = time.time()
            if (opt.use_warmup == 1) and (iteration < opt.noamopt_warmup):
                opt.current_lr = opt.learning_rate * \
                    (iteration+1) / opt.noamopt_warmup
                utils.set_lr(optimizer, opt.current_lr)
            # Load data from train split (0)
            data = loader.get_batch('train')
            print('Read data:', time.time() - start)

            if (iteration % acc_steps == 0):
                optimizer.zero_grad()

            torch.cuda.synchronize()
            start = time.time()
            tmp = [data['attr'], data['img'], data['labels'], data['masks']]
            tmp = [_ if _ is None else _.cuda() for _ in tmp]
            attrs, imgs, labels, masks = tmp

            model_out = dp_lw_model(attrs, imgs, labels, masks, data['gts'],
                                    torch.arange(0, len(data['gts'])), sc_flag)

            loss = model_out['loss'].mean()
            loss_sp = loss / acc_steps

            loss_sp.backward()
            if ((iteration + 1) % acc_steps == 0):
                utils.clip_gradient(optimizer, opt.grad_clip)
                optimizer.step()
            torch.cuda.synchronize()
            train_loss = loss.item()
            end = time.time()
            if not sc_flag:
                print(
                    "{}: iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}"
                    .format(
                        time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
                        iteration, epoch, train_loss, end - start))
            else:
                print(
                    "{}: iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}"
                    .format(
                        time.strftime("%Y-%m-%d %H:%M:%S",
                                      time.localtime()), iteration, epoch,
                        model_out['reward'].mean(), end - start))

            # Update the iteration and epoch
            iteration += 1
            if data['bounds']['wrapped']:
                epoch += 1
                epoch_done = True

            # Write the training loss summary
            if (iteration % opt.losses_log_every == 0):
                add_summary_value(tb_summary_writer, 'train_loss', train_loss,
                                  iteration)
                if opt.noamopt:
                    opt.current_lr = optimizer.rate()
                elif opt.reduce_on_plateau:
                    opt.current_lr = optimizer.current_lr
                add_summary_value(tb_summary_writer, 'learning_rate',
                                  opt.current_lr, iteration)
                add_summary_value(tb_summary_writer, 'scheduled_sampling_prob',
                                  model.ss_prob, iteration)
                if sc_flag:
                    add_summary_value(tb_summary_writer, 'avg_reward',
                                      model_out['reward'].mean(), iteration)

                loss_history[
                    iteration] = train_loss if not sc_flag else model_out[
                        'reward'].mean()
                lr_history[iteration] = opt.current_lr
                ss_prob_history[iteration] = model.ss_prob

            # update infos
            infos['iter'] = iteration
            infos['epoch'] = epoch
            infos['iterators'] = loader.iterators
            infos['split_ix'] = loader.split_ix

            # make evaluation on validation set, and save model
            if (iteration % opt.save_checkpoint_every == 0):
                # eval model
                eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
                eval_kwargs.update(vars(opt))
                val_loss, predictions, lang_stats = eval_utils.eval_split(
                    dp_model, lw_model.crit, loader, eval_kwargs)

                if opt.reduce_on_plateau:
                    if 'CIDEr' in lang_stats:
                        optimizer.scheduler_step(-lang_stats['CIDEr'])
                    else:
                        optimizer.scheduler_step(val_loss)
                # Write validation result into summary
                add_summary_value(tb_summary_writer, 'validation loss',
                                  val_loss, iteration)
                if lang_stats is not None:
                    for k, v in lang_stats.items():
                        add_summary_value(tb_summary_writer, k, v, iteration)
                val_result_history[iteration] = {
                    'loss': val_loss,
                    'lang_stats': lang_stats,
                    'predictions': predictions
                }

                # Save model if is improving on validation result
                if opt.language_eval == 1:
                    current_score = lang_stats['CIDEr']
                else:
                    current_score = -val_loss

                best_flag = False

                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True

                # Dump miscalleous informations
                infos['best_val_score'] = best_val_score
                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history

                save_checkpoint(model, infos, optimizer, histories)
                if opt.save_history_ckpt:
                    save_checkpoint(model,
                                    infos,
                                    optimizer,
                                    append=str(iteration))

                if best_flag:
                    save_checkpoint(model, infos, optimizer, append='best')

            # Stop if reaching max epochs
            if epoch >= opt.max_epochs and opt.max_epochs != -1:
                break
    except (RuntimeError, KeyboardInterrupt):
        print('Save ckpt on exception ...')
        save_checkpoint(model, infos, optimizer)
        print('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)
Ejemplo n.º 3
0
def train(opt):
    # on node-13 this line cauuses a bug
    from torch.utils.tensorboard import SummaryWriter

    ################################
    # Build dataloader
    ################################

    # the loader here needs to be fixed actually...
    # so that data loading is correct
    # need to modify opt here so everything else is correct

    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    ##########################
    # Initialize infos
    ##########################
    infos = {
        'iter': 0,
        'epoch': 0,
        'loader_state_dict': None,
        'vocab': loader.get_vocab(),
    }
    # Load old infos(if there is) and check if models are compatible
    if opt.start_from is not None and os.path.isfile(
            os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl')):
        raise Exception("not implemented")

    infos['opt'] = opt

    #########################
    # Build logger
    #########################
    # naive dict logger
    histories = defaultdict(dict)
    if opt.start_from is not None and os.path.isfile(
            os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
        with open(os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl'),
                  'rb') as f:
            histories.update(utils.pickle_load(f))

    # tensorboard logger
    tb_summary_writer = SummaryWriter(opt.checkpoint_path)

    ##########################
    # Build model
    ##########################
    # opt.vocab = loader.get_vocab()
    opt.vocab = loader.get_vocab()
    model = TransformerLM(opt).cuda()  # only set up the language model
    del opt.vocab
    # Load pretrained weights:
    if opt.start_from is not None and os.path.isfile(
            os.path.join(opt.start_from, 'model.pth')):
        model.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'model.pth')))

    # Wrap generation model with loss function(used for training)
    # This allows loss function computed separately on each machine
    lw_model = LossWrapper(model, opt)
    # Wrap with dataparallel
    dp_model = torch.nn.DataParallel(model)
    dp_lw_model = torch.nn.DataParallel(lw_model)

    ##########################
    #  Build optimizer
    ##########################
    if opt.noamopt:
        assert opt.caption_model == 'transformer', 'noamopt can only work with transformer'
        optimizer = utils.get_std_opt(model,
                                      factor=opt.noamopt_factor,
                                      warmup=opt.noamopt_warmup)
    elif opt.reduce_on_plateau:
        optimizer = utils.build_optimizer(model.parameters(), opt)
        optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3)
    else:
        optimizer = utils.build_optimizer(model.parameters(), opt)
    # Load the optimizer
    if opt.start_from is not None and os.path.isfile(
            os.path.join(opt.start_from, "optimizer.pth")):
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    #########################
    # Get ready to start
    #########################
    iteration = infos['iter']
    epoch = infos['epoch']
    # For back compatibility
    if 'iterators' in infos:
        infos['loader_state_dict'] = {
            split: {
                'index_list': infos['split_ix'][split],
                'iter_counter': infos['iterators'][split]
            }
            for split in ['train', 'val', 'test']
        }
    loader.load_state_dict(infos['loader_state_dict'])
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)
    if opt.noamopt:
        optimizer._step = iteration
    # flag indicating finish of an epoch
    # Always set to True at the beginning to initialize the lr or etc.
    epoch_done = True
    # Assure in training mode
    dp_lw_model.train()

    # Start training
    try:
        while True:
            if epoch_done:
                if not opt.noamopt and not opt.reduce_on_plateau:
                    # Assign the learning rate
                    if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                        frac = (epoch - opt.learning_rate_decay_start
                                ) // opt.learning_rate_decay_every
                        decay_factor = opt.learning_rate_decay_rate**frac
                        opt.current_lr = opt.learning_rate * decay_factor
                    else:
                        opt.current_lr = opt.learning_rate
                    utils.set_lr(optimizer,
                                 opt.current_lr)  # set the decayed rate
                # Assign the scheduled sampling prob
                if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                    frac = (epoch - opt.scheduled_sampling_start
                            ) // opt.scheduled_sampling_increase_every
                    opt.ss_prob = min(
                        opt.scheduled_sampling_increase_prob * frac,
                        opt.scheduled_sampling_max_prob)
                    model.ss_prob = opt.ss_prob

                # If start self critical training
                if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                    sc_flag = True
                    init_scorer(opt.cached_tokens)
                else:
                    sc_flag = False

                # If start structure loss training
                if opt.structure_after != -1 and epoch >= opt.structure_after:
                    struc_flag = True
                    init_scorer(opt.cached_tokens)
                else:
                    struc_flag = False

                epoch_done = False

            start = time.time()
            # Load data from train split (0)
            data = loader.get_batch('train')
            print('Read data:', time.time() - start)

            torch.cuda.synchronize()
            start = time.time()

            tmp = [
                data['fc_feats'], data['att_feats'], data['labels'],
                data['masks'], data['att_masks']
            ]
            tmp = [_ if _ is None else _.cuda() for _ in tmp]
            fc_feats, att_feats, labels, masks, att_masks = tmp

            optimizer.zero_grad()
            model_out = dp_lw_model(fc_feats, att_feats, labels, masks,
                                    att_masks, data['gts'],
                                    torch.arange(0, len(data['gts'])), sc_flag,
                                    struc_flag)

            loss = model_out['loss'].mean()

            loss.backward()
            utils.clip_gradient(optimizer, opt.grad_clip)
            optimizer.step()
            train_loss = loss.item()
            torch.cuda.synchronize()
            end = time.time()
            if struc_flag:
                print(
                    "iter {} (epoch {}), train_loss = {:.3f}, lm_loss = {:.3f}, struc_loss = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, train_loss, model_out['lm_loss'].mean().item(),
                            model_out['struc_loss'].mean().item(), end - start))
            elif not sc_flag:
                print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                      .format(iteration, epoch, train_loss, end - start))
            else:
                print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \
                      .format(iteration, epoch, model_out['reward'].mean(), end - start))

            # Update the iteration and epoch
            iteration += 1
            if data['bounds']['wrapped']:
                epoch += 1
                epoch_done = True

            # Write the training loss summary
            if (iteration % opt.losses_log_every == 0):
                tb_summary_writer.add_scalar('train_loss', train_loss,
                                             iteration)
                if opt.noamopt:
                    opt.current_lr = optimizer.rate()
                elif opt.reduce_on_plateau:
                    opt.current_lr = optimizer.current_lr
                tb_summary_writer.add_scalar('learning_rate', opt.current_lr,
                                             iteration)
                tb_summary_writer.add_scalar('scheduled_sampling_prob',
                                             model.ss_prob, iteration)
                if sc_flag:
                    tb_summary_writer.add_scalar('avg_reward',
                                                 model_out['reward'].mean(),
                                                 iteration)
                elif struc_flag:
                    tb_summary_writer.add_scalar(
                        'lm_loss', model_out['lm_loss'].mean().item(),
                        iteration)
                    tb_summary_writer.add_scalar(
                        'struc_loss', model_out['struc_loss'].mean().item(),
                        iteration)
                    tb_summary_writer.add_scalar(
                        'reward', model_out['reward'].mean().item(), iteration)

                histories['loss_history'][
                    iteration] = train_loss if not sc_flag else model_out[
                        'reward'].mean()
                histories['lr_history'][iteration] = opt.current_lr
                histories['ss_prob_history'][iteration] = model.ss_prob

            # update infos
            infos['iter'] = iteration
            infos['epoch'] = epoch
            infos['loader_state_dict'] = loader.state_dict()

            # make evaluation on validation set, and save model
            if (iteration % opt.save_checkpoint_every == 0):
                # eval model
                eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
                eval_kwargs.update(vars(opt))
                val_loss, predictions, lang_stats = eval_utils.eval_split(
                    dp_model, lw_model.crit, loader, eval_kwargs)

                if opt.reduce_on_plateau:
                    if lang_stats is not None:
                        if 'CIDEr' in lang_stats:
                            optimizer.scheduler_step(-lang_stats['CIDEr'])
                        else:
                            optimizer.scheduler_step(val_loss)
                    else:
                        optimizer.scheduler_step(val_loss)
                # Write validation result into summary
                tb_summary_writer.add_scalar('validation loss', val_loss,
                                             iteration)
                if lang_stats is not None:
                    for k, v in lang_stats.items():
                        tb_summary_writer.add_scalar(k, v, iteration)
                histories['val_result_history'][iteration] = {
                    'loss': val_loss,
                    'lang_stats': lang_stats,
                    'predictions': predictions
                }

                # Save model if is improving on validation result
                if opt.language_eval == 1:
                    current_score = lang_stats['CIDEr']
                else:
                    current_score = -val_loss

                best_flag = False

                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True

                # Dump miscellaneous information
                infos['best_val_score'] = best_val_score

                utils.save_checkpoint(opt, model, infos, optimizer, histories)
                if opt.save_history_ckpt:
                    utils.save_checkpoint(opt,
                                          model,
                                          infos,
                                          optimizer,
                                          append=str(iteration))

                if best_flag:
                    utils.save_checkpoint(opt,
                                          model,
                                          infos,
                                          optimizer,
                                          append='best')

            # Stop if reaching max epochs
            if epoch >= opt.max_epochs and opt.max_epochs != -1:
                break

    except (RuntimeError, KeyboardInterrupt):
        print('Save ckpt on exception ...')
        utils.save_checkpoint(opt, model, infos, optimizer)
        print('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)
Ejemplo n.º 4
0
def train(opt):
    print("=================Training Information==============")
    print("start from {}".format(opt.start_from))
    print("box from {}".format(opt.input_box_dir))
    print("input json {}".format(opt.input_json))
    print("attributes from {}".format(opt.input_att_dir))
    print("features from {}".format(opt.input_fc_dir))
    print("batch size ={}".format(opt.batch_size))
    print("#GPU={}".format(torch.cuda.device_count()))
    # Deal with feature things before anything
    opt.use_fc, opt.use_att = utils.if_use_feat(opt.caption_model)
    if opt.use_box:
        opt.att_feat_size = opt.att_feat_size + 5

    acc_steps = getattr(opt, 'acc_steps', 1)
    name_append = opt.name_append
    if len(name_append) > 0 and name_append[0] != '-':
        name_append = '_' + name_append

    loader = DataLoader(opt)

    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length
    opt.write_summary = write_summary
    if opt.write_summary:
        print("write summary to {}".format(opt.checkpoint_path))
        tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    infos = {}
    histories = {}

    if opt.start_from is not None:
        # open old infos and check if models are compatible
        infors_path = os.path.join(opt.start_from,
                                   'infos' + name_append + '.pkl')
        print("Load model information {}".format(infors_path))
        with open(infors_path, 'rb') as f:
            infos = utils.pickle_load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        histories_path = os.path.join(opt.start_from,
                                      'histories_' + name_append + '.pkl')
        if os.path.isfile(histories_path):
            with open(histories_path, 'rb') as f:
                histories = utils.pickle_load(f)
    else:  # start from scratch
        print("Initialize training process from all begining")
        infos['iter'] = 0
        infos['epoch'] = 0
        infos['iterators'] = loader.iterators
        infos['split_ix'] = loader.split_ix
        infos['vocab'] = loader.get_vocab()
    infos['opt'] = opt

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    #  sanity check for the saved model name has a correct index
    if opt.name_append.isdigit() and int(opt.name_append) < 100:
        assert int(
            opt.name_append
        ) == epoch, "dismatch in the model index and the real epoch number"
        epoch += 1

    print(
        "==================start from {} epoch================".format(epoch))
    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})
    # pdb.set_trace()
    loader.iterators = infos.get('iterators', loader.iterators)
    start_Img_idx = loader.iterators['train']
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    opt.vocab = loader.get_vocab()
    model = models.setup(opt).cuda()
    del opt.vocab
    dp_model = torch.nn.DataParallel(model)
    lw_model = LossWrapper(model, opt)  # wrap loss into model
    dp_lw_model = torch.nn.DataParallel(lw_model)

    epoch_done = True
    # Assure in training mode
    dp_lw_model.train()

    if opt.noamopt:
        assert opt.caption_model in [
            'transformer', 'aoa'
        ], 'noamopt can only work with transformer'
        optimizer = utils.get_std_opt(model,
                                      factor=opt.noamopt_factor,
                                      warmup=opt.noamopt_warmup)
        optimizer._step = iteration
    elif opt.reduce_on_plateau:
        optimizer = utils.build_optimizer(model.parameters(), opt)
        optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3)
    else:
        optimizer = utils.build_optimizer(model.parameters(), opt)
    # Load the optimizer
    if vars(opt).get('start_from', None) is not None:
        optimizer_path = os.path.join(opt.start_from,
                                      'optimizer' + name_append + '.pth')
        if os.path.isfile(optimizer_path):
            print("Loading optimizer............")
            optimizer.load_state_dict(torch.load(optimizer_path))

    def save_checkpoint(model, infos, optimizer, histories=None, append=''):
        if len(append) > 0:
            append = '_' + append
        # if checkpoint_path doesn't exist
        if not os.path.isdir(opt.checkpoint_path):
            os.makedirs(opt.checkpoint_path)
        checkpoint_path = os.path.join(opt.checkpoint_path,
                                       'model%s.pth' % (append))
        torch.save(model.state_dict(), checkpoint_path)
        print("Save model state to {}".format(checkpoint_path))

        optimizer_path = os.path.join(opt.checkpoint_path,
                                      'optimizer%s.pth' % (append))
        torch.save(optimizer.state_dict(), optimizer_path)
        print("Save model optimizer to {}".format(optimizer_path))

        with open(
                os.path.join(opt.checkpoint_path,
                             'infos' + '%s.pkl' % (append)), 'wb') as f:
            utils.pickle_dump(infos, f)
            print("Save training information to {}".format(
                os.path.join(opt.checkpoint_path,
                             'infos' + '%s.pkl' % (append))))

        if histories:
            with open(
                    os.path.join(opt.checkpoint_path,
                                 'histories_' + '%s.pkl' % (append)),
                    'wb') as f:
                utils.pickle_dump(histories, f)
                print("Save training historyes to {}".format(
                    os.path.join(opt.checkpoint_path,
                                 'histories_' + opt.id + '%s.pkl' % (append))))

    try:
        while True:
            # pdb.set_trace()
            if epoch_done:
                if not opt.noamopt and not opt.reduce_on_plateau:
                    # Assign the learning rate
                    if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                        frac = (epoch - opt.learning_rate_decay_start
                                ) // opt.learning_rate_decay_every
                        decay_factor = opt.learning_rate_decay_rate**frac
                        opt.current_lr = opt.learning_rate * decay_factor
                    else:
                        opt.current_lr = opt.learning_rate
                    utils.set_lr(optimizer,
                                 opt.current_lr)  # set the decayed rate
                # Assign the scheduled sampling prob
                if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                    frac = (epoch - opt.scheduled_sampling_start
                            ) // opt.scheduled_sampling_increase_every
                    opt.ss_prob = min(
                        opt.scheduled_sampling_increase_prob * frac,
                        opt.scheduled_sampling_max_prob)
                    model.ss_prob = opt.ss_prob

                # If start self critical training
                if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                    sc_flag = True
                    init_scorer(opt.cached_tokens)
                else:
                    sc_flag = False

                epoch_done = False
            print("{}th Epoch Training starts now!".format(epoch))
            with tqdm(total=len(loader.split_ix['train']),
                      initial=start_Img_idx) as pbar:
                for i in range(start_Img_idx, len(loader.split_ix['train']),
                               opt.batch_size):
                    # import ipdb; ipdb.set_trace()
                    start = time.time()
                    if (opt.use_warmup
                            == 1) and (iteration < opt.noamopt_warmup):
                        opt.current_lr = opt.learning_rate * (
                            iteration + 1) / opt.noamopt_warmup
                        utils.set_lr(optimizer, opt.current_lr)
                    # Load data from train split (0)
                    data = loader.get_batch('train')
                    # print('Read data:', time.time() - start)

                    if (iteration % acc_steps == 0):
                        optimizer.zero_grad()

                    torch.cuda.synchronize()
                    start = time.time()
                    tmp = [
                        data['fc_feats'], data['att_feats'], data['labels'],
                        data['masks'], data['att_masks']
                    ]
                    tmp = [_ if _ is None else _.cuda() for _ in tmp]
                    fc_feats, att_feats, labels, masks, att_masks = tmp

                    model_out = dp_lw_model(fc_feats, att_feats, labels, masks,
                                            att_masks, data['gts'],
                                            torch.arange(0, len(data['gts'])),
                                            sc_flag)

                    loss = model_out['loss'].mean()
                    loss_sp = loss / acc_steps

                    loss_sp.backward()
                    if ((iteration + 1) % acc_steps == 0):
                        utils.clip_gradient(optimizer, opt.grad_clip)
                        optimizer.step()
                    torch.cuda.synchronize()
                    train_loss = loss.item()
                    end = time.time()
                    # if not sc_flag:
                    #     print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}"
                    #         .format(iteration, epoch, train_loss, end - start))
                    # else:
                    #     print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}"
                    #         .format(iteration, epoch, model_out['reward'].mean(), end - start))
                    if not sc_flag:
                        pbar.set_description(
                            "iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}"
                            .format(iteration, epoch, train_loss, end - start))
                    else:
                        pbar.set_description(
                            "iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}"
                            .format(iteration, epoch,
                                    model_out['reward'].mean(), end - start))

                    # Update the iteration and epoch
                    iteration += 1
                    pbar.update(opt.batch_size)
                    if data['bounds']['wrapped']:
                        # save after each epoch
                        save_checkpoint(model,
                                        infos,
                                        optimizer,
                                        append=str(epoch))
                        epoch += 1
                        # infos['epoch'] = epoch
                        epoch_done = True

                    # Write validation result into summary
                    if (iteration % opt.losses_log_every
                            == 0) and opt.write_summary:
                        add_summary_value(tb_summary_writer, 'loss/train_loss',
                                          train_loss, iteration)
                        if opt.noamopt:
                            opt.current_lr = optimizer.rate()
                        elif opt.reduce_on_plateau:
                            opt.current_lr = optimizer.current_lr
                        add_summary_value(tb_summary_writer,
                                          'hyperparam/learning_rate',
                                          opt.current_lr, iteration)
                        add_summary_value(
                            tb_summary_writer,
                            'hyperparam/scheduled_sampling_prob',
                            model.ss_prob, iteration)
                        if sc_flag:
                            add_summary_value(tb_summary_writer, 'avg_reward',
                                              model_out['reward'].mean(),
                                              iteration)

                        loss_history[
                            iteration] = train_loss if not sc_flag else model_out[
                                'reward'].mean()
                        lr_history[iteration] = opt.current_lr
                        ss_prob_history[iteration] = model.ss_prob

                    # update infos
                    infos['iter'] = iteration
                    infos['epoch'] = epoch
                    infos['iterators'] = loader.iterators
                    infos['split_ix'] = loader.split_ix

                    # make evaluation on validation set, and save model
                    # TODO modify it to evaluate by each epoch
                    # ipdb.set_trace()
                    if (iteration % opt.save_checkpoint_every
                            == 0) and eval_ and epoch > 20:
                        model_path = os.path.join(
                            opt.checkpoint_path,
                            'model_itr%s.pth' % (iteration))
                        eval_kwargs = {
                            'split': 'val',
                            'dataset': opt.input_json,
                            'model': model_path
                        }
                        eval_kwargs.update(vars(opt))
                        val_loss, predictions, lang_stats = eval_utils.eval_split(
                            dp_model, lw_model.crit, loader, eval_kwargs)

                        if opt.reduce_on_plateau:
                            if 'CIDEr' in lang_stats:
                                optimizer.scheduler_step(-lang_stats['CIDEr'])
                            else:
                                optimizer.scheduler_step(val_loss)

                        # Write validation result into summary
                        if opt.write_summary:
                            add_summary_value(tb_summary_writer,
                                              'loss/validation loss', val_loss,
                                              iteration)

                            if lang_stats is not None:
                                bleu_dict = {}
                                for k, v in lang_stats.items():
                                    if 'Bleu' in k:
                                        bleu_dict[k] = v
                                if len(bleu_dict) > 0:
                                    tb_summary_writer.add_scalars(
                                        'val/Bleu', bleu_dict, epoch)

                                for k, v in lang_stats.items():
                                    if 'Bleu' not in k:
                                        add_summary_value(
                                            tb_summary_writer, 'val/' + k, v,
                                            iteration)
                        val_result_history[iteration] = {
                            'loss': val_loss,
                            'lang_stats': lang_stats,
                            'predictions': predictions
                        }

                        # Save model if is improving on validation result
                        if opt.language_eval == 1:
                            current_score = lang_stats['CIDEr']
                        else:
                            current_score = -val_loss

                        best_flag = False

                        if best_val_score is None or current_score > best_val_score:
                            best_val_score = current_score
                            best_flag = True

                        # Dump miscalleous informations
                        infos['best_val_score'] = best_val_score
                        histories['val_result_history'] = val_result_history
                        histories['loss_history'] = loss_history
                        histories['lr_history'] = lr_history
                        histories['ss_prob_history'] = ss_prob_history

                        # save_checkpoint(model, infos, optimizer, histories, append=str(iteration))
                        save_checkpoint(model, infos, optimizer, histories)
                        # if opt.save_history_ckpt:
                        #     save_checkpoint(model, infos, optimizer, append=str(iteration))

                        if best_flag:
                            save_checkpoint(model,
                                            infos,
                                            optimizer,
                                            append='best')
                            print(
                                "update best model at {} iteration--{} epoch".
                                format(iteration, epoch))

                    start_Img_idx = 0
                    # if epoch_done: # go through the set, start a new epoch loop
                    #     break
            # Stop if reaching max epochs
            if epoch >= opt.max_epochs and opt.max_epochs != -1:
                print("epoch {} break all".format(epoch))
                save_checkpoint(model, infos, optimizer)
                tb_summary_writer.close()
                print("============{} Training Done !==============".format(
                    'Refine' if opt.use_test or opt.use_val else ''))
                break
    except (RuntimeError, KeyboardInterrupt):
        print('Save ckpt on exception ...')
        save_checkpoint(model, infos, optimizer, append='_interrupt')
        print('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)
Ejemplo n.º 5
0
def train(opt):

    ################################
    # Build dataloader
    ################################
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    ##########################
    # Initialize infos
    ##########################
    infos = {
        'iter': 0,
        'epoch': 0,
        'loader_state_dict': None,
        'vocab': loader.get_vocab(),
    }
    # Load old infos(if there is) and check if models are compatible
    if opt.start_from is not None and os.path.isfile(
            os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl')):
        with open(os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl'),
                  'rb') as f:
            infos = utils.pickle_load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert getattr(saved_model_opt, checkme) == getattr(
                    opt, checkme
                ), "Command line argument and saved model disagree on '%s' " % checkme
    infos['opt'] = opt

    #########################
    # Build logger
    #########################
    # naive dict logger
    histories = defaultdict(dict)
    if opt.start_from is not None and os.path.isfile(
            os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
        with open(os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl'),
                  'rb') as f:
            histories.update(utils.pickle_load(f))

    # tensorboard logger
    tb_summary_writer = SummaryWriter(opt.checkpoint_path)

    ##########################
    # Build model
    ##########################
    USE_CUDA = torch.cuda.is_available()
    device = torch.device("cuda:0" if USE_CUDA else "cpu")
    opt.vocab = loader.get_vocab()
    model = models.setup(opt).cuda()
    del opt.vocab
    # Load pretrained weights:
    if opt.start_from is not None and os.path.isfile(
            os.path.join(opt.start_from, 'model.pth')):
        model.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'model.pth')))

    # Wrap generation model with loss function(used for training)
    # This allows loss function computed separately on each machine
    lw_model = LossWrapper(model, opt)
    # Wrap with dataparallel
    dp_model = torch.nn.DataParallel(model)
    dp_lw_model = torch.nn.DataParallel(lw_model)
    dp_model.to(device)
    dp_lw_model.to(device)

    ##########################
    #  Build optimizer
    ##########################
    optimizer = utils.build_optimizer(model.parameters(), opt)
    # Load the optimizer
    if opt.start_from is not None and os.path.isfile(
            os.path.join(opt.start_from, "optimizer.pth")):
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    #########################
    # Get ready to start
    #########################
    iteration = infos['iter']
    epoch = infos['epoch']
    # For back compatibility
    if 'iterators' in infos:
        infos['loader_state_dict'] = {
            split: {
                'index_list': infos['split_ix'][split],
                'iter_counter': infos['iterators'][split]
            }
            for split in ['train', 'val', 'test']
        }
    loader.load_state_dict(infos['loader_state_dict'])
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)
    # flag indicating finish of an epoch
    # Always set to True at the beginning to initialize the lr or etc.
    epoch_done = True
    # Assure in training mode
    dp_lw_model.train()

    # Start training
    try:
        while True:
            # Stop if reaching max epochs
            if epoch >= opt.max_epochs and opt.max_epochs != -1:
                break

            if epoch_done:
                # Assign the learning rate
                if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                    frac = (epoch - opt.learning_rate_decay_start
                            ) // opt.learning_rate_decay_every
                    decay_factor = opt.learning_rate_decay_rate**frac
                    opt.current_lr = opt.learning_rate * decay_factor
                else:
                    opt.current_lr = opt.learning_rate
                utils.set_lr(optimizer, opt.current_lr)  # set the decayed rate
                # Assign the scheduled sampling prob
                if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                    frac = (epoch - opt.scheduled_sampling_start
                            ) // opt.scheduled_sampling_increase_every
                    opt.ss_prob = min(
                        opt.scheduled_sampling_increase_prob * frac,
                        opt.scheduled_sampling_max_prob)
                    model.ss_prob = opt.ss_prob

                # If start self critical training
                if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                    sc_flag = True
                    init_scorer(opt.cached_tokens)
                else:
                    sc_flag = False

                # If start structure loss training
                if opt.structure_after != -1 and epoch >= opt.structure_after:
                    struc_flag = True
                    init_scorer(opt.cached_tokens)
                else:
                    struc_flag = False

                epoch_done = False

            start = time.time()
            # Load data from train split (0)
            data = loader.get_batch('train')

            torch.cuda.synchronize()
            start = time.time()

            tmp = [
                data['semantic_feat'], data["semantic1_feat"],
                data['att_feats'], data["att1_feats"], data["box_feat"],
                data["box1_feat"], data['labels'], data['masks']
            ]
            tmp = [_ if _ is None else _.cuda() for _ in tmp]
            semantic_feat, semantic1_feat, att_feats, att1_feats, box_feat, box1_feat, labels, masks = tmp

            optimizer.zero_grad()
            model_out = dp_lw_model(semantic_feat, semantic1_feat, att_feats,
                                    att1_feats, box_feat, box1_feat, labels,
                                    masks, data['gts'],
                                    torch.arange(0, len(data['gts'])), sc_flag,
                                    struc_flag)

            loss = model_out['loss'].mean()

            loss.backward()
            if opt.grad_clip_value != 0:
                getattr(torch.nn.utils, 'clip_grad_%s_' %
                        (opt.grad_clip_mode))(model.parameters(),
                                              opt.grad_clip_value)
            optimizer.step()
            train_loss = loss.item()
            torch.cuda.synchronize()
            end = time.time()
            if struc_flag and iteration % opt.losses_log_every == 0:
                print("iter {} (epoch {}), train_loss = {:.3f}, lm_loss = {:.3f}, struc_loss = {:.3f}, cider = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, train_loss, model_out['lm_loss'].mean().item(), model_out['struc_loss'].mean().item(), model_out['cider'].mean().item(), end - start))
            elif not sc_flag and iteration % opt.losses_log_every == 0:
                print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, train_loss, end - start))
            else:
                if iteration % opt.losses_log_every == 0:
                    print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \
                        .format(iteration, epoch, model_out['reward'].mean(), end - start))

            # Update the iteration and epoch
            iteration += 1

            if data['bounds']['wrapped']:
                epoch += 1
                epoch_done = True

            # Write the training loss summary
            if (iteration % opt.losses_log_every == 0):
                tb_summary_writer.add_scalar('train_loss', train_loss,
                                             iteration)
                tb_summary_writer.add_scalar('learning_rate', opt.current_lr,
                                             iteration)
                tb_summary_writer.add_scalar('scheduled_sampling_prob',
                                             model.ss_prob, iteration)
                if sc_flag:
                    tb_summary_writer.add_scalar('avg_reward',
                                                 model_out['reward'].mean(),
                                                 iteration)
                elif struc_flag:
                    tb_summary_writer.add_scalar(
                        'lm_loss', model_out['lm_loss'].mean().item(),
                        iteration)
                    tb_summary_writer.add_scalar(
                        'struc_loss', model_out['struc_loss'].mean().item(),
                        iteration)
                    tb_summary_writer.add_scalar(
                        'reward', model_out['reward'].mean().item(), iteration)
                    tb_summary_writer.add_scalar(
                        'reward_var', model_out['reward'].var(1).mean(),
                        iteration)

                histories['loss_history'][
                    iteration] = train_loss if not sc_flag else model_out[
                        'reward'].mean()
                histories['lr_history'][iteration] = opt.current_lr
                histories['ss_prob_history'][iteration] = model.ss_prob

            # update infos
            infos['iter'] = iteration
            infos['epoch'] = epoch
            infos['loader_state_dict'] = loader.state_dict()

            # make evaluation on validation set, and save model
            if (iteration % opt.save_checkpoint_every == 0 and not opt.save_every_epoch) or \
                (epoch_done and opt.save_every_epoch):
                # eval model
                eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
                eval_kwargs.update(vars(opt))
                val_loss, predictions, lang_stats = eval_utils.eval_split(
                    dp_model, lw_model.crit, loader, eval_kwargs)

                # Write validation result into summary
                tb_summary_writer.add_scalar('validation loss', val_loss,
                                             iteration)

                histories['val_result_history'][iteration] = {
                    'loss': val_loss,
                    'lang_stats': lang_stats,
                    'predictions': predictions
                }

                # Save model if is improving on validation result
                if opt.language_eval == 1:
                    current_score = lang_stats
                else:
                    current_score = -val_loss
                print("val_loss = {:.3f}".format(val_loss))

                best_flag = False

                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True

                # Dump miscalleous informations
                infos['best_val_score'] = best_val_score

                utils.save_checkpoint(opt, model, infos, optimizer, histories)
                if opt.save_history_ckpt:
                    utils.save_checkpoint(
                        opt,
                        model,
                        infos,
                        optimizer,
                        append=str(epoch)
                        if opt.save_every_epoch else str(iteration))

                if best_flag:
                    utils.save_checkpoint(opt,
                                          model,
                                          infos,
                                          optimizer,
                                          append='best')

    except (RuntimeError, KeyboardInterrupt):
        print('Save ckpt on exception ...')
        utils.save_checkpoint(opt, model, infos, optimizer)
        print('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)
Ejemplo n.º 6
0
def train(opt):
    torch.cuda.set_device(opt.device)
    # opt.use_att = utils.if_use_att(opt.caption_model)
    opt.use_att = True
    if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5

    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length
    print(opt.seq_length)
    print(opt.checkpoint_path)
    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from, 'infos_'+opt.id+'.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same=["caption_model", "rnn_type", "rnn_size", "num_layers"]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')):
            with open(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    critic_loss_history = histories.get('critic_loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})
    variance_history = histories.get('variance_history', {})
    time_history = histories.get('time_history', {})
    pseudo_num_history = histories.get('pseudo_num_history', {})
    pseudo_num_length_history = histories.get('pseudo_num_length_history', {})
    pseudo_num_batch_history = histories.get('pseudo_num_batch_history', {})
    sum_logits_history = histories.get('sum_logits_history', {})
    reward_main_history = histories.get('reward_main_history', {})
    first_order = histories.get('first_order_history', np.zeros(1))
    second_order = histories.get('second_order_history', np.zeros(1))
    first_order = torch.from_numpy(first_order).float().cuda()
    second_order = torch.from_numpy(second_order).float().cuda()

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt).cuda()
    dp_model = model

    target_actor = models.setup(opt).cuda()

    ####################### Critic pretrain #####################################################################
    ##### Critic with state as input
    # if opt.critic_model == 'state_critic':
    #     critic_model = CriticModel(opt)
    # else:
    critic_model = AttCriticModel(opt)
    target_critic = AttCriticModel(opt)
    if vars(opt).get('start_from_critic', None) is not None and True:
        # check if all necessary files exist
        assert os.path.isdir(opt.start_from_critic), " %s must be a a path" % opt.start_from_critic
        print(os.path.join(opt.start_from_critic, opt.critic_model + '_model.pth'))
        critic_model.load_state_dict(torch.load(os.path.join(opt.start_from_critic, opt.critic_model + '_model.pth')))
        target_critic.load_state_dict(torch.load(os.path.join(opt.start_from_critic, opt.critic_model + '_model.pth')))
    critic_model = critic_model.cuda()
    target_critic = target_critic.cuda()
    critic_optimizer = utils.build_optimizer(critic_model.parameters(), opt)
    dp_model.eval()
    critic_iter = 0
    init_scorer(opt.cached_tokens)
    critic_model.train()
    error_sum = 0
    loss_vector_sum = 0
    while opt.pretrain_critic == 1:
        if critic_iter > opt.pretrain_critic_steps:
            print('****************Finished critic training!')
            break
        data = loader.get_batch('train')
        torch.cuda.synchronize()
        start = time.time()
        tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks']]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        fc_feats, att_feats, labels, masks, att_masks = tmp
        critic_model.train()
        critic_optimizer.zero_grad()
        # assert opt.critic_model == 'att_critic_vocab'
        # crit_loss, reward, std = critic_loss_fun(fc_feats, att_feats, att_masks, dp_model, critic_model, opt, data)
        crit_loss, reward, std = target_critic_loss_fun_mask(fc_feats, att_feats, att_masks, dp_model, critic_model, opt, data, target_critic, target_actor)
        crit_loss.backward()
        critic_optimizer.step()
        #TODO update target.
        for cp, tp in zip(critic_model.parameters(), target_critic.parameters()):
            tp.data = tp.data + opt.gamma_critic * (cp.data - tp.data)
        crit_train_loss = crit_loss.item()
        torch.cuda.synchronize()
        end = time.time()
        error_sum += crit_train_loss**0.5-std
        if (critic_iter % opt.losses_log_every == 0):
            print("iter {} , crit_train_loss = {:.3f}, difference = {:.3f}, difference_sum = {:.3f}, time/batch = {:.3f}" \
                .format(critic_iter, crit_train_loss**0.5, crit_train_loss**0.5-std, error_sum, end - start))
            print(opt.checkpoint_path)
            opt.importance_sampling = 1
            critic_model.eval()
            _, _, _, _ = get_rf_loss(dp_model, fc_feats, att_feats, att_masks, data, opt, loader, critic_model, test_critic=True)

        critic_iter += 1

        # make evaluation on validation set, and save model
        if (critic_iter % opt.save_checkpoint_every == 0):
            if not os.path.isdir(opt.checkpoint_path):
                os.mkdir(opt.checkpoint_path)
            checkpoint_path = os.path.join(opt.checkpoint_path, opt.critic_model + '_model.pth')
            torch.save(critic_model.state_dict(), checkpoint_path)

    ######################### Actor-critic Training #####################################################################

    update_lr_flag = True
    # Assure in training mode
    dp_model.train()

    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()

    optimizer = utils.build_optimizer(model.parameters(), opt)
    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(os.path.join(opt.start_from,"optimizer.pth")):
        optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    # first_order = 0
    # second_order = 0
    while True:
        if update_lr_flag:
                # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate  ** frac
                opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate
            utils.set_lr(optimizer, opt.current_lr)
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob  * frac, opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            update_lr_flag = False

        # Load data from train split (0)
        data = loader.get_batch('train')
        # if data['bounds']['it_pos_now'] > 5000:
        #     loader.reset_iterator('train')
        #     continue
        dp_model.train()
        critic_model.eval()

        torch.cuda.synchronize()
        start = time.time()
        gen_result = None
        tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks']]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        fc_feats, att_feats, labels, masks, att_masks = tmp
        optimizer.zero_grad()
        if not sc_flag:
            loss = crit(dp_model(fc_feats, att_feats, labels, att_masks), labels[:,1:], masks[:,1:])
        else:
            if opt.rl_type == 'sc':
                gen_result, sample_logprobs = dp_model(fc_feats, att_feats, att_masks, opt={'sample_max':0}, mode='sample')
                reward = get_self_critical_reward(dp_model, fc_feats, att_feats, att_masks, data, gen_result, opt)
                loss = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda())
                pseudo_num = 0
                pseudo_num_length = 0
            elif opt.rl_type == 'reinforce':
                gen_result, sample_logprobs = dp_model(fc_feats, att_feats, att_masks, opt={'sample_max':0}, mode='sample')
                reward = get_reward(data, gen_result, opt)
                loss = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda())
                pseudo_num_length = 0
                pseudo_num = 0

            elif opt.rl_type == 'arsm':
                loss, pseudo_num, pseudo_num_length, pseudo_num_batch, rewards_main, sum_logits = get_arm_loss_daniel(dp_model, fc_feats, att_feats, att_masks, data, opt, loader)
                #print(loss)
                reward = np.zeros([2,2])
            elif opt.rl_type == 'rf4':
                loss,_,_,_ = get_rf_loss(dp_model, fc_feats, att_feats, att_masks, data, opt, loader)
                # print(loss)
                reward = np.zeros([2, 2])
            elif opt.rl_type == 'importance_sampling':
                opt.importance_sampling = 1
                loss, gen_result, reward, sample_logprobs_total = get_rf_loss(dp_model, fc_feats, att_feats, att_masks, data, opt, loader)
                reward = np.repeat(reward[:, np.newaxis], gen_result.shape[1], 1)
                std = np.std(reward)
            elif opt.rl_type == 'importance_sampling_critic':
                opt.importance_sampling = 1
                loss, gen_result, reward, sample_logprobs_total = get_rf_loss(target_actor, fc_feats, att_feats, att_masks, data, opt, loader, target_critic)
                reward = np.repeat(reward[:, np.newaxis], gen_result.shape[1], 1)
                std = np.std(reward)
            elif opt.rl_type == 'ar':
                loss = get_ar_loss(dp_model, fc_feats, att_feats, att_masks, data, opt, loader)
                reward = np.zeros([2,2])
            elif opt.rl_type == 'mct':
                opt.rf_demean = 0
                gen_result, sample_logprobs, probs, mct_baseline = get_mct_loss(dp_model, fc_feats, att_feats,
                                                                                att_masks, data,
                                                                                opt, loader)
                reward = get_reward(data, gen_result, opt)
                pseudo_num = 0
                pseudo_num_length = 0
                reward_cuda = torch.from_numpy(reward).float().cuda()
                mct_baseline[mct_baseline < 0] = reward_cuda[mct_baseline < 0]
                final_reward = torch.cat([mct_baseline[:, 1:], reward_cuda[:, 0:1]], 1)
                final_reward = final_reward - torch.mean(final_reward)
                if opt.arm_step_sample == 'greedy':
                    sample_logprobs = sample_logprobs * probs
                loss = rl_crit(sample_logprobs, gen_result.data, final_reward)
            elif opt.rl_type == 'mct_sc':
                opt.rf_demean = 0
                gen_result, sample_logprobs, probs, mct_baseline = get_mct_loss(dp_model, fc_feats, att_feats,
                                                                                att_masks, data,
                                                                                opt, loader)
                reward = get_reward(data, gen_result, opt)
                pseudo_num = 0
                pseudo_num_length = 0
                reward_cuda = torch.from_numpy(reward).float().cuda()
                mct_baseline[mct_baseline < 0] = reward_cuda[mct_baseline < 0]
                final_reward = torch.cat([mct_baseline[:, 1:], reward_cuda[:, 0:1]], 1)
                gen_result_sc, sample_logprobs_sc = dp_model(fc_feats, att_feats, att_masks, opt={'sample_max': 1},
                                                       mode='sample')
                reward = get_reward(data, gen_result_sc, opt)
                final_reward = final_reward - torch.from_numpy(reward).float().cuda()
                loss = rl_crit(sample_logprobs, gen_result.data, final_reward)
            elif opt.rl_type == 'mct_critic':
                #TODO change the critic to attention
                if opt.critic_model == 'state_critic':
                    opt.rf_demean = 0
                    gen_result, sample_logprobs, probs, mct_baseline = get_mct_loss(dp_model, fc_feats, att_feats,
                                                                                    att_masks, data,
                                                                                    opt, loader)
                    gen_result_pad = torch.cat(
                        [gen_result.new_zeros(gen_result.size(0), 1, dtype=torch.long), gen_result], 1)
                    critic_value = critic_model(gen_result_pad, fc_feats, att_feats, True, opt, att_masks).squeeze(2)
                    reward, std = get_reward(data, gen_result, opt, critic=True)
                    pseudo_num = 0
                    pseudo_num_length = 0
                    reward_cuda = torch.from_numpy(reward).float().cuda()
                    mct_baseline[mct_baseline < 0] = reward_cuda[mct_baseline < 0]
                    final_reward = torch.cat([mct_baseline[:, 1:], reward_cuda[:, 0:1]], 1)
                    print(critic_value.shape)
                    loss = rl_crit(sample_logprobs, gen_result.data, final_reward - critic_value)




                    critic_value, gen_result, sample_logprobs = critic_model(dp_model, fc_feats, att_feats, opt, att_masks)
                    reward, std = get_reward(data, gen_result, opt, critic=True)
                    loss = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda() - critic_value[:,:-1].data)
                elif opt.critic_model == 'att_critic':
                    opt.rf_demean = 0
                    gen_result, sample_logprobs, probs, mct_baseline = get_mct_loss(dp_model, fc_feats, att_feats,
                                                                                    att_masks, data,
                                                                                    opt, loader)
                    gen_result_pad = torch.cat(
                        [gen_result.new_zeros(gen_result.size(0), 1, dtype=torch.long), gen_result], 1)
                    critic_value = critic_model(gen_result_pad, fc_feats, att_feats, True, opt, att_masks).squeeze(2)
                    reward, std = get_reward(data, gen_result, opt, critic=True)
                    pseudo_num = 0
                    pseudo_num_length = 0
                    reward_cuda = torch.from_numpy(reward).float().cuda()
                    mct_baseline[mct_baseline < 0] = reward_cuda[mct_baseline < 0]
                    final_reward = torch.cat([mct_baseline[:, 1:], reward_cuda[:, 0:1]], 1)
                    print(critic_value.shape)
                    loss = rl_crit(sample_logprobs, gen_result.data, final_reward - critic_value)
            elif opt.rl_type =='mct_baseline':
                opt.rf_demean = 0
                gen_result, sample_logprobs, probs, mct_baseline = get_mct_loss(dp_model, fc_feats, att_feats, att_masks, data,
                                                                         opt, loader)
                reward = get_reward(data, gen_result, opt)
                pseudo_num = 0
                pseudo_num_length = 0
                reward_cuda = torch.from_numpy(reward).float().cuda()
                mct_baseline[mct_baseline < 0] = reward_cuda[mct_baseline < 0]
                if opt.arm_step_sample == 'greedy':
                    sample_logprobs = sample_logprobs * probs
                loss = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda() - mct_baseline)
            elif opt.rl_type == 'arsm_baseline':
                opt.arm_as_baseline = 1
                opt.rf_demean = 0
                gen_result, sample_logprobs, probs, arm_baseline = get_arm_loss(dp_model, fc_feats, att_feats, att_masks, data, opt, loader)
                reward = get_reward(data, gen_result, opt)
                reward_cuda = torch.from_numpy(reward).float().cuda()
                arm_baseline[arm_baseline < 0] = reward_cuda[arm_baseline < 0]
                if opt.arm_step_sample == 'greedy' and False:
                    sample_logprobs = sample_logprobs * probs
                loss = rl_crit(sample_logprobs, gen_result.data, reward_cuda - arm_baseline)
            elif opt.rl_type == 'ars_indicator':
                opt.arm_as_baseline = 1
                opt.rf_demean = 0
                gen_result, sample_logprobs, probs, arm_baseline = get_arm_loss(dp_model, fc_feats, att_feats, att_masks, data, opt, loader)
                reward = get_self_critical_reward(dp_model, fc_feats, att_feats, att_masks, data, gen_result, opt)
                reward_cuda = torch.from_numpy(reward).float().cuda()
                loss = rl_crit(sample_logprobs, gen_result.data, reward_cuda * arm_baseline)
            elif opt.rl_type == 'arsm_baseline_critic':
                opt.arm_as_baseline = 1
                opt.rf_demean = 0
                gen_result, sample_logprobs, probs, arm_baseline = get_arm_loss(dp_model, fc_feats, att_feats, att_masks, data, opt, loader, critic_model)
                reward, std = get_reward(data, gen_result, opt, critic=True)
                if opt.arm_step_sample == 'greedy':
                    sample_logprobs = sample_logprobs * probs
                loss = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda() - arm_baseline)
            elif opt.rl_type == 'arsm_critic':
                #print(opt.critic_model)
                tic = time.time()
                loss = get_arm_loss(dp_model, fc_feats, att_feats, att_masks, data, opt, loader, critic_model)
                #print('arm_loss time', str(time.time()-tic))
                reward = np.zeros([2, 2])
            elif opt.rl_type == 'critic_vocab_sum':
                assert opt.critic_model == 'att_critic_vocab'
                tic = time.time()
                gen_result, sample_logprobs_total = dp_model(fc_feats, att_feats, att_masks, opt={'sample_max': 0}, total_probs=True,
                                                       mode='sample') #batch, seq, vocab
                #print('generation time', time.time()-tic)
                gen_result_pad = torch.cat(
                    [gen_result.new_zeros(gen_result.size(0), 1, dtype=torch.long), gen_result], 1)
                tic = time.time()
                critic_value = critic_model(gen_result_pad, fc_feats, att_feats, True, opt, att_masks) #batch, seq, vocab
                #print('critic time', time.time() - tic)
                probs = torch.sum(F.softmax(sample_logprobs_total, 2) * critic_value.detach(), 2)
                mask = (gen_result > 0).float()
                mask = torch.cat([mask.new(mask.size(0), 1).fill_(1), mask[:, :-1]], 1)
                loss = -torch.sum(probs * mask) / torch.sum(mask)
                reward = np.zeros([2, 2])
            elif opt.rl_type == 'reinforce_critic':
                #TODO change the critic to attention
                if opt.critic_model == 'state_critic':
                    critic_value, gen_result, sample_logprobs = critic_model(dp_model, fc_feats, att_feats, opt, att_masks)
                    reward, std = get_reward(data, gen_result, opt, critic=True)
                    loss = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda() - critic_value[:,:-1].data)
                elif opt.critic_model == 'att_critic':
                    gen_result, sample_logprobs = dp_model(fc_feats, att_feats, att_masks, opt={'sample_max': 0},
                                                           mode='sample')
                    gen_result_pad = torch.cat(
                        [gen_result.new_zeros(gen_result.size(0), 1, dtype=torch.long), gen_result], 1)
                    critic_value = critic_model(gen_result_pad, fc_feats, att_feats, True, opt, att_masks).squeeze(2)

                    reward, std = get_reward(data, gen_result, opt, critic=True)
                    loss = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda() - critic_value.data)
        if opt.mle_weights != 0:
            loss += opt.mle_weights * crit(dp_model(fc_feats, att_feats, labels, att_masks), labels[:, 1:], masks[:, 1:])
        #TODO make sure all sampling replaced by greedy for critic
        #### update the actor
        loss.backward()
        # with open(os.path.join(opt.checkpoint_path, 'best_embed.pkl'), 'wb') as f:
        #     cPickle.dump(list(dp_model.embed.parameters())[0].data.cpu().numpy(), f)
        # with open(os.path.join(opt.checkpoint_path, 'best_logit.pkl'), 'wb') as f:
        #     cPickle.dump(list(dp_model.logit.parameters())[0].data.cpu().numpy(), f)
        ## compute variance
        gradient = torch.zeros([0]).cuda()
        for i in model.parameters():
            gradient = torch.cat((gradient, i.grad.view(-1)), 0)
        first_order = 0.9999 * first_order + 0.0001 * gradient
        second_order = 0.9999 * second_order + 0.0001 * gradient.pow(2)
        # print(torch.max(torch.abs(gradient)))
        variance = torch.mean(torch.abs(second_order - first_order.pow(2))).item()
        if opt.rl_type != 'arsm' or not sc_flag:
            utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        # ### update the critic
        if 'critic' in opt.rl_type:
            dp_model.eval()
            critic_model.train()
            utils.set_lr(critic_optimizer, opt.critic_learning_rate)
            critic_optimizer.zero_grad()
            #assert opt.critic_model == 'att_critic_vocab'
            crit_loss, reward, std = target_critic_loss_fun_mask(fc_feats, att_feats, att_masks, dp_model, critic_model, opt,
                                                            data, target_critic, target_actor, gen_result=gen_result, sample_logprobs_total=sample_logprobs_total, reward=reward)
            crit_loss.backward()
            critic_optimizer.step()
            for cp, tp in zip(critic_model.parameters(), target_critic.parameters()):
                tp.data = tp.data + opt.gamma_critic * (cp.data - tp.data)
            for cp, tp in zip(dp_model.parameters(), target_actor.parameters()):
                tp.data = tp.data + opt.gamma_actor * (cp.data - tp.data)
            crit_train_loss = crit_loss.item()
            error_sum += crit_train_loss ** 0.5 - std
        train_loss = loss.item()
        torch.cuda.synchronize()
        end = time.time()
        if (iteration % opt.losses_log_every == 0):
            if not sc_flag:
                print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, train_loss, end - start))
                print(opt.checkpoint_path)
            elif 'critic' in opt.rl_type:
                print(
                    "iter {} , crit_train_loss = {:.3f}, difference = {:.3f}, difference_sum = {:.3f},variance = {:g}, time/batch = {:.3f}" \
                    .format(iteration, crit_train_loss ** 0.5, crit_train_loss ** 0.5 - std, error_sum, variance, end - start))
                print(opt.checkpoint_path)
                critic_model.eval()
                _, _, _, _ = get_rf_loss(dp_model, fc_feats, att_feats, att_masks, data, opt, loader, critic_model, test_critic=True)
            else:
                print("iter {} (epoch {}), avg_reward = {:.3f}, variance = {:g}, time/batch = {:.3f}" \
                      .format(iteration, epoch, np.mean(reward[:, 0]), variance, end - start))
                print("pseudo num: ", pseudo_num)

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration)
            add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration)
            add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration)
            if sc_flag:
                add_summary_value(tb_summary_writer, 'avg_reward', np.mean(reward), iteration)
                add_summary_value(tb_summary_writer, 'variance', variance, iteration)

            loss_history[iteration] = train_loss if not sc_flag else np.mean(reward)
            critic_loss_history[iteration] = crit_train_loss if 'critic' in opt.rl_type else 0
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob
            variance_history[iteration] = variance
            pseudo_num_history[iteration] = pseudo_num
            reward_main_history[iteration] = rewards_main
            #print(pseudo_num_length)
            #print(type(pseudo_num_length).__module__)
            if type(pseudo_num_length).__module__ != 'torch':
                print('not right')
                pseudo_num_length_history[iteration] = pseudo_num_length
                pseudo_num_batch_history[iteration] = pseudo_num_batch
                sum_logits_history[iteration] = sum_logits
            else:
                pseudo_num_length_history[iteration] = pseudo_num_length.data.cpu().numpy()
                pseudo_num_batch_history[iteration] = pseudo_num_batch.data.cpu().numpy()
                sum_logits_history[iteration] = sum_logits.data.cpu().numpy()
            time_history[iteration] = end - start


        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            # eval model
            eval_kwargs = {'split': 'val',
                            'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = eval_utils.eval_split(dp_model, crit, loader, eval_kwargs)

            # Write validation result into summary
            add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration)
            if lang_stats is not None:
                for k,v in lang_stats.items():
                    add_summary_value(tb_summary_writer, k, v, iteration)
            val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions}

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = - val_loss

            best_flag = False
            if True: # if true
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                if not os.path.isdir(opt.checkpoint_path):
                    os.mkdir(opt.checkpoint_path)
                checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                checkpoint_path = os.path.join(opt.checkpoint_path, opt.critic_model + '_model.pth')
                torch.save(critic_model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['critic_loss_history'] = critic_loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                histories['variance_history'] = variance_history
                histories['pseudo_num_history'] = pseudo_num_history
                histories['pseudo_num_length_history'] = pseudo_num_length_history
                histories['pseudo_num_batch_history'] = pseudo_num_batch_history
                histories['sum_logits_history'] = sum_logits_history
                histories['reward_main_history'] = reward_main_history
                histories['time'] = time_history
                histories['first_order_history'] = first_order.data.cpu().numpy()
                histories['second_order_history'] = second_order.data.cpu().numpy()
                # histories['variance'] = 0
                with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(os.path.join(opt.checkpoint_path, 'histories_'+opt.id+'.pkl'), 'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path, 'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'-best.pkl'), 'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
Ejemplo n.º 7
0
def train(opt):
    # opt.use_att = utils.if_use_att(opt.caption_model)
    opt.use_att = True
    if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5

    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length
    print(opt.checkpoint_path)
    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from,
                               'infos_' + opt.id + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    critic_loss_history = histories.get('critic_loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})
    variance_history = histories.get('variance_history', {})
    time_history = histories.get('time_history', {})
    pseudo_num_history = histories.get('pseudo_num_history', {})
    pseudo_num_depth_history = histories.get('pseudo_num_depth_history', {})
    pseudo_num_length_history = histories.get('pseudo_num_length_history', {})
    pseudo_num_batch_history = histories.get('pseudo_num_batch_history', {})
    reward_batch_history = histories.get('reward_batch_history', {})
    entropy_batch_history = histories.get('entropy_batch_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt).cuda()
    dp_model = model

    ######################### Actor-critic Training #####################################################################

    update_lr_flag = True
    # Assure in training mode
    dp_model.train()
    #TODO: change this to a flag
    crit = utils.LanguageModelCriterion_binary()
    rl_crit = utils.RewardCriterion_binary()

    optimizer = utils.build_optimizer(model.parameters(), opt)
    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(
            os.path.join(opt.start_from, "optimizer.pth")):
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    first_order = 0
    second_order = 0
    while True:
        if update_lr_flag:
            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate
            utils.set_lr(optimizer, opt.current_lr)
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            update_lr_flag = False

        # Load data from train split (0)
        data = loader.get_batch('train')
        if data['bounds']['it_pos_now'] > 10000:
            loader.reset_iterator('train')
            continue
        dp_model.train()

        torch.cuda.synchronize()
        start = time.time()
        gen_result = None
        tmp = [
            data['fc_feats'], data['att_feats'], data['labels'], data['masks'],
            data['att_masks']
        ]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        fc_feats, att_feats, labels, masks, att_masks = tmp
        optimizer.zero_grad()
        if not sc_flag:
            loss = crit(dp_model(fc_feats, att_feats, labels,
                                 att_masks), labels[:, 1:], masks[:, 1:],
                        dp_model.depth, dp_model.vocab2code, dp_model.phi_list,
                        dp_model.cluster_size)
        else:
            if opt.rl_type == 'sc':
                gen_result, sample_logprobs = dp_model(fc_feats,
                                                       att_feats,
                                                       att_masks,
                                                       opt={'sample_max': 0},
                                                       mode='sample')
                reward = get_self_critical_reward(dp_model, fc_feats,
                                                  att_feats, att_masks, data,
                                                  gen_result, opt)
                loss = rl_crit(sample_logprobs, gen_result.data,
                               torch.from_numpy(reward).float().cuda(),
                               dp_model.depth)
            elif opt.rl_type == 'reinforce':
                gen_result, sample_logprobs = dp_model(fc_feats,
                                                       att_feats,
                                                       att_masks,
                                                       opt={'sample_max': 0},
                                                       mode='sample')
                reward = get_reward(data, gen_result, opt)
                loss = rl_crit(sample_logprobs, gen_result.data,
                               torch.from_numpy(reward).float().cuda(),
                               dp_model.depth)
            elif opt.rl_type == 'arm':
                loss, pseudo_num, pseudo_num_depth, pseudo_num_length, pseudo_num_batch, reward_batch, entropy_batch = dp_model.get_arm_loss_binary_fast(
                    fc_feats, att_feats, att_masks, opt, data, loader)
                #print(loss)
                reward = np.zeros([2, 2])
            elif opt.rl_type == 'rf4':
                loss, _, _, _ = get_rf_loss(dp_model, fc_feats, att_feats,
                                            att_masks, data, opt, loader)
                # print(loss)
                reward = np.zeros([2, 2])
            elif opt.rl_type == 'ar':
                loss = get_ar_loss(dp_model, fc_feats, att_feats, att_masks,
                                   data, opt, loader)
                reward = np.zeros([2, 2])
            elif opt.rl_type == 'mct_baseline':
                opt.rf_demean = 0
                gen_result, sample_logprobs, probs, mct_baseline = get_mct_loss(
                    dp_model, fc_feats, att_feats, att_masks, data, opt,
                    loader)
                reward = get_reward(data, gen_result, opt)
                reward_cuda = torch.from_numpy(reward).float().cuda()
                mct_baseline[mct_baseline < 0] = reward_cuda[mct_baseline < 0]
                if opt.arm_step_sample == 'greedy':
                    sample_logprobs = sample_logprobs * probs
                loss = rl_crit(
                    sample_logprobs, gen_result.data,
                    torch.from_numpy(reward).float().cuda() - mct_baseline)
            elif opt.rl_type == 'arsm_baseline':
                opt.arm_as_baseline = 1
                opt.rf_demean = 0
                gen_result, sample_logprobs, probs, arm_baseline = get_arm_loss(
                    dp_model, fc_feats, att_feats, att_masks, data, opt,
                    loader)
                reward = get_reward(data, gen_result, opt)
                reward_cuda = torch.from_numpy(reward).float().cuda()
                arm_baseline[arm_baseline < 0] = reward_cuda[arm_baseline < 0]
                if opt.arm_step_sample == 'greedy' and False:
                    sample_logprobs = sample_logprobs * probs
                loss = rl_crit(sample_logprobs, gen_result.data,
                               reward_cuda - arm_baseline)
            elif opt.rl_type == 'ars_indicator':
                opt.arm_as_baseline = 1
                opt.rf_demean = 0
                gen_result, sample_logprobs, probs, arm_baseline = get_arm_loss(
                    dp_model, fc_feats, att_feats, att_masks, data, opt,
                    loader)
                reward = get_self_critical_reward(dp_model, fc_feats,
                                                  att_feats, att_masks, data,
                                                  gen_result, opt)
                reward_cuda = torch.from_numpy(reward).float().cuda()
                loss = rl_crit(sample_logprobs, gen_result.data,
                               reward_cuda * arm_baseline)
        if opt.mle_weights != 0:
            loss += opt.mle_weights * crit(
                dp_model(fc_feats, att_feats, labels, att_masks),
                labels[:, 1:], masks[:, 1:])
        #TODO make sure all sampling replaced by greedy for critic
        #### update the actor
        loss.backward()
        # with open(os.path.join(opt.checkpoint_path, 'embeddings.pkl'), 'wb') as f:
        #     cPickle.dump(list(dp_model.embed.parameters())[0].data.cpu().numpy(), f)
        ## compute variance
        gradient = torch.zeros([0]).cuda()
        for i in model.parameters():
            gradient = torch.cat((gradient, i.grad.view(-1)), 0)
        first_order = 0.999 * first_order + 0.001 * gradient
        second_order = 0.999 * second_order + 0.001 * gradient.pow(2)
        # print(torch.max(torch.abs(gradient)))
        variance = torch.mean(torch.abs(second_order -
                                        first_order.pow(2))).item()
        if opt.rl_type != 'arsm' or not sc_flag:
            utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        # ### update the critic

        train_loss = loss.item()
        torch.cuda.synchronize()
        end = time.time()
        if (iteration % opt.losses_log_every == 0):
            if not sc_flag:
                print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, train_loss, end - start))
                print(opt.checkpoint_path)
            else:
                print("iter {} (epoch {}), avg_reward = {:.3f}, variance = {:g}, time/batch = {:.3f}, pseudo num = {:.3f}, " \
                      .format(iteration, epoch, np.mean(reward[:, 0]), variance, end - start, pseudo_num))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            add_summary_value(tb_summary_writer, 'train_loss', train_loss,
                              iteration)
            add_summary_value(tb_summary_writer, 'learning_rate',
                              opt.current_lr, iteration)
            add_summary_value(tb_summary_writer, 'scheduled_sampling_prob',
                              model.ss_prob, iteration)
            if sc_flag:
                add_summary_value(tb_summary_writer, 'avg_reward',
                                  np.mean(reward), iteration)
                add_summary_value(tb_summary_writer, 'variance', variance,
                                  iteration)

            loss_history[iteration] = train_loss if not sc_flag else np.mean(
                reward)
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob
            variance_history[iteration] = variance
            time_history[iteration] = end - start
            pseudo_num_history[iteration] = pseudo_num.item()
            pseudo_num_length_history[iteration] = pseudo_num_length.data.cpu(
            ).numpy()
            pseudo_num_depth_history[iteration] = pseudo_num_depth.data.cpu(
            ).numpy()
            pseudo_num_batch_history[iteration] = pseudo_num_batch.data.cpu(
            ).numpy()
            reward_batch_history[iteration] = reward_batch
            entropy_batch_history[iteration] = entropy_batch.data.cpu().numpy()

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            # eval model
            eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = eval_utils_binary.eval_split(
                dp_model, crit, loader, eval_kwargs)
            print('1')
            # Write validation result into summary
            add_summary_value(tb_summary_writer, 'validation loss', val_loss,
                              iteration)
            if lang_stats is not None:
                for k, v in lang_stats.items():
                    add_summary_value(tb_summary_writer, k, v, iteration)
            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss
            print('2')
            best_flag = False
            if True:  # if true
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                if not os.path.isdir(opt.checkpoint_path):
                    os.mkdir(opt.checkpoint_path)
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print('3')
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               opt.critic_model + '_model.pth')
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)
                print('4')
                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['critic_loss_history'] = critic_loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                histories['variance_history'] = variance_history
                histories['time'] = time_history
                histories['pseudo_num_history'] = pseudo_num_history
                histories[
                    'pseudo_num_length_history'] = pseudo_num_length_history
                histories[
                    'pseudo_num_depth_history'] = pseudo_num_depth_history
                histories[
                    'pseudo_num_batch_history'] = pseudo_num_batch_history
                histories['reward_batch_history'] = reward_batch_history
                histories['entropy_batch_history'] = entropy_batch_history
                # histories['variance'] = 0
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'infos_' + opt.id + '.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'histories_' + opt.id + '.pkl'),
                        'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos_' + opt.id + '-best.pkl'),
                            'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
Ejemplo n.º 8
0
def train(opt):
    # Load data
    print('Loading dataset...')
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length
    loader.input_encoding_size = opt.input_encoding_size

    BERT_features = None
    if opt.cached_bert_features == "":
        # Extract BERT features
        print('Extracting pretrained BERT features...')
        BERT_features = process_bert.extract_BERT_features(loader, opt)
        with open(opt.data_path + 'BERT_features.pkl', 'wb') as f:
            pickle.dump(BERT_features, f)
    else:
        # Load BERT tokenization results
        print('Loading pretrained BERT features...')
        with open(opt.data_path + 'BERT_features.pkl', 'rb') as f:
            BERT_features = pickle.load(f)

    bert_vocab_path = opt.data_path + 'bert-base-cased-vocab.txt'
    opt.vocab_size = loader.update_bert_tokens(bert_vocab_path, BERT_features)
    print('Vocabulary size: ' + str(opt.vocab_size))

    # Load pretrained model, info file, histories file
    infos = {}
    histories = {}
    if opt.start_from is not None:
        with open(os.path.join(opt.start_from,
                               'infos_' + opt.id + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = ["rnn_type", "rnn_size", "num_layers"]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme
        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl')) as f:
                histories = cPickle.load(f)
    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)
    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})
    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    # Create model
    model = models.setup(opt).cuda()
    dp_model = torch.nn.DataParallel(model)
    dp_model.train()

    # Loss function
    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()

    # Optimizer and learning rate adjustment flag
    optimizer = utils.build_optimizer(model.parameters(), opt)
    update_lr_flag = True

    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(
            os.path.join(opt.start_from, "optimizer.pth")):
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    # Training loop
    while True:
        # Update learning rate once per epoch
        if update_lr_flag:

            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate
            utils.set_lr(optimizer, opt.current_lr)

            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            update_lr_flag = False

        # Load data from train split (0)
        start = time.time()
        data = loader.get_batch('train')
        data_time = time.time() - start
        start = time.time()

        # Unpack data
        torch.cuda.synchronize()
        tmp = [
            data['bert_feats'], data['fc_feats'], data['att_feats'],
            data['labels'], data['masks'], data['att_masks']
        ]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        bert_feats, fc_feats, att_feats, labels, masks, att_masks = tmp
        bert_feats.requires_grad = False

        # Forward pass and loss
        optimizer.zero_grad()
        outputs = dp_model(bert_feats, fc_feats, att_feats, labels, att_masks)
        loss = crit(outputs, labels[:, 1:], masks[:, 1:])

        # Backward pass
        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        train_loss = loss.item()
        torch.cuda.synchronize()

        # Print
        total_time = time.time() - start
        if iteration % opt.print_freq == 1:
            print("iter {} (epoch {}), train_loss = {:.3f}, data_time = {:.3f}, time/batch = {:.3f}" \
                .format(iteration, epoch, train_loss, data_time, total_time))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Validate and save model
        if (iteration % opt.save_checkpoint_every == 0):

            # Evaluate model
            eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = eval_utils.eval_split(
                dp_model, crit, loader, eval_kwargs)

            # Our metric is CIDEr if available, otherwise validation loss
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            # Save model in checkpoint path
            best_flag = False
            if best_val_score is None or current_score > best_val_score:
                best_val_score = current_score
                best_flag = True
            checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth')
            torch.save(model.state_dict(), checkpoint_path)
            print("model saved to {}".format(checkpoint_path))
            optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth')
            torch.save(optimizer.state_dict(), optimizer_path)

            # Dump miscalleous informations
            infos['iter'] = iteration
            infos['epoch'] = epoch
            infos['iterators'] = loader.iterators
            infos['split_ix'] = loader.split_ix
            infos['best_val_score'] = best_val_score
            infos['opt'] = opt
            infos['vocab'] = loader.get_vocab()
            histories['val_result_history'] = val_result_history
            histories['loss_history'] = loss_history
            histories['lr_history'] = lr_history
            histories['ss_prob_history'] = ss_prob_history
            with open(
                    os.path.join(opt.checkpoint_path,
                                 'infos_' + opt.id + '.pkl'), 'wb') as f:
                cPickle.dump(infos, f)
            with open(
                    os.path.join(opt.checkpoint_path,
                                 'histories_' + opt.id + '.pkl'), 'wb') as f:
                cPickle.dump(histories, f)

            # Save model to unique file if new best model
            if best_flag:
                model_fname = 'model-best-i{:05d}-score{:.4f}.pth'.format(
                    iteration, best_val_score)
                infos_fname = 'model-best-i{:05d}-infos.pkl'.format(iteration)
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               model_fname)
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                with open(os.path.join(opt.checkpoint_path, infos_fname),
                          'wb') as f:
                    cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
Ejemplo n.º 9
0
def train(opt):
    # Deal with feature things before anything
    opt.use_att = utils.if_use_att(opt.caption_model)
    opt.use_fc = utils.if_use_fc(opt.caption_model)

    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    infos = load_info(opt)
    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)
    val_result_history = infos.get('val_result_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)

    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    # Define and load model, optimizer, critics
    decoder = setup(opt).train().cuda()
    crit = utils.LanguageModelCriterion().cuda()
    rl_crit = utils.RewardCriterion().cuda()
    optimizer = utils.build_optimizer(decoder.parameters(), opt)
    models = {'decoder': decoder}
    optimizers = {'decoder': optimizer}
    save_nets_structure(models, opt)
    load_checkpoint(models, optimizers, opt)

    epoch_done = True
    sc_flag = False
    while True:
        if epoch_done:
            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate
            utils.set_lr(optimizer, opt.current_lr)
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                decoder.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(opt.cached_tokens)
            else:
                sc_flag = False
            epoch_done = False

        # 1. fetch a batch of data from train split
        data = loader.get_batch('train')
        tmp = [
            data['fc_feats'], data['att_feats'], data['labels'], data['masks'],
            data['att_masks']
        ]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        fc_feats, att_feats, labels, masks, att_masks = tmp
        sg_data = {
            key: data['sg_data'][key] if data['sg_data'][key] is None else
            torch.from_numpy(data['sg_data'][key]).cuda()
            for key in data['sg_data']
        }

        # 2. Forward model and compute loss
        torch.cuda.synchronize()
        optimizer.zero_grad()
        if not sc_flag:
            out = decoder(sg_data, fc_feats, att_feats, labels, att_masks)
            loss = crit(out, labels[:, 1:], masks[:, 1:])
        else:
            gen_result, sample_logprobs, core_args = decoder(
                sg_data,
                fc_feats,
                att_feats,
                att_masks,
                opt={
                    'sample_max': 0,
                    'return_core_args': True
                },
                mode='sample')
            reward = get_self_critical_reward(decoder, core_args, sg_data,
                                              fc_feats, att_feats, att_masks,
                                              data, gen_result, opt)
            loss = rl_crit(sample_logprobs, gen_result.data,
                           torch.from_numpy(reward).float().cuda())

        # 3. Update model
        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        train_loss = loss.item()
        torch.cuda.synchronize()

        # Update the iteration and epoch
        iteration += 1
        # Write the training loss summary
        if (iteration % opt.log_loss_every == 0):
            # logging log
            logger.info("{} ({}), loss: {:.3f}".format(iteration, epoch,
                                                       train_loss))
            tb.add_values('loss', {'train': train_loss}, iteration)

        if data['bounds']['wrapped']:
            epoch += 1
            epoch_done = True

        # Make evaluation and save checkpoint
        if (opt.save_checkpoint_every > 0
                and iteration % opt.save_checkpoint_every
                == 0) or (opt.save_checkpoint_every == -1 and epoch_done):
            # eval model
            eval_kwargs = {
                'split': 'val',
                'dataset': opt.input_json,
                'expand_features': False
            }
            eval_kwargs.update(vars(opt))
            predictions, lang_stats = eval_utils.eval_split(
                decoder, loader, eval_kwargs)
            # log val results
            if not lang_stats is None:
                logger.info("Scores: {}".format(lang_stats))
                tb.add_values('scores', lang_stats, epoch)
            val_result_history[epoch] = {
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Save model if is improving on validation result
            current_score = 0 if lang_stats is None else lang_stats['CIDEr']
            best_flag = False
            if best_val_score is None or current_score > best_val_score:
                best_val_score = current_score
                best_flag = True

            # Dump miscalleous informations
            infos['iter'] = iteration
            infos['epoch'] = epoch
            infos['iterators'] = loader.iterators
            infos['split_ix'] = loader.split_ix
            infos['best_val_score'] = best_val_score
            infos['opt'] = opt
            infos['vocab'] = loader.get_vocab()
            infos['val_result_history'] = val_result_history

            save_checkpoint(models, optimizers, infos, best_flag, opt)

        # Stop if reaching max epochs
        if epoch > opt.max_epochs and opt.max_epochs != -1:
            break
Ejemplo n.º 10
0
def train(opt):
    # Deal with feature things before anything
    opt.use_att = utils.if_use_att(opt.caption_model)
    if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length
    training_mode = 0
    optimizer_reset = 0
    change_mode1 = 0
    change_mode2 = 0

    use_rela = getattr(opt, 'use_rela', 0)
    if use_rela:
        opt.rela_dict_size = loader.rela_dict_size
    #need another parameter to control how to train the model

    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(
                os.path.join(
                    opt.checkpoint_path, 'infos_' + opt.id +
                    format(int(opt.start_from), '04') + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(
                os.path.join(
                    opt.checkpoint_path, 'histories_' + opt.id +
                    format(int(opt.start_from), '04') + '.pkl')):
            with open(
                    os.path.join(
                        opt.checkpoint_path, 'histories_' + opt.id +
                        format(int(opt.start_from), '04') + '.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    if epoch >= opt.step2_train_after and epoch < opt.step3_train_after:
        training_mode = 1
    elif epoch >= opt.step3_train_after:
        training_mode = 2
    else:
        training_mode = 0

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt).cuda()
    #dp_model = torch.nn.DataParallel(model)
    #dp_model = torch.nn.DataParallel(model, [0, 1])
    dp_model = model
    for name, param in model.named_parameters():
        print(name)

    update_lr_flag = True
    # Assure in training mode
    dp_model.train()

    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()

    optimizer = utils.build_optimizer(model.parameters(), opt)
    optimizer_mem = optim.Adam([model.memory_cell],
                               opt.learning_rate,
                               (opt.optim_alpha, opt.optim_beta),
                               opt.optim_epsilon,
                               weight_decay=opt.weight_decay)

    # Load the optimizer

    if vars(opt).get('start_from', None) is not None and os.path.isfile(
            os.path.join(
                opt.checkpoint_path, 'optimizer' + opt.id +
                format(int(opt.start_from), '04') + '.pth')):
        optimizer.load_state_dict(
            torch.load(
                os.path.join(
                    opt.checkpoint_path, 'optimizer' + opt.id +
                    format(int(opt.start_from), '04') + '.pth')))
        if (training_mode == 1 or training_mode == 2) and os.path.isfile(
                os.path.join(
                    opt.checkpoint_path, 'optimizer_mem' + opt.id +
                    format(int(opt.start_from), '04') + '.pth')):
            optimizer_mem.load_state_dict(
                torch.load(
                    os.path.join(
                        opt.checkpoint_path, 'optimizer_mem' + opt.id +
                        format(int(opt.start_from), '04') + '.pth')))

    optimizer.zero_grad()
    optimizer_mem.zero_grad()
    accumulate_iter = 0
    reward = np.zeros([1, 1])
    train_loss = 0

    while True:
        # if optimizer_reset == 1:
        #     print("++++++++++++++++++++++++++++++")
        #     print('reset optimizer')
        #     print("++++++++++++++++++++++++++++++")
        #     optimizer = utils.build_optimizer(model.parameters(), opt)
        #     optimizer_mem = optim.Adam([model.memory_cell], opt.learning_rate, (opt.optim_alpha, opt.optim_beta),
        #                                opt.optim_epsilon,
        #                                weight_decay=opt.weight_decay)
        #     optimizer_reset = 0

        if update_lr_flag:
            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate
            utils.set_lr(optimizer, opt.current_lr)
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            update_lr_flag = False

        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch(opt.train_split)
        print('Read data:', time.time() - start)

        torch.cuda.synchronize()
        start = time.time()

        if epoch >= opt.step2_train_after and epoch < opt.step3_train_after:
            training_mode = 1
            if change_mode1 == 0:
                change_mode1 = 1
                optimizer_reset = 1
        elif epoch >= opt.step3_train_after:
            training_mode = 2
            if change_mode2 == 0:
                change_mode2 = 1
                optimizer_reset = 1
        else:
            training_mode = 0

        fc_feats = None
        att_feats = None
        att_masks = None
        ssg_data = None
        rela_data = None

        tmp = [data['fc_feats'], data['labels'], data['masks']]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        fc_feats, labels, masks = tmp

        tmp = [
            data['att_feats'], data['att_masks'], data['rela_rela_matrix'],
            data['rela_rela_masks'], data['rela_attr_matrix'],
            data['rela_attr_masks']
        ]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]

        att_feats, att_masks, rela_rela_matrix, rela_rela_masks, \
            rela_attr_matrix, rela_attr_masks = tmp

        rela_data = {}
        rela_data['att_feats'] = att_feats
        rela_data['att_masks'] = att_masks
        rela_data['rela_matrix'] = rela_rela_matrix
        rela_data['rela_masks'] = rela_rela_masks
        rela_data['attr_matrix'] = rela_attr_matrix
        rela_data['attr_masks'] = rela_attr_masks

        tmp = [
            data['ssg_rela_matrix'], data['ssg_rela_masks'], data['ssg_obj'],
            data['ssg_obj_masks'], data['ssg_attr'], data['ssg_attr_masks']
        ]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        ssg_rela_matrix, ssg_rela_masks, ssg_obj, ssg_obj_masks, ssg_attr, ssg_attr_masks = tmp
        ssg_data = {}
        ssg_data['ssg_rela_matrix'] = ssg_rela_matrix
        ssg_data['ssg_rela_masks'] = ssg_rela_masks
        ssg_data['ssg_obj'] = ssg_obj
        ssg_data['ssg_obj_masks'] = ssg_obj_masks
        ssg_data['ssg_attr'] = ssg_attr
        ssg_data['ssg_attr_masks'] = ssg_attr_masks

        if not sc_flag:
            loss = crit(
                dp_model(fc_feats, att_feats, labels, att_masks, rela_data,
                         ssg_data, use_rela, training_mode), labels[:, 1:],
                masks[:, 1:])
        else:
            gen_result, sample_logprobs = dp_model(fc_feats,
                                                   att_feats,
                                                   att_masks,
                                                   rela_data,
                                                   ssg_data,
                                                   use_rela,
                                                   training_mode,
                                                   opt={'sample_max': 0},
                                                   mode='sample')

            rela_data = {}
            rela_data['att_feats'] = att_feats
            rela_data['att_masks'] = att_masks
            rela_data['rela_matrix'] = rela_rela_matrix
            rela_data['rela_masks'] = rela_rela_masks
            rela_data['attr_matrix'] = rela_attr_matrix
            rela_data['attr_masks'] = rela_attr_masks

            reward = get_self_critical_reward(dp_model, fc_feats, att_feats,
                                              att_masks, rela_data, ssg_data,
                                              use_rela, training_mode, data,
                                              gen_result, opt)
            loss = rl_crit(sample_logprobs, gen_result.data,
                           torch.from_numpy(reward).float().cuda())

        accumulate_iter = accumulate_iter + 1
        loss = loss / opt.accumulate_number
        loss.backward()

        if accumulate_iter % opt.accumulate_number == 0:
            if training_mode == 0:
                utils.clip_gradient(optimizer, opt.grad_clip)
                optimizer.step()
                optimizer.zero_grad()
            elif training_mode == 1:
                utils.clip_gradient(optimizer, opt.grad_clip)
                optimizer.step()
                optimizer.zero_grad()

                utils.clip_gradient(optimizer_mem, opt.grad_clip)
                optimizer_mem.step()
                optimizer_mem.zero_grad()
            elif training_mode == 2:
                utils.clip_gradient(optimizer, opt.grad_clip)
                optimizer.step()
                optimizer.zero_grad()

                utils.clip_gradient(optimizer_mem, opt.grad_clip)
                optimizer_mem.step()
                optimizer_mem.zero_grad()

            iteration += 1
            accumulate_iter = 0
            train_loss = loss.item() * opt.accumulate_number
            end = time.time()
            text_file = open(opt.id + '.txt', "aw")
            if not sc_flag:
                print("iter {} (epoch {}), train_model {}, train_loss = {:.3f}, time/batch = {:.3f}" \
                      .format(iteration, epoch, training_mode, train_loss, end - start))
                text_file.write("iter {} (epoch {}), train_model {}, train_loss = {:.3f}, time/batch = {:.3f}\n" \
                      .format(iteration, epoch, training_mode, train_loss, end - start))
            else:
                print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \
                      .format(iteration, epoch, np.mean(reward[:, 0]), end - start))
                text_file.write("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}\n" \
                      .format(iteration, epoch, np.mean(reward[:, 0]), end - start))
            text_file.close()

        torch.cuda.synchronize()

        # Update the iteration and epoch

        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every
                == 0) and (accumulate_iter % opt.accumulate_number == 0):
            add_summary_value(tb_summary_writer, 'train_loss', train_loss,
                              iteration)
            add_summary_value(tb_summary_writer, 'learning_rate',
                              opt.current_lr, iteration)
            add_summary_value(tb_summary_writer, 'scheduled_sampling_prob',
                              model.ss_prob, iteration)
            if sc_flag:
                add_summary_value(tb_summary_writer, 'avg_reward',
                                  np.mean(reward[:, 0]), iteration)

            loss_history[iteration] = train_loss if not sc_flag else np.mean(
                reward[:, 0])
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every
                == 0) and (accumulate_iter % opt.accumulate_number == 0):
            # eval model

            eval_kwargs = {
                'split': 'test',
                'dataset': opt.input_json,
                'use_rela': use_rela,
                'num_images': 1,
            }
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = eval_utils_mem.eval_split(
                dp_model, crit, loader, training_mode, eval_kwargs)

            # Write validation result into summary
            add_summary_value(tb_summary_writer, 'validation loss', val_loss,
                              iteration)
            if lang_stats is not None:
                for k, v in lang_stats.items():
                    add_summary_value(tb_summary_writer, k, v, iteration)
            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            best_flag = False
            if True:  # if true
                save_id = iteration / opt.save_checkpoint_every
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                checkpoint_path = os.path.join(
                    opt.checkpoint_path,
                    'model' + opt.id + format(int(save_id), '04') + '.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(
                    opt.checkpoint_path,
                    'optimizer' + opt.id + format(int(save_id), '04') + '.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                if training_mode == 1 or training_mode == 2 or opt.caption_model == 'lstm_mem':
                    optimizer_mem_path = os.path.join(
                        opt.checkpoint_path, 'optimizer_mem' + opt.id +
                        format(int(save_id), '04') + '.pth')
                    torch.save(optimizer_mem.state_dict(), optimizer_mem_path)

                    memory_cell = dp_model.memory_cell.data.cpu().numpy()
                    memory_cell_path = os.path.join(
                        opt.checkpoint_path, 'memory_cell' + opt.id +
                        format(int(save_id), '04') + '.npz')
                    np.savez(memory_cell_path, memory_cell=memory_cell)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(
                        os.path.join(
                            opt.checkpoint_path, 'infos_' + opt.id +
                            format(int(save_id), '04') + '.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(
                        os.path.join(
                            opt.checkpoint_path, 'histories_' + opt.id +
                            format(int(save_id), '04') + '.pkl'), 'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos_' + opt.id + '-best.pkl'),
                            'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
Ejemplo n.º 11
0
def train(opt):

    # Load data
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    # Tensorboard summaries (they're great!)
    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    # Load pretrained model, info file, histories file
    infos = {}
    histories = {}

    if opt.start_from is not None:
        with open(os.path.join(opt.start_from,
                               'infos_' + opt.id + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = ["rnn_type", "rnn_size", "num_layers"]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme
        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl')) as f:
                histories = cPickle.load(f)
    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)
    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    #ss_prob_history = histories.get('ss_prob_history', {})
    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    # Create model
    model = convcap(opt).cuda()
    #  pretrained_dict = torch.load('log_xe_final_before_review/all2model12000.pth')
    #   model.load_state_dict(pretrained_dict, strict=False)
    back_model = convcap(opt).cuda()
    back_model.train()
    #   d_pretrained_dict = torch.load('log_xe_final_before_review/all2d_model12000.pth')
    #   back_model.load_state_dict(d_pretrained_dict, strict=False)
    dp_model = model
    dp_model.train()
    dis_model = Discriminator(512, 512, 512, 0.2)
    dis_model = dis_model.cuda()
    dis_model.train()
    #    dis_pretrained_dict = torch.load('./log_xe_final_before_review/all2dis_model12000.pth')
    #    dis_model.load_state_dict(dis_pretrained_dict, strict=False)
    d_optimizer = utils.build_optimizer(dis_model.parameters(), opt)
    back_model.train()
    # Loss functio}
    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()

    # Optimizer and learning rate adjustment flag

    optimizer = utils.build_optimizer_adam(
        chain(model.parameters(), back_model.parameters()), opt)

    #back_optimizer = utils.build_optimizer(back_model.parameters(), opt)
    update_lr_flag = True

    #Load the optimizer

    #   if os.path.isfile(os.path.join('log_xe_final_before_review/',"optimizer.pth")):
    #      optimizer.load_state_dict(torch.load(os.path.join('log_xe_final_before_review/', 'optimizer.pth')))
    #      print ('optimiser loaded')
    #   print (optimizer)
    # Training loop
    while True:

        # Update learning rate once per epoch
        if update_lr_flag:

            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate
            utils.set_lr(optimizer, opt.current_lr)

            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                #opt.ss_prob = min(opt.scheduled_sampling_increase_prob  * frac, opt.scheduled_sampling_max_prob)
                #model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            update_lr_flag = False

        # Load data from train split (0)
        start = time.time()
        data = loader.get_batch('train')
        data_time = time.time() - start
        start = time.time()

        # Unpack data
        torch.cuda.synchronize()
        tmp = [
            data['fc_feats'], data['att_feats'], data['labels'], data['dist'],
            data['masks'], data['att_masks']
        ]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        fc_feats, att_feats, labels, dist_label, masks, attmasks = tmp
        labels = labels.long()
        labels[:, :, 0] = 8667
        nd_labels = labels
        batchsize = fc_feats.size(0)
        # Forward pass and loss
        optimizer.zero_grad()
        d_steps = 1
        g_steps = 1
        #print (torch.sum(labels!=0), torch.sum(masks!=0))
        if 1:
            if iteration >= 0:

                if 1:
                    dp_model.eval()
                    back_model.eval()
                    with torch.no_grad():
                        _, x_all_d = dp_model(fc_feats, att_feats,
                                              nd_labels.long(), 30, 6)

                        labels_nd = nd_labels.view(batchsize, -1)
                        idx = [
                            i for i in range(labels_nd.size()[1] - 1, -1, -1)
                        ]
                        labels_flip_nd = labels_nd[:, idx]
                        labels_flip_nd = labels_flip_nd.view(batchsize, 6, 30)
                        labels_flip_nd[:, :, 0] = 8667
                        _, x_all_flip_d = back_model(fc_feats, att_feats,
                                                     labels_flip_nd, 30, 6)

                        x_all_d = x_all_d[:, :, :-1]
                        x_all_flip_d = x_all_flip_d[:, :, :-1]

                        idx = [
                            i
                            for i in range(x_all_flip_d.size()[2] - 1, -1, -1)
                        ]
                        idx = torch.LongTensor(idx[1:])
                        idx = Variable(idx).cuda()
                        invert_backstates = x_all_flip_d.index_select(2, idx)

                        x_all_d.detach()
                        invert_backstates.detach()
                    x_all_d = x_all_d[:, :, :-1]

                    autoregressive_scores = dis_model(
                        x_all_d.transpose(2, 1).cuda())
                    teacher_forcing_scores = dis_model(
                        invert_backstates.transpose(2, 1).cuda())

                    tf_loss, ar_loss = _calcualte_discriminator_loss(
                        teacher_forcing_scores, autoregressive_scores)

                    tf_loss.backward(retain_graph=True)
                    ar_loss.backward()

                    d_optimizer.step()
                    for p in dis_model.parameters():
                        p.data.clamp_(-0.01, 0.01)

                    torch.cuda.synchronize()
                    total_time = time.time() - start

                if 1:
                    dp_model.train()
                    back_model.train()
                    wordact, x_all = dp_model(fc_feats, att_feats, labels, 30,
                                              6)
                    mask = masks.view(batchsize, -1)
                    mask = mask[:, 1:].contiguous()
                    wordact = wordact[:, :, :-1]
                    wordact_t = wordact.permute(0, 2, 1).contiguous()
                    wordact_t = wordact_t.view(
                        wordact_t.size(0) * wordact_t.size(1), -1)
                    labels_flat = labels.view(batchsize, -1)
                    wordclass_v = labels_flat[:, 1:]
                    wordclass_t = wordclass_v.contiguous().view(\
                     wordclass_v.size(0) * wordclass_v.size(1), 1)
                    maskids = torch.nonzero(
                        mask.view(-1).cpu()).numpy().reshape(-1)
                    loss_xe = F.cross_entropy(wordact_t[maskids, ...], \
                     wordclass_t[maskids, ...].contiguous().view(maskids.shape[0])).cuda()

                    idx = [i for i in range(labels_flat.size()[1] - 1, -1, -1)]
                    labels_flip = labels_flat[:, idx]
                    labels_flip = labels_flip.view(batchsize, 6, 30)
                    labels_flip[:, :, 0] = 8667
                    wordact, x_all_flip = back_model(fc_feats, att_feats,
                                                     labels_flip, 30, 6)
                    mask = masks.view(batchsize, -1).flip((1, ))
                    reverse_mask = mask[:, 1:].contiguous()
                    wordact = wordact[:, :, :-1]
                    wordact_t = wordact.permute(0, 2, 1).contiguous()
                    wordact_t = wordact_t.view(
                        wordact_t.size(0) * wordact_t.size(1), -1)
                    labels_flip = labels_flip.contiguous().view(-1, 6 * 30)
                    wordclass_v = labels_flip[:, 1:]
                    wordclass_t = wordclass_v.contiguous().view(\
                     wordclass_v.size(0) * wordclass_v.size(1), 1)
                    maskids = torch.nonzero(
                        reverse_mask.view(-1).cpu()).numpy().reshape(-1)

                    loss_xe_flip = F.cross_entropy(wordact_t[maskids, ...], \
                     wordclass_t[maskids, ...].contiguous().view(maskids.shape[0])).cuda()

                    train_loss = loss_xe

                    x_all_flip = x_all_flip[:, :, :-1].cuda()
                    x_all = x_all[:, :, :-1].cuda()

                    idx = [i for i in range(x_all_flip.size()[2] - 1, -1, -1)]
                    idx = torch.LongTensor(idx[1:])
                    idx = Variable(idx).cuda()

                    invert_backstates = x_all_flip.index_select(2, idx)
                    invert_backstates = invert_backstates.detach()
                    l2_loss = ((x_all[:, :, :-1] -
                                invert_backstates)**2).mean()

                    autoregressive_scores = dis_model(
                        x_all.transpose(2, 1).cuda())

                    ad_loss = _calculate_generator_loss(
                        autoregressive_scores).sum()

                    all_loss = loss_xe + loss_xe_flip + l2_loss
                    ad_loss.backward(retain_graph=True)
                    all_loss.backward()
                    #            utils.clip_gradient(optimizer, opt.grad_clip)
                    optimizer.step()

            if 1:
                if iteration % opt.print_freq == 1:
                    print('Read data:', time.time() - start)
                    if not sc_flag:
                        print("iter {} (epoch {}), train_loss = {:.3f},l2_loss= {:.3f}, flip_loss = {:.3f}, loss_ad = {:.3f}, fake = {:.3f}, real = {:.3f}, data_time = {:.3f}, time/batch = {:.3f}" \
                          .format(iteration, epoch, loss_xe, l2_loss, loss_xe_flip, ad_loss, ar_loss, tf_loss, data_time, total_time))
                    else:
                        print("iter {} (epoch {}), avg_reward = {:.3f}, data_time = {:.3f}, time/batch = {:.3f}" \
                            .format(iteration, epoch, np.mean(reward[:,0]), data_time, total_time))

            # Update the iteration and epoch
            iteration += 1
            if data['bounds']['wrapped']:
                epoch += 1
                update_lr_flag = True

            # Write the training loss summary
            if (iteration % opt.losses_log_every == 0):
                add_summary_value(tb_summary_writer, 'train_loss', train_loss,
                                  iteration)
                add_summary_value(tb_summary_writer, 'learning_rate',
                                  opt.current_lr, iteration)
                #add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration)
                if sc_flag:
                    add_summary_value(tb_summary_writer, 'avg_reward',
                                      np.mean(reward[:, 0]), iteration)
                loss_history[
                    iteration] = train_loss if not sc_flag else np.mean(
                        reward[:, 0])
                lr_history[iteration] = opt.current_lr
                #ss_prob_history[iteration] = model.ss_prob

        # Validate and save model
            if (iteration % opt.save_checkpoint_every == 0):
                checkpoint_path = os.path.join(
                    opt.checkpoint_path,
                    'all2model{:05d}.pth'.format(iteration))
                torch.save(model.state_dict(), checkpoint_path)
                checkpoint_path = os.path.join(
                    opt.checkpoint_path,
                    'all2d_model{:05d}.pth'.format(iteration))
                torch.save(back_model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                checkpoint_path = os.path.join(
                    opt.checkpoint_path,
                    'all2dis_model{:05d}.pth'.format(iteration))
                torch.save(dis_model.state_dict(), checkpoint_path)
                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)
                # Evaluate model
        if (iteration % 1000 == 0):
            eval_kwargs = {'split': 'test', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = eval_utils.eval_split(
                dp_model, crit, loader, eval_kwargs)
            # Write validation result into summary
            add_summary_value(tb_summary_writer, 'validation loss', val_loss,
                              iteration)
            if lang_stats is not None:
                for k, v in lang_stats.items():
                    add_summary_value(tb_summary_writer, k, v, iteration)
            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Our metric is CIDEr if available, otherwise validation loss
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            # Save model in checkpoint path
            best_flag = False
            if best_val_score is None or current_score > best_val_score:
                best_val_score = current_score
                best_flag = True
            checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth')
            torch.save(model.state_dict(), checkpoint_path)
            checkpoint_path = os.path.join(opt.checkpoint_path, 'd_model.pth')
            torch.save(back_model.state_dict(), checkpoint_path)
            checkpoint_path = os.path.join(opt.checkpoint_path,
                                           'dis_model.pth')
            torch.save(dis_model.state_dict(), checkpoint_path)
            print("model saved to {}".format(checkpoint_path))
            optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth')
            torch.save(optimizer.state_dict(), optimizer_path)

            # Dump miscalleous informations
            infos['iter'] = iteration
            infos['epoch'] = epoch
            infos['iterators'] = loader.iterators
            infos['split_ix'] = loader.split_ix
            infos['best_val_score'] = best_val_score
            infos['opt'] = opt
            infos['vocab'] = loader.get_vocab()
            histories['val_result_history'] = val_result_history
            histories['loss_history'] = loss_history
            histories['lr_history'] = lr_history
            #histories['ss_prob_history'] = ss_prob_history
            with open(
                    os.path.join(opt.checkpoint_path,
                                 'infos_' + opt.id + '.pkl'), 'wb') as f:
                cPickle.dump(infos, f)
            with open(
                    os.path.join(opt.checkpoint_path,
                                 'histories_' + opt.id + '.pkl'), 'wb') as f:
                cPickle.dump(histories, f)

            # Save model to unique file if new best model
            if best_flag:
                model_fname = 'model-best-i{:05d}-score{:.4f}.pth'.format(
                    iteration, best_val_score)
                infos_fname = 'model-best-i{:05d}-infos.pkl'.format(iteration)
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               model_fname)
                torch.save(model.state_dict(), checkpoint_path)
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'd_model-best.pth')
                torch.save(back_model.state_dict(), checkpoint_path)
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'dis_model-best.pth')
                torch.save(dis_model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                with open(os.path.join(opt.checkpoint_path, infos_fname),
                          'wb') as f:
                    cPickle.dump(infos, f)
def main(configs, args):
    global net, dataloader, optimizer, lr_scheduler, writer, epochs, logger
    best_acc = 0

    torch.manual_seed(6666)
    configs = init_configs(configs)
    net = build_model(configs)
    net = init_model(net, configs)
    net = net.cuda().train()
    print(net)

    if args.debug:
        configs.log_dir = os.path.join('debug', configs.log_dir)
        configs.ckpt.save_config_path = os.path.join(
            'debug', configs.ckpt.save_config_path)
        configs.ckpt.save_model_path = os.path.join(
            'debug', configs.ckpt.save_model_path)
        configs.ckpt.save_optim_path = os.path.join(
            'debug', configs.ckpt.save_optim_path)

    check_dir(configs.log_dir)
    if not configs.do_test:
        config_path = configs.ckpt.save_config_path
        torch.save({'configs': configs},
                   os.path.join(config_path, 'configs.pth'))

    logger = create_logger(configs.log_dir, configs.cfg_name)
    writer = SummaryWriter(configs.log_dir)

    for name, param in net.named_parameters():
        print('%s required grad is %s' % (name, param.requires_grad))

    dataloader = build_dataset(configs)
    optimizer = build_optimizer(net.parameters(), configs.optimizer)
    optimizer = init_optim(optimizer, configs)
    lr_scheduler = get_lr_scheduler(configs.training)

    max_iterations = configs.training.max_episodes
    test_every_iterations = configs.testing.test_every_episodes
    for iteration in range(1, max_iterations + 1):
        try:
            if iteration % test_every_iterations == 0 or configs.do_test or (
                    args.debug and args.debug_test):
                epochs += 1
                acc = test('test', configs)
                optim_path = configs.ckpt.save_optim_path
                model_path = configs.ckpt.save_model_path
                z, ind_z, den_z, images, labels = extract_features(
                    'test', configs)
                if not configs.do_test:
                    torch.save({'model': net.state_dict()},
                               os.path.join(model_path,
                                            'model_%d.pth' % iteration))
                    torch.save({'optim': optimizer.state_dict()},
                               os.path.join(optim_path,
                                            'optim_%d.pth' % iteration))
                    torch.save(
                        {
                            'z': z.numpy(),
                            'ind_z': ind_z.numpy(),
                            'den_z': den_z.numpy(),
                            'labels': labels,
                            'images': images
                        },
                        os.path.join(model_path, 'results_%d.pth' % iteration))
                    if acc > best_acc:
                        best_acc = acc
                        torch.save({'model': net.state_dict()},
                                   os.path.join(model_path, 'model_best.pth'))
                        torch.save({'optim': optimizer.state_dict()},
                                   os.path.join(optim_path, 'optim_best.pth'))
                if configs.do_test or (args.debug and args.debug_test):
                    return
            train(iteration, configs)
        except KeyboardInterrupt:
            import ipdb
            ipdb.set_trace()
Ejemplo n.º 13
0
def train(opt):
    # opt.use_att = utils.if_use_att(opt.caption_model)
    opt.use_att = True
    if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5

    opt.vocab_size = 50
    opt.seq_length = 10
    opt.fc_feat_size = 100
    opt.train_true = True
    opt.train_true_step = 100
    np.random.seed(0)
    data_num = 5000
    data_features = np.random.normal(size=[data_num, opt.fc_feat_size])
    test_data_num = 1000
    test_data_features = np.random.normal(
        size=[test_data_num, opt.fc_feat_size])
    print(opt.checkpoint_path)
    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from,
                               'infos_' + opt.id + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    critic_loss_history = histories.get('critic_loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})
    variance_history = histories.get('variance_history', {})
    time_history = histories.get('time_history', {})

    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt).cuda()
    dp_model = model
    #TODO: save true model
    true_model = models.setup(opt).cuda()
    if vars(opt).get('start_from', None) is not None:
        # check if all necessary files exist
        assert os.path.isdir(
            opt.start_from), " %s must be a a path" % opt.start_from
        assert os.path.isfile(
            os.path.join(opt.start_from, "infos_" + opt.id + ".pkl")
        ), "infos.pkl file does not exist in path %s" % opt.start_from
        true_model.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'truemodel.pth')))
    true_model.eval()
    ######################### Actor-critic Training #####################################################################

    update_lr_flag = True
    # Assure in training mode
    dp_model.train()

    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()

    optimizer = utils.build_optimizer(model.parameters(), opt)
    tm_optimizer = utils.build_optimizer(true_model.parameters(), opt)
    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(
            os.path.join(opt.start_from, "optimizer.pth")):
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    first_order = 0
    second_order = 0
    while True:
        if update_lr_flag:
            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate
            utils.set_lr(optimizer, opt.current_lr)
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            update_lr_flag = False

        dp_model.train()

        torch.cuda.synchronize()
        start = time.time()
        gen_result = None
        start_index = (iteration * opt.batch_size) % data_num
        end_index = start_index + opt.batch_size
        fc_feats = torch.from_numpy(
            data_features[start_index:end_index, :]).cuda().float()
        att_feats = None
        att_masks = None
        labels, total_logits = true_model(fc_feats,
                                          att_feats,
                                          att_masks,
                                          opt={'sample_max': 1},
                                          total_probs=True,
                                          mode='sample')
        labels = torch.cat(
            [torch.zeros(labels.size(0), 1).cuda().long(), labels], 1)
        masks = (labels > 0).float()

        # train true model:
        if iteration < opt.train_true_step and opt.train_true:
            tm_optimizer.zero_grad()
            loss = -((total_logits * F.softmax(total_logits, 2)).sum(2)).mean()
            loss.backward()
            tm_optimizer.step()

        optimizer.zero_grad()
        if not sc_flag:
            loss = crit(dp_model(fc_feats, att_feats, labels, att_masks),
                        labels[:, 1:], masks[:, 1:])
        else:
            if opt.rl_type == 'sc':
                gen_result, sample_logprobs = dp_model(fc_feats,
                                                       att_feats,
                                                       att_masks,
                                                       opt={'sample_max': 0},
                                                       mode='sample')
                gen_result_sc, _ = dp_model(fc_feats,
                                            att_feats,
                                            att_masks,
                                            opt={'sample_max': 1},
                                            mode='sample')
                reward = reward_fun(gen_result, fc_feats,
                                    true_model).unsqueeze(1).repeat(
                                        1, sample_logprobs.size(1))
                reward_sc = reward_fun(gen_result_sc, fc_feats,
                                       true_model).unsqueeze(1).repeat(
                                           1, sample_logprobs.size(1))
                reward = reward - reward_sc
                loss = rl_crit(sample_logprobs, gen_result.data, reward)
                reward = np.zeros([2, 2])
            elif opt.rl_type == 'reinforce':
                gen_result, sample_logprobs = dp_model(fc_feats,
                                                       att_feats,
                                                       att_masks,
                                                       opt={'sample_max': 0},
                                                       mode='sample')
                reward = reward_fun(gen_result, fc_feats,
                                    true_model).unsqueeze(1).repeat(
                                        1, sample_logprobs.size(1))
                loss = rl_crit(sample_logprobs, gen_result.data, reward)
                reward = np.zeros([2, 2])
            elif opt.rl_type == 'reinforce_demean':
                gen_result, sample_logprobs = dp_model(fc_feats,
                                                       att_feats,
                                                       att_masks,
                                                       opt={'sample_max': 0},
                                                       mode='sample')
                reward = reward_fun(gen_result, fc_feats,
                                    true_model).unsqueeze(1).repeat(
                                        1, sample_logprobs.size(1))
                loss = rl_crit(sample_logprobs, gen_result.data,
                               reward - reward.mean())
                reward = np.zeros([2, 2])
            elif opt.rl_type == 'arsm':
                loss = get_arm_loss(dp_model, fc_feats, att_feats, att_masks,
                                    true_model, opt)
                #print(loss)
                reward = np.zeros([2, 2])
            elif opt.rl_type == 'ars':
                loss = get_arm_loss(dp_model,
                                    fc_feats,
                                    att_feats,
                                    att_masks,
                                    true_model,
                                    opt,
                                    type='ars')
                #print(loss)
                reward = np.zeros([2, 2])
            elif opt.rl_type == 'ar':
                loss = get_ar_loss(dp_model, fc_feats, att_feats, att_masks,
                                   true_model, opt)
                # print(loss)
                reward = np.zeros([2, 2])
            elif opt.rl_type == 'mct_baseline':
                opt.rf_demean = 0
                gen_result, sample_logprobs, probs, mct_baseline = get_mct_loss(
                    dp_model, fc_feats, att_feats, att_masks, opt, true_model)
                reward = reward_fun(gen_result, fc_feats,
                                    true_model).unsqueeze(1).repeat(
                                        1, sample_logprobs.size(1))
                reward_cuda = reward
                #mct_baseline[mct_baseline < 0] = reward_cuda[mct_baseline < 0]
                loss = rl_crit(sample_logprobs, gen_result.data,
                               reward - mct_baseline)
        if opt.mle_weights != 0:
            loss += opt.mle_weights * crit(
                dp_model(fc_feats, att_feats, labels, att_masks),
                labels[:, 1:], masks[:, 1:])
        #TODO make sure all sampling replaced by greedy for critic
        #### update the actor
        loss.backward()
        # with open(os.path.join(opt.checkpoint_path, 'best_embed.pkl'), 'wb') as f:
        #     cPickle.dump(list(dp_model.embed.parameters())[0].data.cpu().numpy(), f)
        # with open(os.path.join(opt.checkpoint_path, 'best_logit.pkl'), 'wb') as f:
        #     cPickle.dump(list(dp_model.logit.parameters())[0].data.cpu().numpy(), f)
        ## compute variance
        gradient = torch.zeros([0]).cuda()
        for i in model.parameters():
            gradient = torch.cat((gradient, i.grad.view(-1)), 0)
        first_order = 0.9999 * first_order + 0.0001 * gradient
        second_order = 0.9999 * second_order + 0.0001 * gradient.pow(2)
        # print(torch.max(torch.abs(gradient)))
        variance = torch.mean(torch.abs(second_order -
                                        first_order.pow(2))).item()
        if opt.rl_type != 'arsm' or not sc_flag:
            utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        train_loss = loss.item()
        torch.cuda.synchronize()
        end = time.time()
        if (iteration % opt.losses_log_every == 0):
            if not sc_flag:
                print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, train_loss, end - start))
                print(opt.checkpoint_path)
            else:
                print("iter {} (epoch {}), avg_reward = {:.3f}, variance = {:g}, time/batch = {:.3f}" \
                      .format(iteration, epoch, reward.mean(), variance, end - start))

        # Update the iteration and epoch
        iteration += 1
        if (iteration * opt.batch_size) % data_num == 0:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            add_summary_value(tb_summary_writer, 'train_loss', train_loss,
                              iteration)
            add_summary_value(tb_summary_writer, 'learning_rate',
                              opt.current_lr, iteration)
            add_summary_value(tb_summary_writer, 'scheduled_sampling_prob',
                              model.ss_prob, iteration)
            if sc_flag:
                add_summary_value(tb_summary_writer, 'avg_reward',
                                  reward.mean(), iteration)
                add_summary_value(tb_summary_writer, 'variance', variance,
                                  iteration)

            #loss_history[iteration] = train_loss if not sc_flag else reward.mean()
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob
            variance_history[iteration] = variance
            time_history[iteration] = end - start

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            # eval model

            val_loss, lang_stats = eval_utils_syn(dp_model, true_model,
                                                  test_data_features,
                                                  opt.batch_size, crit)

            lang_stats = lang_stats.item()
            val_loss = val_loss.item()
            # Write validation result into summary
            add_summary_value(tb_summary_writer, 'validation loss', val_loss,
                              iteration)
            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats
            }
            # Save model if is improving on validation result
            print('loss', val_loss, 'lang_stats', lang_stats)
            if True:  # if true
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model.pth')
                if not os.path.isdir(opt.checkpoint_path):
                    os.mkdir(opt.checkpoint_path)
                torch.save(model.state_dict(), checkpoint_path)
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'truemodel.pth')
                torch.save(true_model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)
                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = opt.vocab_size
                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['critic_loss_history'] = critic_loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                histories['variance_history'] = variance_history
                histories['time'] = time_history
                # histories['variance'] = 0
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'infos_' + opt.id + '.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'histories_' + opt.id + '.pkl'),
                        'wb') as f:
                    cPickle.dump(histories, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
Ejemplo n.º 14
0
def train(opt, num_switching=None):
    global internal
    if opt.gpu2 is None:
        torch.cuda.set_device(opt.gpu)
    RL_count = 0
    pure_reward = None

    # Deal with feature things before anything
    opt.use_att = utils.if_use_att(opt.caption_model)
    if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5

    # set dataloder
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length
    opt.baseline_concat = 0

    # setting of record
    result_path = '/mnt/workspace2019/nakamura/selfsequential/log_python3/' + opt.checkpoint_path
    tb_summary_writer = tb and tb.SummaryWriter(result_path)

    infos = {}
    histories = {}


    # --- pretrained model loading --- #
    if opt.start_from is not None:
        opt.start_from = '/mnt/workspace2019/nakamura/selfsequential/log_python3/' + opt.start_from
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        infos = cPickle.load(open(os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl'), mode='rb'))
        saved_model_opt = infos['opt']
        # need_be_same=["caption_model", "rnn_type", "rnn_size", "num_layers"]
        need_be_same = ["rnn_type", "rnn_size", "num_layers"]
        for checkme in need_be_same:
            assert vars(saved_model_opt)[checkme] == vars(opt)[
                checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')):
            histories = cPickle.load(open(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl') , mode='rb'))
    if opt.sf_epoch is not None and opt.sf_itr is not None:
        iteration = opt.sf_itr
        epoch = opt.sf_epoch
    else:
        iteration = infos.get('iter', 0)
        epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    #---------------------------------------#

    # I forget about these parameter, they maybe are not used.
    b_regressor = None
    opt.regressor = b_regressor

    # model setting
    if opt.gpu2 is not None:
        model = models.setup(opt).cuda()
        dp_model = torch.nn.DataParallel(model)
    else:
        model = models.setup(opt).cuda()
        dp_model = model

    update_lr_flag = True
    # Assure in training mode
    dp_model.train()

    # set rl mode and internal critic and similairty model
    info_json = json.load(open(opt.input_json))
    sim_model = None
    new_internal = None
    if opt.internal_model == 'sim' or opt.internal_model == 'sim_newr'  or opt.internal_model == 'sim_dammy':

        # setting internal critic and similarity prediction network
        sim_model = sim.Sim_model(opt.input_encoding_size, opt.rnn_size, vocab_size=len(info_json['ix_to_word']))

        if opt.region_bleu_flg == 0:
            if opt.sim_pred_type == 0:
                # model_root = '/mnt/workspace2018/nakamura/vg_feature/model_cossim2/model_13_1700.pt'
                model_root = '/mnt/workspace2018/nakamura/vg_feature/model_cossim_bu/model_6_0.pt'
            elif opt.sim_pred_type == 1:
                model_root = '/mnt/workspace2018/nakamura/vg_feature/model_cossim_noshuffle04model_71_1300.pt'
            elif opt.sim_pred_type == 2:
                model_root = '/mnt/workspace2019/nakamura/selfsequential/sim_model/subset_similarity/model_0_3000.pt'
            else:
                print('select 0 or 1')
                exit()
            checkpoint = torch.load(model_root, map_location='cuda:0')
            sim_model.load_state_dict(checkpoint['model_state_dict'])
            sim_model.cuda()
            sim_model.eval()
            for param in sim_model.parameters():
                param.requires_grad = False
            sim_model_optimizer = None
        elif opt.region_bleu_flg == 1:
            sim_model.cuda()
            if opt.sf_internal_epoch is not None:
                sim_model.load_state_dict(
                    torch.load(os.path.join(opt.start_from, 'sim_model_' + str(opt.sf_internal_epoch) + '_' + str(
                        opt.sf_internal_itr) + '.pth')))
                # sim_model_optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'internal_optimizer_' + str(
                #     opt.sf_internal_epoch) + '_' + str(opt.sf_internal_itr) + '.pth')))
            sim_model_optimizer = utils.build_internal_optimizer(sim_model.parameters(), opt)
        else:
            print('not implimented')
            exit()


        if opt.only_critic_train == 1:
            random.seed(100)
        if opt.critic_encode==1:
            internal = models.CriticModel_with_encoder(opt)
        elif opt.bag_flg == 1:
            internal = models.CriticModel_bag(opt)
        elif opt.ppo == 1:
            # internal = models.CriticModel_sim(opt)
            internal = models.CriticModel_nodropout(opt)
            new_internal = models.CriticModel_nodropout(opt)
            internal.load_state_dict(new_internal.state_dict())
        elif opt.input_h_flg == 1:
            internal = models.CriticModel_sim(opt)
        else:
            internal = models.CriticModel_sim_h(opt)

        internal = internal.cuda()
        if new_internal is not None:
            new_internal = new_internal.cuda()

        if opt.ppo == 1:
            internal_optimizer = utils.build_internal_optimizer(new_internal.parameters(), opt)
        else:
            internal_optimizer = utils.build_internal_optimizer(internal.parameters(), opt)

        if opt.sf_internal_epoch is not None:
            internal.load_state_dict(torch.load(os.path.join(opt.start_from,'internal_' + str(opt.sf_internal_epoch) + '_' + str(
                                                                 opt.sf_internal_itr) + '.pth')))
            internal_optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'internal_optimizer_' + str(
                opt.sf_internal_epoch) + '_' + str(opt.sf_internal_itr) + '.pth')))
            # new_internal = models.CriticModel_nodropout(opt)
            new_internal.load_state_dict(torch.load(os.path.join(opt.start_from,'internal_' + str(opt.sf_internal_epoch) + '_' + str(
                                                                 opt.sf_internal_itr) + '.pth')))
        if opt.multi_learn_flg != 1:
            if opt.internal_rl_flg == 1:
                internal_rl_flg = True
                dp_model.eval()
            else:
                internal.eval()
                internal_rl_flg = False
        else:
            internal_rl_flg = True
    else:
        if opt.sim_reward_flg > 0:
            # setting internal critic and similarity prediction network
            sim_model = sim.Sim_model(opt.input_encoding_size, opt.rnn_size, vocab_size=len(info_json['ix_to_word']))
            if opt.sim_pred_type == 0:
                # model_root = '/mnt/workspace2018/nakamura/vg_feature/model_cossim2/model_13_1700.pt'
                # model_root = '/mnt/workspace2018/nakamura/vg_feature/model_cossim_bu/model_6_0.pt'
                model_root = '/mnt/workspace2019/nakamura/selfsequential/sim_model/no_shuffle_simforcoco/model_37_34000.pt'
            elif opt.sim_pred_type == 1:
                model_root = '/mnt/workspace2018/nakamura/vg_feature/model_cossim_noshuffle04model_71_1300.pt'
            elif opt.sim_pred_type == 2:
                model_root = '/mnt/workspace2019/nakamura/selfsequential/sim_model/subset_similarity/model_0_3000.pt'
            else:
                print('select 0 or 1')
                exit()

            if opt.region_bleu_flg == 0:
                if opt.sim_pred_type == 0:
                    # model_root = '/mnt/workspace2018/nakamura/vg_feature/model_cossim2/model_13_1700.pt'
                    opt.sim_model_dir = '/mnt/workspace2018/nakamura/vg_feature/model_cossim_bu/model_6_0.pt'
                elif opt.sim_pred_type == 1:
                    opt.sim_model_dir = '/mnt/workspace2018/nakamura/vg_feature/model_cossim_noshuffle04model_71_1300.pt'
                elif opt.sim_pred_type == 2:
                    opt.sim_model_dir = '/mnt/workspace2019/nakamura/selfsequential/sim_model/subset_similarity/model_0_3000.pt'
                else:
                    opt.sim_model_dir = '/mnt/workspace2019/nakamura/selfsequential/log_python3/log_' + opt.id + '/sim_model' + opt.model[-13:-4] + '.pth'

                checkpoint = torch.load(opt.sim_model_dir, map_location='cuda:0')
                sim_model.load_state_dict(checkpoint['model_state_dict'])
                sim_model.cuda()
                sim_model.eval()
                for param in sim_model.parameters():
                    param.requires_grad = False
                sim_model_optimizer = None
            elif opt.region_bleu_flg == 1:
                sim_model_optimizer = utils.build_internal_optimizer(sim_model.parameters(), opt)
                sim_model.cuda()

        internal = None
        internal_optimizer = None
        internal_rl_flg = False
        opt.c_current_lr = 0
    # opt.internal = internal

    # set Discriminator
    if opt.discriminator_weight > 0:
        dis_opt = opt
        if opt.dis_type == 'coco':
            discrimiantor_model_dir = '/mnt/workspace2018/nakamura/selfsequential/discriminator_log/coco/discriminator_150.pth'
            dis_opt.input_label_h5 = '/mnt/poplin/share/dataset/MSCOCO/cocotalk_coco_for_discriminator_label.h5'
            dis_opt.input_json = '/mnt/poplin/share/dataset/MSCOCO/cocotalk_coco_for_discriminator.json'
        elif opt.dis_type == 'iapr':
            discrimiantor_model_dir = '/mnt/workspace2018/nakamura/selfsequential/discriminator_log/iapr_dict/discriminator_125.pth'
            dis_opt.input_label_h5 = '/mnt/workspace2019/visual_genome_pretrain/iapr_talk_cocodict_label.h5'
            dis_opt.input_json = '/mnt/workspace2018/nakamura/IAPR/iapr_talk_cocodict.json'
        elif opt.dis_type == 'ss':
            discrimiantor_model_dir = '/mnt/workspace2019/nakamura/selfsequential/discriminator_log/shuttorstock_dict/discriminator_900.pth'
            dis_opt.input_label_h5 = '/mnt/workspace2019/nakamura/shutterstock/shuttorstock_talk_cocodict_label.h5'
            dis_opt.input_json = '/mnt/workspace2019/nakamura/shutterstock/shuttorstock_talk_cocodict.json'
        elif opt.dis_type == 'sew':
            discrimiantor_model_dir = '/mnt/workspace2019/nakamura/selfsequential/discriminator_log/sew/discriminator_900.pth'
            dis_opt.input_label_h5 = '/mnt/poplin/share/dataset/simple_english_wikipedia/sew_talk_label.h5'
            dis_opt.input_json = '/mnt/poplin/share/dataset/simple_english_wikipedia/sew_talk.json'
        elif opt.dis_type == 'sew_cut5':
            discrimiantor_model_dir = '/mnt/workspace2019/nakamura/selfsequential/discriminator_log/sew_cut5/discriminator_90.pth'
            dis_opt.input_label_h5 = '/mnt/poplin/share/dataset/simple_english_wikipedia/sew_talk_label.h5'
            dis_opt.input_json = '/mnt/poplin/share/dataset/simple_english_wikipedia/sew_talk.json'
            opt.cut_length = 5
        elif opt.dis_type == 'vg_cut5':
            opt.cut_length = 5
            discrimiantor_model_dir = '/mnt/workspace2019/nakamura/selfsequential/discriminator_log/vg_cut5/discriminator_200.pth'
            dis_opt.input_label_h5 = '/mnt/poplin/share/dataset/MSCOCO/cocotalk_subset_vg_larger_label.h5'
            dis_opt.input_json = '/mnt/poplin/share/dataset/MSCOCO/cocotalk_subset_vg_larger_addregions.json'
        else:
            print('select existing discriminative model!')
            exit()

        discriminator_path_learned = os.path.join(result_path, 'discriminator_{}_{}.pth'.format(epoch, iteration))
        Discriminator = dis_utils.Discriminator(opt)
        if os.path.isfile(discriminator_path_learned):
            Discriminator.load_state_dict(torch.load(discriminator_path_learned, map_location='cuda:' + str(opt.gpu)))
        else:
            Discriminator.load_state_dict(torch.load(discrimiantor_model_dir, map_location='cuda:' + str(opt.gpu)))
        Discriminator = Discriminator.cuda()
        # change discriminator learning rate
        # opt.learning_rate = opt.learning_rate/10
        dis_optimizer = utils.build_optimizer(Discriminator.parameters(), opt)
        # for group in dis_optimizer.param_groups:
        #     group['lr'] = opt.learning_rate/100
        Discriminator.eval()
        dis_loss_func = nn.BCELoss().cuda()
        dis_loader = dis_dataloader.DataLoader(dis_opt)
    else:
        Discriminator = None
        dis_loader = None
        dis_optimizer = None

    # set Acter Critic network
    if opt.actor_critic_flg == 1:
        Q_net = models.Actor_Critic_Net_upper(opt)
        target_Q_net = models.Actor_Critic_Net_upper(opt)
        Q_net.load_state_dict(target_Q_net.state_dict())
        target_model = models.setup(opt).cuda()
        target_model.load_state_dict(model.state_dict())
        target_model.eval()
        Q_net.cuda()
        target_Q_net.cuda()
        Q_net_optimizer = utils.build_optimizer(Q_net.parameters(), opt)
    elif opt.actor_critic_flg == 2:
        Q_net = models.Actor_Critic_Net_seq(opt)
        target_Q_net = models.Actor_Critic_Net_seq(opt)
        Q_net.load_state_dict(target_Q_net.state_dict())
        target_model = models.setup(opt).cuda()
        target_model.load_state_dict(model.state_dict())
        target_model.eval()
        Q_net.cuda()
        target_Q_net.cuda()
        Q_net_optimizer = utils.build_optimizer(Q_net.parameters(), opt)

        seq_mask = torch.zeros((opt.batch_size * opt.seq_per_img, opt.seq_length, opt.seq_length)).cuda().type(torch.cuda.LongTensor)
        for i in range(opt.seq_length):
            seq_mask[:, i, :i] += 1
    elif opt.t_model_flg == 1:
        target_model = models.setup(opt).cuda()
        target_model.load_state_dict(model.state_dict())
        target_model.eval()
    else:
        target_model = None

    baseline = None
    new_model = None
    # set functions calculating loss
    if opt.caption_model == 'hcatt_hard' or opt.caption_model == 'basicxt_hard' or opt.caption_model == 'hcatt_hard_nregion' or opt.caption_model == 'basicxt_hard_nregion' :
        if opt.ppo == 1:
            new_model = models.setup(opt).cuda()
            new_model.load_state_dict(model.state_dict())
            # new_optimizer = utils.build_optimizer(new_model.parameters(), opt)
            # new_model.eval()

        # If you use hard attention, use this setting (but is is not implemented completely)
        crit = utils.LanguageModelCriterion_hard()
    else:
        crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()
    rl_crit_hard = utils.RewardCriterion_hard()
    rl_crit_conly = utils.RewardCriterion_conly()
    rl_crit_hard_base = utils.RewardCriterion_hard_baseline()
    att_crit = utils.AttentionCriterion()

    if opt.caption_model == 'hcatt_hard' and opt.ppo == 1:
        optimizer = utils.build_optimizer(new_model.parameters(), opt)
    else:
        # set optimizer
        optimizer = utils.build_optimizer(model.parameters(), opt)
    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(os.path.join(opt.start_from,"optimizer.pth")):
        if opt.sf_epoch is None:
            optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'optimizer.pth')))
        else:
            optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'optimizer_' + str(opt.sf_epoch) + '_' +str(opt.sf_itr) + '.pth')))

    critic_train_count = 0
    total_critic_reward = 0
    pre_para = None

    #------------------------------------------------------------------------------------------------------------#
    # training start
    while True:
        train_loss = 0
        if update_lr_flag:
            # cahnge lr
            opt, optimizer, model, internal_optimizer, dis_optimizer = utils.change_lr(opt, epoch, optimizer, model, internal_optimizer, dis_optimizer)

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                # internal_rl_flg == False
                init_scorer(opt.cached_tokens, len(info_json['ix_to_word']))
            else:
                sc_flag = False

            update_lr_flag = False

        # # !!!!!
        # internal_rl_flg = False
        # model.train()
        # internal.eval()
        # #!!!!!

        # Load data from train split (0)
        data = loader.get_batch('train')

        torch.cuda.synchronize()
        start = time.time()

        # get datch
        tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks'],
               data['bbox'], data['sub_att'], data['fixed_region']]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        fc_feats, att_feats, labels, masks, att_masks, bbox, sub_att, fixed_region = tmp

        optimizer.zero_grad()
        # calculating loss...
        if not sc_flag:
            # use cross entropy
            if opt.weight_deterministic_flg > 0:
                weight_index = np.array(data['weight_index'])
                # fc_feats = fc_feats * 0.0
                output = dp_model(fc_feats, att_feats, labels, att_masks, internal, weight_index=weight_index)
                # output = dp_model(fc_feats, att_feats, labels, att_masks, internal, weight_index=None)
            else:
                output = dp_model(fc_feats, att_feats, labels, att_masks, internal)
            if opt.caption_model == 'hcatt_prob':
                print(torch.exp(output).mean(),  model.probs.mean())
                output = output + model.probs.view(output.size(0), output.size(1), 1)
                loss = crit(output, labels[:,1:], masks[:,1:])
            elif opt.caption_model != 'hcatt_hard' and opt.caption_model != 'hcatt_hard_nregion'and opt.caption_model != 'basicxt_hard_nregion' and opt.caption_model != 'basicxt_hard':
                loss = crit(output, labels[:,1:], masks[:,1:])
            else:
                if baseline is None:
                    baseline = torch.zeros((output.size()[0], output.size()[1]))/output.size()[1]
                    baseline = baseline.cuda()
                    # baseline = torch.log(baseline)
                # print('pre:', baseline.mean().item())
                loss, baseline = crit(output, labels[:,1:], masks[:,1:], baseline, dp_model.weights_p, dp_model.weights)
                # print('after:', baseline.mean().item())
        else:
            # use rl
            if opt.weight_deterministic_flg > 0:
                weight_index = np.array(data['weight_index'])
            else:
                weight_index = None

            if dp_model.training:
                sample_max_flg = 0
            else:
                sample_max_flg = 1

            # get predicted captions and logprops, similarity
            gen_result, sample_logprobs, word_exist_seq = dp_model(fc_feats, att_feats, att_masks,internal,
                                                   opt={'sample_max':sample_max_flg}, sim_model = sim_model, New_Critic=new_internal,
                                                   bbox=bbox, sub_att=sub_att, label_region = data['label_region'], weight_index=weight_index,mode='sample')
            train_similarity = dp_model.similarity

            # ---------- learning discriminator ----------------
            if Discriminator is not None and opt.dis_adv_flg == 1 and internal_rl_flg == False:
                correct = 0
                Discriminator.train()
                fake_data = gen_result.data.cpu()
                hokan = torch.zeros((len(fake_data), 1)).type(torch.LongTensor)
                fake_data = torch.cat((hokan, fake_data, hokan), 1).cuda()
                fake_data = fake_data[:, 1:]
                label = torch.ones((fake_data.size(0))).cuda()
                # pdb.set_trace()
                Discriminator, dis_optimizer, correct, neg_loss = \
                    dis_utils.learning_func(Discriminator, dis_optimizer, fake_data, label, correct, 0, opt.cut_length, opt.random_disc, opt.all_switch_end_dis, opt.all_switch_dis,
                                            loss_func=dis_loss_func, weight_index=weight_index, model_gate=model.gate.data.cpu().numpy())

                dis_data = dis_loader.get_batch('train', batch_size=fake_data.size(0))
                real_data = torch.from_numpy(dis_data['labels']).cuda()
                real_data = real_data[:, 1:]
                Discriminator, dis_optimizer, correct, pos_loss = \
                    dis_utils.learning_func(Discriminator, dis_optimizer, real_data, label, correct, 1, opt.cut_length, 0, 0, 0,
                                            loss_func=dis_loss_func, weight_index=weight_index)

                loss_mean = (pos_loss + neg_loss) / 2
                dis_accuracy = correct/(fake_data.size(0) * 2)
                print('Discriminator loss: {}, accuracy: {}'.format(loss_mean, dis_accuracy))
                Discriminator.eval()
            else:
                loss_mean = -1.0
                dis_accuracy = -1.0
            # --------------------------------------------------


            # ---------- calculate att loss -----------
            if opt.att_reward_flg == 1 and model.training:
            # if opt.att_reward_flg == 1 :
                att_loss = att_crit(model, gen_result.data.cpu().numpy())
                att_loss_num = att_loss.data.cpu().numpy()
            else:
                att_loss = 0.0
                att_loss_num = 0.0
            # ------------------------------------------

            # --- get states and actions xt and weights, ccs, seqs ---
            if opt.actor_critic_flg==1 and model.training:
                xts = model.all_xts
                weights_p = model.weights_p
                ccs = internal.output_action
            if opt.actor_critic_flg == 2 and model.training:
                all_logprops = model.all_logprops
                weight_state = model.state_weights
                # xts = model.all_xts
                gen_result_repeat = gen_result.repeat(1, opt.seq_length).view(all_logprops.size(0), opt.seq_length, opt.seq_length)
                # xts = seq_mask * gen_result_repeat
                xts = gen_result_repeat
                weights_p = model.weights_p
                # pdb.set_trace()
                if internal is not None:
                    ccs = internal.output_action
                else:
                    ccs = torch.zeros((len(xts), weights_p.size(1))).cuda()
            if opt.caption_model == 'hcatt_hard' and opt.ppo==1:
                xts = model.all_xts
                weights_p = model.weights_p
                weights = model.weights
            # ----------------------------------------------------------

            # ---------------- Calculate reward (CIDEr, Discriminator, Similarity...)---------------------
            if opt.actor_critic_flg == 2 and model.training:
                reward, pure_reward = get_self_critical_and_similarity_reward_for_actor_critic(dp_model,
                                                                                                   fc_feats,
                                                                                                   att_feats,
                                                                                                   att_masks, data,
                                                                                                   gen_result, opt,
                                                                                                   train_similarity,
                                                                                                   internal=internal,
                                                                                                   sim_model=sim_model,
                                                                                               label_region=data['label_region'],
                                                                                               D=Discriminator)
            else:
                reward, pure_reward, actor_critic_reward, target_update_flg = get_self_critical_and_similarity_reward(dp_model, fc_feats, att_feats,
                                                                          att_masks, data, gen_result, opt,
                                                                          train_similarity,
                                                                          internal=internal,
                                                                          sim_model=sim_model,
                                                                        label_region=data['label_region'],
                                                                          bbox=bbox,
                                                                        D=Discriminator,
                                                                        weight_index=weight_index,
                                                                        fixed_region=fixed_region,
                                                                        target_model=target_model)
                if target_update_flg and target_model is not None:
                    print('----- target model updated ! -----')
                    target_model.load_state_dict(model.state_dict())

                # print(train_similarity.mean(), model.similarity.mean())
            #----------------------------------------------------------


            #-------------------------------- calculate captioning model loss -----------------------------------------
            #------------ Calculate actor critic loss ----------------
            if opt.actor_critic_flg == 1 and model.training:
                # get q_value
                q_value = Q_net(fc_feats, att_feats, xts, weights_p, gen_result)
                # get target_sample
                with torch.no_grad():
                    gen_result_sample, __ = target_model(fc_feats, att_feats, att_masks,
                                                           seqs=gen_result, ccs=ccs, mode='sample')
                    target_q_value = target_Q_net(fc_feats, att_feats, target_model.all_xts, target_model.weights_p, gen_result)
                # calculate actor critic loss
                actor_critic_loss = Q_net.loss_func(actor_critic_reward, q_value, target_q_value)
                add_summary_value(tb_summary_writer, 'actor_critic_loss', actor_critic_loss.item(), iteration, opt.tag)
                Q_net_optimizer.zero_grad()
            elif opt.actor_critic_flg == 2 and model.training:
                # get q_value
                q_value = Q_net(fc_feats, att_feats, xts, weight_state.detach(), weights_p, all_logprops[:,:-1,:], gen_result)
                # get target_sample
                with torch.no_grad():
                    gen_result_sample, __ = target_model(fc_feats, att_feats, att_masks,
                                                         seqs=gen_result, ccs=ccs, mode='sample', state_weights=weight_state)
                    # pdb.set_trace()
                    target_q_value = target_Q_net(fc_feats, att_feats, xts, target_model.state_weights,
                                                  target_model.weights_p, target_model.all_logprops[:,:-1,:], gen_result)
                # calculate actor critic loss
                if reward is None:
                    pdb.set_trace()
                actor_critic_loss = Q_net.loss_func(reward, q_value, target_q_value, gen_result)
                print('actor_critic_loss', actor_critic_loss.item())
                add_summary_value(tb_summary_writer, 'actor_critic_loss', actor_critic_loss.item(), iteration,
                                  opt.tag)
                Q_net_optimizer.zero_grad()
            else:
                actor_critic_loss = 0

            model.att_score = att_loss_num

            # update ppo old policy
            if new_internal is not None and internal.iteration % 1 == 0:
                internal.load_state_dict(new_internal.state_dict())
            if opt.caption_model == 'hcatt_hard' and opt.ppo == 1:
                model.load_state_dict(new_model.state_dict())

            if not internal_rl_flg or opt.multi_learn_flg == 1:
                # if opt.ppo == 1 and opt.caption_model == 'hcatt_hard':
                # -------------- calculaete self critical loss ---------------
                if False:
                    # get coeffitient and calculate
                    new_gen_result, new_sample_logprobs = new_model(fc_feats, att_feats, att_masks,
                                                         seqs=gen_result,  mode='sample', decided_att=weights)
                    new_model.pre_weights_p = new_model.weights_p
                    new_model.pre_weights = new_model.weights
                    att_index = np.where(weights.data.cpu() > 0)
                    weights_p_ = weights_p[att_index].view(weights_p.size(0), weights_p.size(1))  # (batch, seq_length)
                    reward_coefficient = 1 / (torch.exp(sample_logprobs) * weights_p_).data.cpu()
                    # train caption network get reward and calculate loss
                    reward_loss, baseline = utils.calculate_loss(opt, rl_crit, rl_crit_hard, rl_crit_hard_base,
                                                                 new_sample_logprobs, gen_result, reward,
                                                                 baseline, new_model, reward_coefficient=reward_coefficient)
                elif (not internal_rl_flg or opt.multi_learn_flg == 1) and opt.actor_critic_flg == 0:
                    # train caption network get reward and calculate loss
                    if opt.weight_deterministic_flg == 7:
                        reward_loss, baseline = utils.calculate_loss(opt, rl_crit, rl_crit_hard, rl_crit_hard_base,
                                                                     sample_logprobs, word_exist_seq, reward,
                                                                     baseline, model)
                    else:
                        reward_loss, baseline = utils.calculate_loss(opt, rl_crit, rl_crit_hard, rl_crit_hard_base,
                                                                     sample_logprobs, gen_result, reward,
                                                                     baseline, model)
                else:
                    reward_loss = 0

                # -------------- calculaete self critical loss ---------------
                if (opt.caption_model == 'hcatt_simple' or  opt.caption_model == 'hcatt_simple_switch') and opt.xe_weight > 0.0:
                    output = dp_model(fc_feats, att_feats, labels, att_masks, internal, weight_index=weight_index)
                    xe_loss = crit(output, labels[:, 1:], masks[:, 1:])
                    print('r_loss: {}, xe_loss: {}'.format(reward_loss.item(), xe_loss.item()))
                    add_summary_value(tb_summary_writer, 'xe_loss', xe_loss.item(), iteration, opt.tag)
                    add_summary_value(tb_summary_writer, 'r_loss', reward_loss.item(), iteration, opt.tag)
                else:
                    xe_loss = 0.0

                loss = opt.rloss_weight * reward_loss + opt.att_lambda * att_loss + actor_critic_loss + opt.xe_weight * xe_loss
            # --------------------------------------------------------------------------------------------------------


        # ------------------------- calculate internal critic loss and update ---------------------------
        if internal_optimizer is not None and internal_rl_flg == True and sc_flag:

            internal_optimizer.zero_grad()
            if opt.region_bleu_flg == 1:
                sim_model_optimizer.zero_grad()
            if opt.only_critic_train == 0:
                internal_loss = rl_crit(internal.pre_output, gen_result.data, torch.from_numpy(reward).float().cuda(),
                                    reward_coefficient=internal.pre_reward_coefficient)
            else:
                internal_loss = rl_crit_conly(internal.pre_output, gen_result.data, torch.from_numpy(reward).float().cuda(),
                                        reward_coefficient=internal.pre_reward_coefficient, c_count=critic_train_count)
            q_value_prop = torch.exp(internal.pre_output)
            entropy = torch.mean(-1 * q_value_prop * torch.log2(q_value_prop + 1e-8) + -1 * (1 - q_value_prop) * torch.log2(
                    1 - q_value_prop + 1e-8))

            internal_loss = internal_loss
            internal_loss.backward()
            internal_optimizer.step()
            if opt.region_bleu_flg == 1:
                sim_model_optimizer.step()

            # ----- record loss and reward to tensorboard -----
            # q_value_prop = torch.exp(internal.pre_output)
            # entropy = torch.mean(-1 * q_value_prop * torch.log2(q_value_prop + 1e-8) + -1 * (1 - q_value_prop) * torch.log2(1 - q_value_prop + 1e-8))
            if opt.only_critic_train == 1:
                if internal is not None and sc_flag:
                    num_internal_switching = internal.same_action_flg.mean().item()
                else:
                    num_internal_switching = 0
                total_critic_reward += np.mean(pure_reward)
                total_critic_reward = utils.record_tb_about_critic(model, internal_loss.cpu().data, critic_train_count, opt.tag,
                                                                   tb_summary_writer, reward,
                                                                   pure_reward, entropy,
                                                                   opt.sim_sum_flg,num_internal_switching,
                                                                   total_critic_reward=total_critic_reward)
            else:
                if internal is not None and sc_flag:
                    num_internal_switching = internal.same_action_flg.mean().item()
                else:
                    num_internal_switching = 0
                total_critic_reward = utils.record_tb_about_critic(model, internal_loss.cpu().data, iteration, opt.tag,
                                         tb_summary_writer, reward, pure_reward, entropy, opt.sim_sum_flg, num_internal_switching)
            # -------------------------------------------------

            critic_train_count += 1

            internal.reset()
            internal.iteration+=1

            print('iter {} (epoch {}), internal_loss: {}, avg_reward: {}, entropy: {}'.format(iteration, epoch,internal_loss, reward.mean(), entropy))
        # --------------------------------------------------------------------------------------------------------
        else:
            #------------------------- updating captioning model ----------------------------
            loss.backward()
            utils.clip_gradient(optimizer, opt.grad_clip)
            optimizer.step()
            if opt.actor_critic_flg > 0 and model.training:
                utils.clip_gradient(Q_net_optimizer, opt.grad_clip)
                Q_net_optimizer.step()
                utils.soft_update(target_model, model, 0.001)
                utils.soft_update(target_Q_net, Q_net, 0.001)
                # if iteration % 1000 == 0:
                #     utils.hard_update(target_model, model)
                #     utils.hard_update(target_Q_net, Q_net)
                # else:
                #     utils.soft_update(target_model, model, 0.001)
                #     utils.soft_update(target_Q_net, Q_net, 0.001)

            train_loss = loss.item()
            torch.cuda.synchronize()
            del loss
            end = time.time()
            if internal is not None and sc_flag:
                num_internal_switching = internal.same_action_flg.mean().item()
            else:
                num_internal_switching = 0
            if not sc_flag:
                print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, train_loss, end - start))
            else:
                try:
                    print("iter {} (epoch {}), avg_reward = {:.3f}, att_loss = {:.3f}. time/batch = {:.3f}" \
                        .format(iteration, epoch, np.mean(reward[:,0]), model.att_score.item(), end - start))
                    utils.record_tb_about_model(model, pure_reward, tb_summary_writer, iteration, opt.tag,
                                                opt.sim_sum_flg, loss_mean, dis_accuracy, num_internal_switching)
                except AttributeError:
                    print("iter {} (epoch {}), avg_reward = {:.3f}, att_loss = {:.3f}. time/batch = {:.3f}" \
                          .format(iteration, epoch, np.mean(reward[:, 0]), model.att_score, end - start))
                    utils.record_tb_about_model(model, pure_reward, tb_summary_writer, iteration, opt.tag,
                                                opt.sim_sum_flg, loss_mean, dis_accuracy, num_internal_switching)
                RL_count += 1

            # --------------------------------------------------------------------------------



        # Update the iteration and epoch
        iteration += 1

        # -------------------- change train internal critic or caption network -----------------------------
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True
            if opt.cycle is None and internal is not None and opt.multi_learn_flg != 1:
                # and entropy < 1.0
                if internal_rl_flg == True and opt.only_critic_train == 0:
                    if opt.actor_critic_flg == 1:
                        utils.hard_update(target_model, model)
                        utils.hard_update(target_Q_net, Q_net)
                    internal_rl_flg = False
                    internal.eval()
                    dp_model.train()
                    if weight_index is not None and loader.weight_deterministic_flg == 4:
                        loader.weight_deterministic_flg = 5

                    if opt.region_bleu_flg == 1:
                        sim_model.eval()
                    train_loss = None
                # elif internal_optimizer is not None and internal_rl_flg == False:
                # elif internal_optimizer is not None and internal_rl_flg == False and (epoch + 1) % 3 == 0 and opt.internal_model != 'sim_dammy':
                # elif internal_optimizer is not None and internal_rl_flg == False and opt.internal_model != 'sim_dammy':
                else:
                    internal_rl_flg = True
                    # internal.load_state_dict(torch.load(result_path + '/internal_best.pth'))
                    if opt.ppo == 1:
                        internal_optimizer = optim.Adam(new_internal.parameters(), opt.c_learning_rate,
                                                        weight_decay=1e-5)
                    else:
                        internal_optimizer = optim.Adam(internal.parameters(), opt.c_learning_rate, weight_decay=1e-5)
                    internal.train()
                    if opt.region_bleu_flg == 1:
                        sim_model.train()
                    dp_model.eval()
                    if weight_index is not None and loader.weight_deterministic_flg == 5:
                        loader.weight_deterministic_flg = 4
                    internal.reset()
                    internal.max_r = 0
        # --------------------------------------------------------------------------------------------------

        # ------------------- Write the training loss summary ------------------------------
        if (iteration % opt.losses_log_every == 0) and internal_rl_flg == False and train_loss is not None:
            add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration, opt.tag)
            add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration, opt.tag)
            add_summary_value(tb_summary_writer, 'critic_learning_rate', opt.c_current_lr, iteration, opt.tag)
            add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration, opt.tag)
            if sc_flag:
                add_summary_value(tb_summary_writer, 'avg_reward', np.mean(reward[:,0]), iteration, opt.tag)

            loss_history[iteration] = train_loss if not sc_flag else np.mean(reward[:,0])
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob
        # ----------------------------------------------------------------------------------

        # ------------------------ make evaluation on validation set, and save model ------------------------------
        wdf7_eval_flg = (opt.weight_deterministic_flg != 7 or sc_flag)
        if ((iteration % opt.save_checkpoint_every == 0) or iteration == 39110 or iteration == 113280 or iteration == 151045 or iteration == 78225 or iteration == 31288 or iteration == 32850 or iteration == 46934) and train_loss is not None:
            if sc_flag and (opt.caption_model == 'hcatt_hard' or opt.caption_model == 'basicxt_hard' or opt.caption_model == 'hcatt_hard_nregion' or opt.caption_model == 'basicxt_hard_nregion'):
                if baseline is None:
                    baseline = torch.zeros((sample_logprobs.size()[0], sample_logprobs.size()[1] + 1)) / sample_logprobs.size()[1]
                    baseline = baseline.cuda()
                    # baseline = torch.log(baseline)
            # eval model
            varbose_loss = not sc_flag


            eval_kwargs = {'split': 'val',
                           'internal': internal,
                           'sim_model': sim_model,
                           'caption_model': opt.caption_model,
                           'baseline': baseline,
                           'gts': data['gts'],
                           'dataset': opt.dataset,
                           'verbose_loss': varbose_loss,
                           'weight_deterministic_flg': opt.weight_deterministic_flg
                           }
            eval_kwargs.update(vars(opt))

            # pdb.set_trace()
            if wdf7_eval_flg:
                # eval_utils.eval_writer(dp_model, iteration, loader, tb_summary_writer, eval_kwargs)
                val_loss, predictions, lang_stats = eval_utils.eval_split(dp_model, crit, loader, eval_kwargs)

                # Write validation result into summary
                add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration, opt.tag)
                if lang_stats is not None:
                    for k,v in lang_stats.items():
                        add_summary_value(tb_summary_writer, k, v, iteration, opt.tag)
                val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions}

                # Save model if is improving on validation result
                if opt.language_eval == 1:
                    current_score = lang_stats['CIDEr']
                else:
                    current_score = - val_loss
            else:
                val_result_history[iteration] = {'loss': None, 'lang_stats': None, 'predictions': None}
                current_score = 0

            best_flag = False
            if True: # if true
                if internal_rl_flg == False:
                    if best_val_score is None or current_score > best_val_score:
                        best_val_score = current_score
                        best_flag = True
                    checkpoint_path = os.path.join(result_path, 'model_{}_{}.pth'.format(epoch, iteration))
                    torch.save(model.state_dict(), checkpoint_path)

                    optimizer_path = os.path.join(result_path, 'optimizer_{}_{}.pth'.format(epoch, iteration))
                    torch.save(optimizer.state_dict(), optimizer_path)
                    print("model saved to {}".format(checkpoint_path))
                    if internal is not None:
                        internal.eval()
                        checkpoint_path = os.path.join(result_path, 'internal_{}_{}.pth'.format(epoch, iteration))
                        torch.save(internal.state_dict(), checkpoint_path)
                        optimizer_path = os.path.join(result_path,
                                                      'internal_optimizer_{}_{}.pth'.format(epoch, iteration))
                        torch.save(internal_optimizer.state_dict(), optimizer_path)
                        print("internal model saved to {}".format(checkpoint_path))
                        checkpoint_path = os.path.join(result_path, 'sim_model_{}_{}.pth'.format(epoch, iteration))
                        torch.save(sim_model.state_dict(), checkpoint_path)
                        print("sim_model saved to {}".format(checkpoint_path))

                else:
                    checkpoint_path = os.path.join(result_path, 'model_{}_{}.pth'.format(epoch, iteration))
                    torch.save(model.state_dict(), checkpoint_path)
                    optimizer_path = os.path.join(result_path, 'optimizer_{}_{}.pth'.format(epoch, iteration))
                    torch.save(optimizer.state_dict(), optimizer_path)
                    print("model saved to {}".format(checkpoint_path))
                    if best_val_score is None or current_score > best_val_score:
                        best_val_score = current_score
                        best_flag = True
                    checkpoint_path = os.path.join(result_path, 'internal_{}_{}.pth'.format(epoch, iteration))
                    torch.save(internal.state_dict(), checkpoint_path)

                    optimizer_path = os.path.join(result_path, 'internal_optimizer_{}_{}.pth'.format(epoch, iteration))
                    torch.save(internal_optimizer.state_dict(), optimizer_path)
                    print("internal model saved to {}".format(checkpoint_path))
                    checkpoint_path = os.path.join(result_path, 'sim_model_{}_{}.pth'.format(epoch, iteration))
                    torch.save(sim_model.state_dict(), checkpoint_path)
                    print("sim_model saved to {}".format(checkpoint_path))
                    dp_model.eval()

                if Discriminator is not None:
                    discriminator_path = os.path.join(result_path, 'discriminator_{}_{}.pth'.format(epoch, iteration))
                    torch.save(Discriminator.state_dict(), discriminator_path)
                    dis_optimizer_path = os.path.join(result_path, 'dis_optimizer_{}_{}.pth'.format(epoch, iteration))
                    torch.save(dis_optimizer.state_dict(), dis_optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()
                infos['internal_rl_flg'] = internal_rl_flg

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history

                with open(os.path.join(result_path, 'infos_'+opt.id+'.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(os.path.join(result_path, 'histories_'+opt.id+'.pkl'), 'wb') as f:
                    cPickle.dump(histories, f)
                if best_flag:
                    checkpoint_path = os.path.join(result_path, 'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))

                    with open(os.path.join(result_path, 'infos_'+opt.id+'-best.pkl'), 'wb') as f:
                        cPickle.dump(infos, f)
                # pdb.set_trace()
        # ---------------------------------------------------------------------------------------------------------

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
Ejemplo n.º 15
0
def train(opt):
    acc_steps = getattr(opt, 'acc_steps', 1)

    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length
    opt.ix_to_word = loader.ix_to_word

    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        with open(os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl'), 'rb') as f:
            infos = utils.pickle_load(f)
            saved_model_opt = infos['opt']
            need_be_same = ["caption_model", "rnn_type", "rnn_size", "num_layers"]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], "Command line argument and saved model disagree on '%s' " % checkme
        if os.path.isfile(os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl'), 'rb') as f:
                histories = utils.pickle_load(f)
    else:
        infos['iter'] = 0
        infos['epoch'] = 0
        infos['iterators'] = loader.iterators
        infos['split_ix'] = loader.split_ix
        infos['vocab'] = loader.get_vocab()
    infos['opt'] = opt

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    opt.vocab = loader.get_vocab()
    model = models.setup(opt).cuda()
    del opt.vocab
    dp_model = torch.nn.DataParallel(model)
    lw_model = LossWrapper(model, opt)
    dp_lw_model = torch.nn.DataParallel(lw_model)

    epoch_done = True
    # Assure in training mode
    dp_lw_model.train()

    if opt.noamopt:
        optimizer = utils.get_std_opt(model, factor=opt.noamopt_factor, warmup=opt.noamopt_warmup)
        optimizer._step = iteration
    elif opt.reduce_on_plateau:
        optimizer = utils.build_optimizer(model.parameters(), opt)
        optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3)
    else:
        optimizer = utils.build_optimizer(model.parameters(), opt)

    def save_checkpoint(model, infos, optimizer, histories=None, append=''):
        if len(append) > 0:
            append = '-' + append
        # if checkpoint_path doesn't exist
        if not os.path.isdir(opt.checkpoint_path):
            os.makedirs(opt.checkpoint_path)
        checkpoint_path = os.path.join(opt.checkpoint_path, 'model%s.pth' %(append))
        torch.save(model.state_dict(), checkpoint_path)
        print("model saved to {}".format(checkpoint_path))
        optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer%s.pth' %(append))
        torch.save(optimizer.state_dict(), optimizer_path)
        with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'%s.pkl' %(append)), 'wb') as f:
            utils.pickle_dump(infos, f)
        if histories:
            with open(os.path.join(opt.checkpoint_path, 'histories_'+opt.id+'%s.pkl' %(append)), 'wb') as f:
                utils.pickle_dump(histories, f)

    try:
        while True:
            sys.stdout.flush()
            if epoch_done:
                if not opt.noamopt and not opt.reduce_on_plateau:
                    # Assign the learning rate
                    if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                        frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every
                        decay_factor = opt.learning_rate_decay_rate ** frac
                        opt.current_lr = opt.learning_rate * decay_factor
                    else:
                        opt.current_lr = opt.learning_rate
                    utils.set_lr(optimizer, opt.current_lr)
                    print('Learning Rate: ', opt.current_lr)
                if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                    sc_flag = True
                    init_scorer(opt.cached_tokens)
                else:
                    sc_flag = False
                epoch_done = False

            data = loader.get_batch('train')
            if (iteration % acc_steps == 0):
                optimizer.zero_grad()

            torch.cuda.synchronize()
            start = time.time()
            tmp = [data['fc_feats'], data['att_feats'], data['c3d_feats'], data['labels'], data['masks'], data['att_masks'], data['c3d_masks']]
            tmp = [_ if _ is None else _.cuda() for _ in tmp]
            fc_feats, att_feats, c3d_feats, labels, masks, att_masks, c3d_masks = tmp

            model_out = dp_lw_model(fc_feats, att_feats, c3d_feats, labels, masks, att_masks, c3d_masks, data['gts'], torch.arange(0, len(data['gts'])), sc_flag)

            loss = model_out['loss'].mean()
            loss_sp = loss / acc_steps

            loss_sp.backward()
            if ((iteration + 1) % acc_steps == 0):
                utils.clip_gradient(optimizer, opt.grad_clip)
                optimizer.step()
            torch.cuda.synchronize()
            train_loss = loss.item()
            end = time.time()
            if iteration % 1 == 0:
                if not sc_flag:
                    print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}".format(iteration, epoch, train_loss, end - start))
                else:
                    print("iter {} (epoch {}), reward1 = {:.3f}, reward2 = {:.3f}, reward3 = {:.3f}, train_loss = {:.3f}, time/batch = {:.3f}".format(iteration, epoch, model_out['reward_layer1'].mean(), model_out['reward_layer2'].mean(), model_out['reward_layer3'].mean(), train_loss, end - start))

            iteration += 1
            if data['bounds']['wrapped']:
                epoch += 1
                epoch_done = True

            if (iteration % opt.losses_log_every == 0):
                add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration)
                if opt.noamopt:
                    opt.current_lr = optimizer.rate()
                elif opt.reduce_on_plateau:
                    opt.current_lr = optimizer.current_lr
                add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration)
                add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration)
                if sc_flag:
                    add_summary_value(tb_summary_writer, 'reward1', model_out['reward_layer1'].mean(), iteration)
                    add_summary_value(tb_summary_writer, 'reward2', model_out['reward_layer2'].mean(), iteration)
                    add_summary_value(tb_summary_writer, 'reward3', model_out['reward_layer3'].mean(), iteration)

                loss_history[iteration] = train_loss
                lr_history[iteration] = opt.current_lr
                ss_prob_history[iteration] = model.ss_prob

            infos['iter'] = iteration
            infos['epoch'] = epoch
            infos['iterators'] = loader.iterators
            infos['split_ix'] = loader.split_ix

            if (iteration % opt.save_checkpoint_every == 0):
                # eval model
                eval_kwargs = {'split': opt.val_split, 'dataset': opt.input_json}
                eval_kwargs.update(vars(opt))
                val_loss, predictions, lang_stats = eval_utils.eval_split(dp_model, lw_model.crit, loader, eval_kwargs)
                print('Summary Epoch {} Iteration {}: CIDEr: {} BLEU-4: {}'.format(epoch, iteration, lang_stats['CIDEr'], lang_stats['Bleu_4']))

                if opt.reduce_on_plateau:
                    if opt.reward_metric == 'cider':
                        optimizer.scheduler_step(-lang_stats['CIDEr'])
                    elif opt.reward_metric == 'bleu':
                        optimizer.scheduler_step(-lang_stats['Bleu_4'])
                    elif opt.reward_metric == 'meteor':
                        optimizer.scheduler_step(-lang_stats['METEOR'])
                    elif opt.reward_metric == 'rouge':
                        optimizer.scheduler_step(-lang_stats['ROUGE_L'])
                    else:
                        optimizer.scheduler_step(val_loss)
                # Write validation result into summary
                add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration)
                if lang_stats is not None:
                    for k,v in lang_stats.items():
                        add_summary_value(tb_summary_writer, k, v, iteration)
                val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions}

                # Save model if is improving on validation result
                if opt.language_eval == 1:
                    if opt.reward_metric == 'cider':
                        current_score = lang_stats['CIDEr']
                    elif opt.reward_metric == 'bleu':
                        current_score = lang_stats['Bleu_4']
                    elif opt.reward_metric == 'meteor':
                        current_score = lang_stats['METEOR']
                    elif opt.reward_metric == 'rouge':
                        current_score = lang_stats['ROUGE_L']
                else:
                    current_score = - val_loss

                best_flag = False

                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True

                # Dump miscalleous informations
                infos['best_val_score'] = best_val_score
                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history

                save_checkpoint(model, infos, optimizer, histories)
                if opt.save_history_ckpt:
                    save_checkpoint(model, infos, optimizer, append=str(iteration))

                if best_flag:
                    save_checkpoint(model, infos, optimizer, append='best')

            # Stop if reaching max epochs
            if epoch >= opt.max_epochs and opt.max_epochs != -1:
                break

    except (RuntimeError, KeyboardInterrupt):
        print('Save ckpt on exception ...')
        save_checkpoint(model, infos, optimizer)
        print('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)
Ejemplo n.º 16
0
def train(opt):
    # Deal with feature things before anything
    opt.use_fc, opt.use_att = utils.if_use_feat(opt.caption_model)
    if opt.use_box:
        opt.att_feat_size = opt.att_feat_size + 5

    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl'), 'rb') as f:
            infos = utils.pickle_load(f)
            saved_model_opt = infos['opt']
            need_be_same = ["caption_model",
                            "rnn_type", "rnn_size", "num_layers"]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(opt)[
                    checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl'), 'rb') as f:
                histories = utils.pickle_load(f)
    else:
        infos['iter'] = 0
        infos['epoch'] = 0
        infos['iterators'] = loader.iterators
        infos['split_ix'] = loader.split_ix
        infos['vocab'] = loader.get_vocab()
    infos['opt'] = opt

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    # cnn_model = utils.build_cnn(opt)
    cnn_model = create_extractor("/root/PycharmProjects/vgg_vae_best_model.pth")
    cnn_model = cnn_model.cuda()

    if vars(opt).get('start_from', None) is not None:
        cnn_model.load_state_dict(torch.load(os.path.join(opt.start_from, 'model-cnn.pth')))
        print("load cnn model parameters from {}".format(os.path.join(opt.start_from, 'model-cnn.pth')))

    model = models.setup(opt).cuda()
    dp_model = torch.nn.DataParallel(model)
    lw_model = LossWrapper(model, opt)
    dp_lw_model = torch.nn.DataParallel(lw_model)
    # dp_lw_model = lw_model
    epoch_done = True
    # Assure in training mode
    dp_lw_model.train()

    if opt.noamopt:
        assert opt.caption_model == 'transformer', 'noamopt can only work with transformer'
        optimizer = utils.get_std_opt(
            model, factor=opt.noamopt_factor, warmup=opt.noamopt_warmup)
        optimizer._step = iteration
    elif opt.reduce_on_plateau:
        optimizer = utils.build_optimizer(model.parameters(), opt)
        optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3)
    else:
        optimizer = utils.build_optimizer(model.parameters(), opt)

    # if opt.finetune_cnn_after != -1:
    #     # only finetune the layer2 to layer4
    cnn_optimizer = optim.Adam([
            {'params': module.parameters()} for module in cnn_model.finetune_modules
        ], lr=opt.cnn_learning_rate, weight_decay=opt.cnn_weight_decay)

    # Load the optimizer
    if vars(opt).get('start_from', None) is not None:
        if os.path.isfile(os.path.join(opt.start_from, "optimizer.pth")):
            optimizer.load_state_dict(torch.load(
                os.path.join(opt.start_from, 'optimizer.pth')))
        if opt.finetune_cnn_after != -1:
            if os.path.isfile(os.path.join(opt.start_from, 'optimizer-cnn.pth')):
                cnn_optimizer.load_state_dict(torch.load(
                    os.path.join(opt.start_from, 'optimizer-cnn.pth')))

    def save_checkpoint(model, cnn_model, infos, optimizer, cnn_optimizer, histories=None, append=''):
        if len(append) > 0:
            append = '-' + append
        # if checkpoint_path doesn't exist
        if not os.path.isdir(opt.checkpoint_path):
            os.makedirs(opt.checkpoint_path)
        checkpoint_path = os.path.join(
            opt.checkpoint_path, 'model%s.pth' % (append))
        torch.save(model.state_dict(), checkpoint_path)
        print("model saved to {}".format(checkpoint_path))

        cnn_checkpoint_path = os.path.join(
            opt.checkpoint_path, 'model-cnn%s.pth' % (append))
        torch.save(cnn_model.state_dict(), cnn_checkpoint_path)
        print("cnn model saved to {}".format(cnn_checkpoint_path))

        optimizer_path = os.path.join(
            opt.checkpoint_path, 'optimizer%s.pth' % (append))
        torch.save(optimizer.state_dict(), optimizer_path)

        if opt.finetune_cnn_after != -1 and epoch >= opt.finetune_cnn_after:
            cnn_optimizer_path = os.path.join(
                opt.checkpoint_path, 'optimizer%s-cnn.pth' % (append))
            torch.save(cnn_optimizer.state_dict(), cnn_optimizer_path)

        with open(os.path.join(opt.checkpoint_path, 'infos_' + opt.id + '%s.pkl' % (append)), 'wb') as f:
            utils.pickle_dump(infos, f)
        if histories:
            with open(os.path.join(opt.checkpoint_path, 'histories_' + opt.id + '%s.pkl' % (append)), 'wb') as f:
                utils.pickle_dump(histories, f)

    try:
        while True:
            if epoch_done:
                if not opt.noamopt and not opt.reduce_on_plateau:
                    # Assign the learning rate
                    if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                        frac = (
                            epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every
                        decay_factor = opt.learning_rate_decay_rate ** frac
                        opt.current_lr = opt.learning_rate * decay_factor
                    else:
                        opt.current_lr = opt.learning_rate
                    # set the decayed rate
                    utils.set_lr(optimizer, opt.current_lr)
                # Assign the scheduled sampling prob
                if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                    frac = (
                        epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every
                    opt.ss_prob = min(
                        opt.scheduled_sampling_increase_prob * frac, opt.scheduled_sampling_max_prob)
                    model.ss_prob = opt.ss_prob

                # Update the training stage of cnn
                if opt.finetune_cnn_after == -1 or epoch < opt.finetune_cnn_after:
                    for p in cnn_model.parameters():
                        p.requires_grad = False
                    cnn_model.eval()
                else:
                    for p in cnn_model.parameters():
                        p.requires_grad = True
                    # Fix the first few layers:
                    for module in cnn_model.fixed_modules:
                        for p in module.parameters():
                            p.requires_grad = False
                        cnn_model.train()

                # If start self critical training
                if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                    sc_flag = True
                    init_scorer(opt.cached_tokens)
                else:
                    sc_flag = False

                epoch_done = False

            start = time.time()
            # Load data from train split (0)
            data = loader.get_batch('train')
            torch.cuda.synchronize()
            print('Read data:', time.time() - start)

            torch.cuda.synchronize()
            start = time.time()

            tmp = [data['fc_feats'], data['att_feats'],
                   data['labels'], data['masks'], data['att_masks']]
            tmp = [_ if _ is None else _.cuda() for _ in tmp]
            fc_feats, att_feats, labels, masks, att_masks = tmp

            # att_feats 8x672x224
            att_feats = att_feats.view(att_feats.size(0), 3, 224, 224)
            att_feats, fc_feats = cnn_model(att_feats)
            # fc_feats = att_feats.mean(3).mean(2)
            # att_feats = torch.nn.functional.adaptive_avg_pool2d(
            #     att_feats, [7, 7]).permute(0, 2, 3, 1)
            att_feats = att_feats.permute(0, 2, 3, 1)
            att_feats = att_feats.view(att_feats.size(0), 49, -1)

            att_feats = att_feats.unsqueeze(1).expand(*((att_feats.size(0), opt.seq_per_img,) + att_feats.size(
            )[1:])).contiguous().view((att_feats.size(0) * opt.seq_per_img), -1, att_feats.size()[-1])
            fc_feats = fc_feats.unsqueeze(1).expand(*((fc_feats.size(0), opt.seq_per_img,) + fc_feats.size(
            )[1:])).contiguous().view(*((fc_feats.size(0) * opt.seq_per_img,) + fc_feats.size()[1:]))

            optimizer.zero_grad()
            if opt.finetune_cnn_after != -1 and epoch >= opt.finetune_cnn_after:
                cnn_optimizer.zero_grad()
            model_out = dp_lw_model(fc_feats, att_feats, labels, masks, att_masks,
                                 data['gts'], torch.arange(0, len(data['gts'])), sc_flag)

            loss = model_out['loss'].mean()

            # loss.backward()
            # utils.clip_gradient(optimizer, opt.grad_clip)
            # optimizer.step()

            if opt.finetune_cnn_after != -1 and epoch >= opt.finetune_cnn_after:
                utils.clip_gradient(cnn_optimizer, opt.grad_clip)
                cnn_optimizer.step()

            train_loss = loss.item()
            torch.cuda.synchronize()
            end = time.time()
            if not sc_flag:
                print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}"
                      .format(iteration, epoch, train_loss, end - start))
            else:
                print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}"
                      .format(iteration, epoch, model_out['reward'].mean(), end - start))

            # Update the iteration and epoch
            iteration += 1
            if data['bounds']['wrapped']:
                epoch += 1
                epoch_done = True

            # Write the training loss summary
            if (iteration % opt.losses_log_every == 0):
                add_summary_value(tb_summary_writer,
                                  'train_loss', train_loss, iteration)
                if opt.noamopt:
                    opt.current_lr = optimizer.rate()
                elif opt.reduce_on_plateau:
                    opt.current_lr = optimizer.current_lr
                add_summary_value(tb_summary_writer,
                                  'learning_rate', opt.current_lr, iteration)
                add_summary_value(
                    tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration)
                if sc_flag:
                    add_summary_value(
                        tb_summary_writer, 'avg_reward', model_out['reward'].mean(), iteration)

                loss_history[iteration] = train_loss if not sc_flag else model_out['reward'].mean(
                )
                lr_history[iteration] = opt.current_lr
                ss_prob_history[iteration] = model.ss_prob

            # update infos
            infos['iter'] = iteration
            infos['epoch'] = epoch
            infos['iterators'] = loader.iterators
            infos['split_ix'] = loader.split_ix

            # make evaluation on validation set, and save model
            if (iteration % opt.save_checkpoint_every == 0):
                # eval model
                eval_kwargs = {'split': 'val',
                               'dataset': opt.input_json}
                eval_kwargs.update(vars(opt))
                val_loss, predictions, lang_stats = eval_utils.eval_split(
                    cnn_model, model, lw_model.crit, loader, eval_kwargs)

                if opt.reduce_on_plateau:
                    if 'CIDEr' in lang_stats:
                        optimizer.scheduler_step(-lang_stats['CIDEr'])
                    else:
                        optimizer.scheduler_step(val_loss)
                # Write validation result into summary
                add_summary_value(tb_summary_writer,
                                  'validation loss', val_loss, iteration)
                if lang_stats is not None:
                    for k, v in lang_stats.items():
                        add_summary_value(tb_summary_writer, k, v, iteration)
                val_result_history[iteration] = {
                    'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions}

                # Save model if is improving on validation result
                if opt.language_eval == 1:
                    current_score = lang_stats['CIDEr']
                else:
                    current_score = - val_loss

                best_flag = False

                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True

                # Dump miscalleous informations
                infos['best_val_score'] = best_val_score
                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history

                save_checkpoint(model, cnn_model, infos,
                                optimizer, cnn_optimizer, histories)
                if opt.save_history_ckpt:
                    save_checkpoint(model, cnn_model, infos, optimizer, cnn_optimizer,
                                    append=str(iteration))

                if best_flag:
                    save_checkpoint(model, cnn_model, infos,
                                    optimizer, cnn_optimizer, append='best')

            # Stop if reaching max epochs
            if epoch >= opt.max_epochs and opt.max_epochs != -1:
                break
    except (RuntimeError, KeyboardInterrupt):
        print('Save ckpt on exception ...')
        save_checkpoint(model, cnn_model, infos, optimizer, cnn_optimizer)
        print('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)

    # test model
    test_kwargs = {'split': 'test',
                   'dataset': opt.input_json}
    test_kwargs.update(vars(opt))
    val_loss, predictions, lang_stats = eval_utils.eval_split(
        cnn_model, model, lw_model.crit, loader, test_kwargs)

    if opt.reduce_on_plateau:
        if 'CIDEr' in lang_stats:
            optimizer.scheduler_step(-lang_stats['CIDEr'])
        else:
            optimizer.scheduler_step(val_loss)
            # Write validation result into summary
    add_summary_value(tb_summary_writer,
                      'test loss', val_loss, iteration)
    if lang_stats is not None:
        for k, v in lang_stats.items():
            add_summary_value(tb_summary_writer, k, v, iteration)
    val_result_history[iteration] = {
        'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions}
Ejemplo n.º 17
0
def train(opt):
    print(opt)

    # To reproduce training results
    init_seed()
    # Image Preprocessing
    # For normalization, see https://github.com/pytorch/vision#models
    transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(degrees=10),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406),(0.229, 0.224, 0.225))
                             ])
    # Deal with feature things before anything
    opt.use_fc, opt.use_att = utils.if_use_feat(opt.caption_model)
    if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5

    loader = DataLoader(opt, transform=transform)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from, 'infos_' + opt.id + '-best.pkl'), 'rb') as f:
            infos = utils.pickle_load(f)
            saved_model_opt = infos['opt']
            need_be_same = ["caption_model", "rnn_type", "rnn_size", "num_layers"]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(opt)[
                    checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl'), 'rb') as f:
                histories = utils.pickle_load(f)
    else:
        infos['iter'] = 0
        infos['epoch'] = 0
        infos['iterators'] = loader.iterators
        infos['split_ix'] = loader.split_ix
        infos['vocab'] = loader.get_vocab()
    infos['opt'] = opt

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    opt.vocab = loader.get_vocab()
    if torch.cuda.is_available():
        model = models.setup(opt).cuda()
    else:
        model = models.setup(opt)
    del opt.vocab
    dp_model = torch.nn.DataParallel(model)
    lw_model = LossWrapper(model, opt)
    dp_lw_model = torch.nn.DataParallel(lw_model)
    #fgm = FGM(model)

    cnn_model = ResnetBackbone()
    if torch.cuda.is_available():
        cnn_model = cnn_model.cuda()
    if opt.start_from is not None:
        model_dict = cnn_model.state_dict()
        predict_dict = torch.load(os.path.join(opt.start_from, 'cnn_model-best.pth'))
        model_dict = {k: predict_dict["module."+k] for k, _ in model_dict.items() if "module."+ k in predict_dict}
        cnn_model.load_state_dict(model_dict)
    cnn_model = torch.nn.DataParallel(cnn_model)

    epoch_done = True
    # Assure in training mode
    dp_lw_model.train()

    if opt.noamopt:
        assert opt.caption_model == 'transformer', 'noamopt can only work with transformer'
        optimizer = utils.get_std_opt(model, factor=opt.noamopt_factor, warmup=opt.noamopt_warmup)
        optimizer._step = iteration
    elif opt.reduce_on_plateau:
        optimizer = utils.build_optimizer(model.parameters(), opt)
        optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3)
    else:
        optimizer = utils.build_optimizer(model.parameters(), opt)
    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(os.path.join(opt.start_from, "optimizer.pth")):
        optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'optimizer-best.pth')))

    def save_checkpoint(model, cnn_model, infos, optimizer, histories=None, append=''):
        if len(append) > 0:
            append = '-' + append
        # if checkpoint_path doesn't exist
        if not os.path.isdir(opt.checkpoint_path):
            os.makedirs(opt.checkpoint_path)
        #Transformer model
        checkpoint_path = os.path.join(opt.checkpoint_path, 'model%s.pth' % (append))
        torch.save(model.state_dict(), checkpoint_path)
        print("model saved to {}".format(checkpoint_path))
        #CNN model
        checkpoint_path = os.path.join(opt.checkpoint_path, 'cnn_model%s.pth' % (append))
        if not os.path.exists(checkpoint_path):
            torch.save(cnn_model.state_dict(), checkpoint_path)
            print("model saved to {}".format(checkpoint_path))
        optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer%s.pth' % (append))
        torch.save(optimizer.state_dict(), optimizer_path)
        with open(os.path.join(opt.checkpoint_path, 'infos_' + opt.id + '%s.pkl' % (append)), 'wb') as f:
            utils.pickle_dump(infos, f)
        if histories:
            with open(os.path.join(opt.checkpoint_path, 'histories_' + opt.id + '%s.pkl' % (append)), 'wb') as f:
                utils.pickle_dump(histories, f)

    cnn_after = 3
    try:
        while True:
            if epoch_done:
                if  opt.fix_cnn or epoch < cnn_after:
                    for p in cnn_model.parameters():
                        p.requires_grad = False
                    cnn_model.eval()
                    cnn_optimizer = None
                else:
                    for p in cnn_model.parameters():
                        p.requires_grad = True
                    # Fix the first few layers:
                    for module in cnn_model._modules['module']._modules['resnet_conv'][:5]._modules.values():
                        for p in module.parameters():
                            p.requires_grad = False
                    cnn_model.train()
                    # Constructing CNN parameters for optimization, only fine-tuning higher layers
                    cnn_optimizer = torch.optim.Adam(
                        (filter(lambda p: p.requires_grad, cnn_model.parameters())),
                        lr=2e-6 if (opt.self_critical_after != -1 and epoch >= opt.self_critical_after) else 5e-5, betas=(0.8, 0.999))

                if not opt.noamopt and not opt.reduce_on_plateau:
                    # Assign the learning rate
                    if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                        frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every
                        decay_factor = opt.learning_rate_decay_rate ** frac
                        opt.current_lr = opt.learning_rate * decay_factor
                    else:
                        opt.current_lr = opt.learning_rate
                    utils.set_lr(optimizer, opt.current_lr)  # set the decayed rate
                # Assign the scheduled sampling prob
                if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                    frac = (epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every
                    opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac, opt.scheduled_sampling_max_prob)
                    model.ss_prob = opt.ss_prob

                # If start self critical training
                if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                    sc_flag = True
                    init_scorer(opt.cached_tokens)
                else:
                    sc_flag = False

                epoch_done = False

            start = time.time()
            # Load data from train split (0)
            data = loader.get_batch('train')
            if iteration % opt.losses_log_every == 0:
                print('Read data:', time.time() - start)

            if torch.cuda.is_available():
                torch.cuda.synchronize()
            start = time.time()

            if torch.cuda.is_available():
                data['att_feats'] = cnn_model( data['att_feats'].cuda())
            else:
                data['att_feats'] = cnn_model( data['att_feats'] )
            data['att_feats'] = repeat_feat(data['att_feats'])
            tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks']]
            if torch.cuda.is_available():
                tmp = [_ if _ is None else _.cuda() for _ in tmp]
            fc_feats, att_feats, labels, masks, att_masks = tmp

            optimizer.zero_grad()
            if cnn_optimizer is not None:
                cnn_optimizer.zero_grad()

            # if epoch >= cnn_after:
            #     att_feats.register_hook(save_grad("att_feats"))
            model_out = dp_lw_model(fc_feats, att_feats, labels, masks, att_masks, data['gts'],
                                    torch.arange(0, len(data['gts'])), sc_flag)

            loss = model_out['loss'].mean()

            loss.backward()

            #loss.backward(retain_graph=True)

            # adversarial training
            #fgm.attack(emb_name='model.tgt_embed.0.lut.weight')
            #adv_out = dp_lw_model(fc_feats, att_feats, labels, masks, att_masks, data['gts'],
            #                      torch.arange(0, len(data['gts'])), sc_flag)

            #adv_loss = adv_out['loss'].mean()
            #adv_loss.backward()
            #fgm.restore(emb_name="model.tgt_embed.0.lut.weight")


            # utils.clip_gradient(optimizer, opt.grad_clip)
            optimizer.step()
            if cnn_optimizer is not None:
                cnn_optimizer.step()
            train_loss = loss.item()
            if torch.cuda.is_available():
                torch.cuda.synchronize()
            end = time.time()
            if not sc_flag and iteration % opt.losses_log_every == 0:
                print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                      .format(iteration, epoch, train_loss, end - start))
            elif iteration % opt.losses_log_every == 0:
                print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \
                      .format(iteration, epoch, model_out['reward'].mean(), end - start))

            # Update the iteration and epoch
            iteration += 1
            if data['bounds']['wrapped']:
                epoch += 1
                epoch_done = True

            # Write the training loss summary
            if (iteration % opt.losses_log_every == 0):
                add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration)
                if opt.noamopt:
                    opt.current_lr = optimizer.rate()
                elif opt.reduce_on_plateau:
                    opt.current_lr = optimizer.current_lr
                add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration)
                add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration)
                if sc_flag:
                    add_summary_value(tb_summary_writer, 'avg_reward', model_out['reward'].mean(), iteration)

                loss_history[iteration] = train_loss if not sc_flag else model_out['reward'].mean()
                lr_history[iteration] = opt.current_lr
                ss_prob_history[iteration] = model.ss_prob

            # update infos
            infos['iter'] = iteration
            infos['epoch'] = epoch
            infos['iterators'] = loader.iterators
            infos['split_ix'] = loader.split_ix

            # make evaluation on validation set, and save model
            if (iteration % opt.save_checkpoint_every == 0):
                # eval model
                eval_kwargs = {'split': 'val',
                               'dataset': opt.input_json}
                eval_kwargs.update(vars(opt))
                eval_kwargs["cnn_model"] = cnn_model
                val_loss, predictions, lang_stats = eval_utils.eval_split(
                    dp_model, lw_model.crit, loader, eval_kwargs)

                if opt.reduce_on_plateau:
                    if 'CIDEr' in lang_stats:
                        optimizer.scheduler_step(-lang_stats['CIDEr'])
                    else:
                        optimizer.scheduler_step(val_loss)
                # Write validation result into summary
                add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration)
                if lang_stats is not None:
                    for k, v in lang_stats.items():
                        add_summary_value(tb_summary_writer, k, v, iteration)
                val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions}

                # Save model if is improving on validation result
                if opt.language_eval == 1:
                    current_score = lang_stats['CIDEr']
                else:
                    current_score = - val_loss

                best_flag = False

                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True

                # Dump miscalleous informations
                infos['best_val_score'] = best_val_score
                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history

                save_checkpoint(model, cnn_model, infos, optimizer, histories)
                if opt.save_history_ckpt:
                    save_checkpoint(model, cnn_model, infos, optimizer, append=str(iteration))

                if best_flag:
                    save_checkpoint(model, cnn_model, infos, optimizer, append='best')

            # Stop if reaching max epochs
            if epoch >= opt.max_epochs and opt.max_epochs != -1:
                break
    except (RuntimeError, KeyboardInterrupt):
        print('Save ckpt on exception ...')
        save_checkpoint(model, cnn_model, infos, optimizer)
        print('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)
Ejemplo n.º 18
0
def train(opt):
    loader = Loader(opt)
    infos = {}
    histories = {}

    Model = model.setup(opt).cuda()
    LW_model = LossWrapper(Model, opt)
    # DP_lw_model = torch.nn.DataParallel(LW_model)
    LW_model.train()
    optimizer = utils.build_optimizer(Model.parameters(), opt)

    if opt.start_from is not None:
        with open(os.path.join(opt.start_from, 'infos-best.pkl'), 'rb') as f:
            infos = utils.pickle_load(f)

        if os.path.isfile(os.path.join(opt.start_from, 'histories-best.pkl')):
            with open(os.path.join(opt.start_from, 'histories-best.pkl'), 'rb') as f:
                histories = utils.pickle_load(f)

        if os.path.isfile(os.path.join(opt.start_from, 'optimizer-best.pth')):
            optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'optimizer-best.pth')))
    else:
        infos['iter'] = 0
        infos['epoch'] = 0
        infos['opt'] = opt
        infos['label2id'] = load_label(opt.input_label2id)

    iteration = infos.get('iter', '0')
    epoch = infos.get('epoch', '0')
    best_val_score = infos.get('best_val_score', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    epoch_done = True
    best_epoch = -1
    try:
        while True:
            if epoch_done:
                iteration = 0
                if epoch != 0:
                    predictions, targets, _ ,metrics = eval_utils.evaluate(Model, loader, infos['label2id'], opt.eval_batch_size, opt.rel_num, 'dev')
                    val_result_history[iteration] = {'predictions': predictions, 'metrics': metrics, 'targets': targets}
                    #print('dev res: ', metrics)
                    current_score = metrics['F1']
                    histories['c'] = val_result_history
                    histories['loss_history'] = loss_history
                    histories['lr_history'] = lr_history

                    best_flag = False
                    if current_score > best_val_score:
                        best_epoch = epoch
                        best_val_score = current_score
                        best_flag = True
                    infos['best_val_score'] = best_val_score

                    save_checkpoint(Model, infos, optimizer, histories)
                    if best_flag:
                        save_checkpoint(Model, infos, optimizer, append='best')


                epoch_done = False
                if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                    frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every
                    decay_factor = opt.learning_rate_decay ** frac
                    opt.current_lr = opt.learning_rate * decay_factor
                else:
                    opt.current_lr = opt.learning_rate
                utils.set_lr(optimizer, opt.current_lr)
            start = time.time()
            data = loader.get_batch_train(opt.batch_size)
            #data = sorted(data, key=lambda x: x[-1], reverse=True)
            wrapped = data[-1]
            data = data[:-1]
            #print('Read data:', time.time() - start)

            torch.cuda.synchronize()
            start = time.time()
            data = [t.cuda() for t in data]
            sents, rels, labels, poses, chars, sen_lens = data
            if not opt.use_char:
                chars = None
            if not opt.use_pos:
                poses = None
            mask = torch.zeros(sents.size()).cuda()
            for i in range(sents.size(0)):
                mask[i][:sen_lens[i]] = 1

            mask2 = torch.where(labels == 8, torch.ones_like(sents), torch.ones_like(sents)*10).cuda()
            mask2 = mask2.float() * mask.float()

            optimizer.zero_grad()
            sum_loss = LW_model(sents, sen_lens, rels, mask, labels, mask2, poses, chars)

            loss = sum_loss/sents.shape[0]
            loss.backward()
            utils.clip_gradient(optimizer, opt.grad_clip)
            optimizer.step()
            train_loss = loss.item()
            torch.cuda.synchronize()
            if iteration % 200 == 0:
                end = time.time()
                print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, train_loss, end - start))

            iteration += 1
            if wrapped:
                epoch += 1
                epoch_done = True
            infos['iter'] = iteration
            infos['epoch'] = epoch

            if iteration % opt.save_loss_every == 0:
                loss_history[iteration] = train_loss
                lr_history[iteration] = opt.current_lr
            if opt.max_epoch != -1 and epoch >= opt.max_epoch:
                break
    except (RuntimeError, KeyboardInterrupt):
        print('Save ckpt on exception ...')
        save_checkpoint(Model, infos, optimizer)
        print('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)
def train(opt):
    if vars(opt).get('start_from', None) is not None:
        opt.checkpoint_path = opt.start_from
        opt.id = opt.checkpoint_path.split('/')[-1]
        print('Point to folder: {}'.format(opt.checkpoint_path))
    else:
        opt.id = datetime.datetime.now().strftime(
            '%Y%m%d_%H%M%S') + '_' + opt.caption_model
        opt.checkpoint_path = os.path.join(opt.checkpoint_path, opt.id)

        if not os.path.exists(opt.checkpoint_path):
            os.makedirs(opt.checkpoint_path)
        print('Create folder: {}'.format(opt.checkpoint_path))

    # Deal with feature things before anything
    opt.use_att = utils.if_use_att(opt.caption_model)
    # opt.use_att = False
    if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5

    loader = DataLoader_UP(opt)
    opt.vocab_size = loader.vocab_size
    if opt.use_rela == 1:
        opt.rela_dict_size = loader.rela_dict_size
    opt.seq_length = loader.seq_length
    use_rela = getattr(opt, 'use_rela', 0)

    try:
        tb_summary_writer = tf and tf.compat.v1.summary.FileWriter(
            opt.checkpoint_path)
    except:
        print('Set tensorboard error!')
        pdb.set_trace()

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.checkpoint_path, 'infos.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(os.path.join(opt.checkpoint_path, 'histories.pkl')):
            with open(os.path.join(opt.checkpoint_path, 'histories.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt).cuda()
    # dp_model = torch.nn.DataParallel(model)
    # dp_model = torch.nn.DataParallel(model, [0,2,3])
    dp_model = model

    print('### Model summary below###\n {}\n'.format(str(model)))
    model_params = sum(p.numel() for p in model.parameters()
                       if p.requires_grad)
    print('model parameter:{}'.format(model_params))

    update_lr_flag = True
    # Assure in training mode
    dp_model.train()
    parameters = model.named_children()
    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()

    optimizer = utils.build_optimizer(
        filter(lambda p: p.requires_grad, model.parameters()), opt)

    optimizer.zero_grad()
    accumulate_iter = 0
    train_loss = 0
    reward = np.zeros([1, 1])

    while True:
        if update_lr_flag:
            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate
            utils.set_lr(optimizer, opt.current_lr)
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            update_lr_flag = False

        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch(opt.train_split)
        # print('Read data:', time.time() - start)

        torch.cuda.synchronize()
        start = time.time()

        fc_feats = None
        att_feats = None
        att_masks = None
        ssg_data = None
        rela_data = None

        if getattr(opt, 'use_ssg', 0) == 1:
            if getattr(opt, 'use_isg', 0) == 1:
                tmp = [
                    data['fc_feats'], data['labels'], data['masks'],
                    data['att_feats'], data['att_masks'],
                    data['isg_rela_matrix'], data['isg_rela_masks'],
                    data['isg_obj'], data['isg_obj_masks'], data['isg_attr'],
                    data['isg_attr_masks'], data['ssg_rela_matrix'],
                    data['ssg_rela_masks'], data['ssg_obj'],
                    data['ssg_obj_masks'], data['ssg_attr'],
                    data['ssg_attr_masks']
                ]

                tmp = [
                    _ if _ is None else torch.from_numpy(_).cuda() for _ in tmp
                ]
                fc_feats, labels, masks, att_feats, att_masks, \
                isg_rela_matrix, isg_rela_masks, isg_obj, isg_obj_masks, isg_attr, isg_attr_masks, \
                ssg_rela_matrix, ssg_rela_masks, ssg_obj, ssg_obj_masks, ssg_attr, ssg_attr_masks = tmp

                # image graph domain
                isg_data = {}
                isg_data['att_feats'] = att_feats
                isg_data['att_masks'] = att_masks

                isg_data['isg_rela_matrix'] = isg_rela_matrix
                isg_data['isg_rela_masks'] = isg_rela_masks
                isg_data['isg_obj'] = isg_obj
                isg_data['isg_obj_masks'] = isg_obj_masks
                isg_data['isg_attr'] = isg_attr
                isg_data['isg_attr_masks'] = isg_attr_masks
                # text graph domain
                ssg_data = {}
                ssg_data['ssg_rela_matrix'] = ssg_rela_matrix
                ssg_data['ssg_rela_masks'] = ssg_rela_masks
                ssg_data['ssg_obj'] = ssg_obj
                ssg_data['ssg_obj_masks'] = ssg_obj_masks
                ssg_data['ssg_attr'] = ssg_attr
                ssg_data['ssg_attr_masks'] = ssg_attr_masks
            else:
                tmp = [
                    data['fc_feats'], data['ssg_rela_matrix'],
                    data['ssg_rela_masks'], data['ssg_obj'],
                    data['ssg_obj_masks'], data['ssg_attr'],
                    data['ssg_attr_masks'], data['labels'], data['masks']
                ]
                tmp = [
                    _ if _ is None else torch.from_numpy(_).cuda() for _ in tmp
                ]
                fc_feats, ssg_rela_matrix, ssg_rela_masks, ssg_obj, ssg_obj_masks, ssg_attr, ssg_attr_masks, labels, masks = tmp
                ssg_data = {}
                ssg_data['ssg_rela_matrix'] = ssg_rela_matrix
                ssg_data['ssg_rela_masks'] = ssg_rela_masks
                ssg_data['ssg_obj'] = ssg_obj
                ssg_data['ssg_obj_masks'] = ssg_obj_masks
                ssg_data['ssg_attr'] = ssg_attr

                isg_data = None
                ssg_data['ssg_attr_masks'] = ssg_attr_masks
        else:
            tmp = [
                data['fc_feats'], data['att_feats'], data['labels'],
                data['masks'], data['att_masks']
            ]
            tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
            fc_feats, att_feats, labels, masks, att_masks = tmp

        if not sc_flag:
            # loss = crit(dp_model(model_zh,model_en,itow_zh,itow, fc_feats, labels, isg_data, ssg_data), labels[:, 1:], masks[:, 1:])
            # print('ssg:')
            # print(ssg_data['ssg_obj'])
            # print('predict:')
            # print(dp_model(fc_feats, labels, isg_data, ssg_data))
            # print('label:')
            # print(labels[:, 1:])
            loss = crit(dp_model(fc_feats, labels, isg_data, ssg_data),
                        labels[:, 1:], masks[:, 1:])
        else:
            gen_result, sample_logprobs = dp_model(fc_feats,
                                                   isg_data,
                                                   ssg_data,
                                                   opt={'sample_max': 0},
                                                   mode='sample')
            reward = get_self_critical_reward(dp_model, fc_feats, isg_data,
                                              ssg_data, data, gen_result, opt)
            loss = rl_crit(sample_logprobs, gen_result.data,
                           torch.from_numpy(reward).float().cuda())

        accumulate_iter = accumulate_iter + 1
        loss = loss / opt.accumulate_number
        loss.backward()
        if accumulate_iter % opt.accumulate_number == 0:
            utils.clip_gradient(optimizer, opt.grad_clip)
            optimizer.step()
            optimizer.zero_grad()
            iteration += 1
            accumulate_iter = 0
            train_loss = loss.item() * opt.accumulate_number
            end = time.time()

            if not sc_flag:
                print("{}/{}/{}|train_loss={:.3f}|time/batch={:.3f}" \
                      .format(opt.id, iteration, epoch, train_loss, end - start))
            else:
                print("{}/{}/{}|avg_reward={:.3f}|time/batch={:.3f}" \
                      .format(opt.id, iteration, epoch, np.mean(reward[:, 0]), end - start))

        torch.cuda.synchronize()

        # Update the iteration and epoch
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0) and (iteration != 0):
            add_summary_value(tb_summary_writer, 'train_loss', train_loss,
                              iteration)
            add_summary_value(tb_summary_writer, 'learning_rate',
                              opt.current_lr, iteration)
            add_summary_value(tb_summary_writer, 'scheduled_sampling_prob',
                              model.ss_prob, iteration)
            if sc_flag:
                add_summary_value(tb_summary_writer, 'avg_reward',
                                  np.mean(reward[:, 0]), iteration)

            loss_history[iteration] = train_loss if not sc_flag else np.mean(
                reward[:, 0])
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        # if (iteration %2 == 0) and (iteration != 0):
        if (iteration % opt.save_checkpoint_every == 0) and (iteration != 0):
            # eval model
            if use_rela:
                eval_kwargs = {
                    'split': 'val',
                    'dataset': opt.input_json,
                    'use_real': 1
                }
            else:
                eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            # val_loss, predictions, lang_stats = eval_utils.eval_split(model_zh,model_en,itow_zh,itow, dp_model, crit, loader, eval_kwargs)
            val_loss, predictions, lang_stats = eval_utils.eval_split(
                dp_model, crit, loader, eval_kwargs)

            # Write validation result into summary
            add_summary_value(tb_summary_writer, 'validation loss', val_loss,
                              iteration)
            if lang_stats is not None:
                for k, v in lang_stats.items():
                    add_summary_value(tb_summary_writer, k, v, iteration)
            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            best_flag = False
            if True:  # if true
                save_id = iteration / opt.save_checkpoint_every
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(os.path.join(opt.checkpoint_path, 'infos.pkl'),
                          'wb') as f:
                    cPickle.dump(infos, f)
                with open(os.path.join(opt.checkpoint_path, 'histories.pkl'),
                          'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos-best.pkl'), 'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
Ejemplo n.º 20
0
def load_info(loader, start_from, checkpoint_path, p_flag):
    infos = {}
    histories = {}
    if start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(checkpoint_path, 'infos.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            # need_be_same = ["caption_model", "rnn_type", "rnn_size", "num_layers"]
            need_be_same = ["rnn_type", "rnn_size", "num_layers"]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(os.path.join(checkpoint_path, 'histories.pkl')):
            with open(os.path.join(checkpoint_path, 'histories.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    opt.p_flag = p_flag
    if getattr(opt, 'p_flag', 0) == 0:
        opt.caption_model = opt.caption_model_zh
    else:
        opt.caption_model = opt.caption_model_en

    model = models.setup(opt).cuda()
    # dp_model = torch.nn.DataParallel(model)
    # dp_model = torch.nn.DataParallel(model, [0,2,3])
    dp_model = model

    update_lr_flag = True
    # Assure in training mode
    dp_model.train()
    parameters = model.named_children()
    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()

    optimizer = utils.build_optimizer(
        filter(lambda p: p.requires_grad, model.parameters()), opt)

    optimizer.zero_grad()
    accumulate_iter = 0
    train_loss = 0
    train_loss_kl = 0
    train_loss_all = 0

    reward = np.zeros([1, 1])
    return loader,iteration,epoch,val_result_history,loss_history,lr_history,ss_prob_history,best_val_score,\
           infos,histories,update_lr_flag,model,dp_model,parameters,crit,rl_crit,optimizer,accumulate_iter,train_loss,reward,train_loss_kl,train_loss_all
Ejemplo n.º 21
0
def train(opt):
    # Deal with feature things before anything
    opt.use_att = utils.if_use_att(opt.caption_model)
    if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5

    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from_path is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from_path,
                               'infos_' + opt.id + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(
                os.path.join(opt.start_from_path,
                             'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from_path,
                                 'histories_' + opt.id + '.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    #print(val_result_history.get(3000))
    #exit(0)
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt).cuda()
    no = sum(p.numel() for p in model.parameters())
    pytorch_total_params = sum(p.numel() for p in model.parameters()
                               if p.requires_grad)
    print("Trainable Params:" + str(pytorch_total_params))
    print("Total Params:" + str(no))
    #exit(0)
    dp_model = torch.nn.DataParallel(model)

    epoch_done = True
    # Assure in training mode
    dp_model.train()
    if (opt.use_obj_mcl_loss == 1):
        mcl_crit = utils.MultiLabelClassification()
    if opt.label_smoothing > 0:
        crit = utils.LabelSmoothing(smoothing=opt.label_smoothing)
    else:
        crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()

    if opt.noamopt:
        assert opt.caption_model == 'transformer', 'noamopt can only work with transformer'
        optimizer = utils.get_std_opt(model,
                                      factor=opt.noamopt_factor,
                                      warmup=opt.noamopt_warmup)
        optimizer._step = iteration
    elif opt.reduce_on_plateau:
        optimizer = utils.build_optimizer(model.parameters(), opt)
        optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3)
    else:
        optimizer = utils.build_optimizer(model.parameters(), opt)
    # Load the optimizer
    if vars(opt).get('start_from_path', None) is not None and os.path.isfile(
            os.path.join(opt.start_from_path, "optimizer.pth")):
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from_path, 'optimizer.pth')))

    time_epoch_start = time.time()
    data_time_sum = 0
    batch_time_sum = 0
    while True:
        if epoch_done:
            torch.cuda.synchronize()
            time_epoch_end = time.time()
            time_elapsed = (time_epoch_end - time_epoch_start)
            print('[DEBUG] Epoch Time: ' + str(time_elapsed))
            print('[DEBUG] Sum Data Time: ' + str(data_time_sum))
            print('[DEBUG] Sum Batch Time: ' + str(batch_time_sum))
            #if epoch==1:
            #    exit(0)
            if not opt.noamopt and not opt.reduce_on_plateau:
                # Assign the learning rate
                if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                    frac = (epoch - opt.learning_rate_decay_start
                            ) // opt.learning_rate_decay_every
                    decay_factor = opt.learning_rate_decay_rate**frac
                    opt.current_lr = opt.learning_rate * decay_factor
                else:
                    opt.current_lr = opt.learning_rate
                utils.set_lr(optimizer, opt.current_lr)  # set the decayed rate
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            epoch_done = False

        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch('train')
        print('Read data:', time.time() - start)
        data_time_sum += time.time() - start
        torch.cuda.synchronize()
        start = time.time()

        if (opt.use_obj_mcl_loss == 0):
            tmp = [
                data['fc_feats'], data['att_feats'], data['labels'],
                data['masks'], data['att_masks']
            ]
            tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
            fc_feats, att_feats, labels, masks, att_masks = tmp
        else:
            if opt.use_obj_att and opt.use_seg_feat:
                tmp = [
                    data['fc_feats'], data['att_feats'], data['obj_att_feats'],
                    data['seg_feat_feats'], data['labels'], data['masks'],
                    data['obj_labels'], data['att_masks'],
                    data['obj_att_masks'], data['seg_feat_masks']
                ]
                tmp = [
                    _ if _ is None else torch.from_numpy(_).cuda() for _ in tmp
                ]
                fc_feats, att_feats, obj_att_feats, seg_feat_feats, labels, masks, obj_labels, att_masks, obj_att_masks, seg_feat_masks = tmp
            elif not opt.use_obj_att and opt.use_seg_feat:
                tmp = [
                    data['fc_feats'], data['att_feats'],
                    data['seg_feat_feats'], data['labels'], data['masks'],
                    data['obj_labels'], data['att_masks'],
                    data['seg_feat_masks']
                ]
                tmp = [
                    _ if _ is None else torch.from_numpy(_).cuda() for _ in tmp
                ]
                fc_feats, att_feats, seg_feat_feats, labels, masks, obj_labels, att_masks, seg_feat_masks = tmp
            elif not opt.use_obj_att and not opt.use_seg_feat:
                tmp = [
                    data['fc_feats'], data['att_feats'], data['labels'],
                    data['masks'], data['obj_labels'], data['att_masks']
                ]
                tmp = [
                    _ if _ is None else torch.from_numpy(_).cuda() for _ in tmp
                ]
                fc_feats, att_feats, labels, masks, obj_labels, att_masks = tmp
            elif opt.use_obj_att and not opt.use_seg_feat:
                tmp = [
                    data['fc_feats'], data['att_feats'], data['obj_att_feats'],
                    data['labels'], data['masks'], data['obj_labels'],
                    data['att_masks'], data['obj_att_masks']
                ]
                tmp = [
                    _ if _ is None else torch.from_numpy(_).cuda() for _ in tmp
                ]
                fc_feats, att_feats, obj_att_feats, labels, masks, obj_labels, att_masks, obj_att_masks = tmp

        optimizer.zero_grad()
        if (opt.use_obj_mcl_loss == 0):
            if not sc_flag:
                loss = crit(dp_model(fc_feats, att_feats, labels, att_masks),
                            labels[:, 1:], masks[:, 1:])
            else:
                gen_result, sample_logprobs = dp_model(fc_feats,
                                                       att_feats,
                                                       att_masks,
                                                       opt={'sample_max': 0},
                                                       mode='sample')
                reward = get_self_critical_reward(dp_model, fc_feats,
                                                  att_feats, att_masks, data,
                                                  gen_result, opt)
                loss = rl_crit(sample_logprobs, gen_result.data,
                               torch.from_numpy(reward).float().cuda())
        else:
            if opt.use_obj_att and opt.use_seg_feat:
                if not sc_flag:
                    logits, out = dp_model(
                        fc_feats, [att_feats, obj_att_feats, seg_feat_feats],
                        labels, [att_masks, obj_att_masks, seg_feat_masks])
                    caption_loss = crit(logits, labels[:, 1:], masks[:, 1:])
                    obj_loss = mcl_crit(out, obj_labels)
                    loss = opt.lambda_caption * caption_loss + opt.lambda_obj * obj_loss
                    #loss = 0.1*caption_loss + obj_loss
                    #loss = caption_loss + 0 * obj_loss
                else:
                    gen_result, sample_logprobs = dp_model(
                        fc_feats,
                        att_feats,
                        att_masks,
                        opt={'sample_max': 0},
                        mode='sample')
                    reward = get_self_critical_reward(dp_model, fc_feats,
                                                      att_feats, att_masks,
                                                      data, gen_result, opt)
                    loss = rl_crit(sample_logprobs, gen_result.data,
                                   torch.from_numpy(reward).float().cuda())
            elif not opt.use_obj_att and opt.use_seg_feat:
                if not sc_flag:
                    logits, out = dp_model(fc_feats,
                                           [att_feats, seg_feat_feats], labels,
                                           [att_masks, seg_feat_masks])
                    caption_loss = crit(logits, labels[:, 1:], masks[:, 1:])
                    obj_loss = mcl_crit(out, obj_labels)
                    loss = opt.lambda_caption * caption_loss + opt.lambda_obj * obj_loss
                    #loss = caption_loss + 0 * obj_loss
                else:
                    gen_result, sample_logprobs = dp_model(
                        fc_feats,
                        att_feats,
                        att_masks,
                        opt={'sample_max': 0},
                        mode='sample')
                    reward = get_self_critical_reward(dp_model, fc_feats,
                                                      att_feats, att_masks,
                                                      data, gen_result, opt)
                    loss = rl_crit(sample_logprobs, gen_result.data,
                                   torch.from_numpy(reward).float().cuda())
            if not opt.use_obj_att and not opt.use_seg_feat:
                if not sc_flag:
                    logits, out = dp_model(fc_feats, att_feats, labels,
                                           att_masks)
                    caption_loss = crit(logits, labels[:, 1:], masks[:, 1:])
                    obj_loss = mcl_crit(out, obj_labels)
                    loss = opt.lambda_caption * caption_loss + opt.lambda_obj * obj_loss
                    #loss = caption_loss + 0 * obj_loss
                else:
                    gen_result, sample_logprobs = dp_model(
                        fc_feats,
                        att_feats,
                        att_masks,
                        opt={'sample_max': 0},
                        mode='sample')
                    reward = get_self_critical_reward(dp_model, fc_feats,
                                                      att_feats, att_masks,
                                                      data, gen_result, opt)
                    loss = rl_crit(sample_logprobs, gen_result.data,
                                   torch.from_numpy(reward).float().cuda())
            elif opt.use_obj_att and not opt.use_seg_feat:
                if not sc_flag:
                    logits, out = dp_model(fc_feats,
                                           [att_feats, obj_att_feats], labels,
                                           [att_masks, obj_att_masks])
                    caption_loss = crit(logits, labels[:, 1:], masks[:, 1:])
                    obj_loss = mcl_crit(out, obj_labels)
                    loss = 0.1 * caption_loss + obj_loss
                    #loss = caption_loss + 0 * obj_loss
                else:
                    gen_result, sample_logprobs = dp_model(
                        fc_feats,
                        att_feats,
                        att_masks,
                        opt={'sample_max': 0},
                        mode='sample')
                    reward = get_self_critical_reward(dp_model, fc_feats,
                                                      att_feats, att_masks,
                                                      data, gen_result, opt)
                    loss = rl_crit(sample_logprobs, gen_result.data,
                                   torch.from_numpy(reward).float().cuda())
        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        train_loss = loss.item()
        torch.cuda.synchronize()
        end = time.time()
        batch_time_sum += end - start
        if not sc_flag:
            if (opt.use_obj_mcl_loss == 1):
                print("iter {} (epoch {}), train_loss = {:.3f}, caption_loss = {:.3f}, object_loss = {:.3f}, time/batch = {:.3f}" \
                      .format(iteration, epoch, train_loss, caption_loss.item(), obj_loss.item(), end - start))
            else:
                print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                .format(iteration, epoch, train_loss, end - start))
        else:
            print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \
                .format(iteration, epoch, np.mean(reward[:,0]), end - start))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            epoch_done = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            add_summary_value(tb_summary_writer, 'train_loss', train_loss,
                              iteration)
            if (opt.use_obj_mcl_loss == 1):
                add_summary_value(tb_summary_writer, 'obj_loss',
                                  obj_loss.item(), iteration)
                add_summary_value(tb_summary_writer, 'caption_loss',
                                  caption_loss.item(), iteration)
            if opt.noamopt:
                opt.current_lr = optimizer.rate()
            elif opt.reduce_on_plateau:
                opt.current_lr = optimizer.current_lr
            add_summary_value(tb_summary_writer, 'learning_rate',
                              opt.current_lr, iteration)
            add_summary_value(tb_summary_writer, 'scheduled_sampling_prob',
                              model.ss_prob, iteration)
            if sc_flag:
                add_summary_value(tb_summary_writer, 'avg_reward',
                                  np.mean(reward[:, 0]), iteration)

            loss_history[iteration] = train_loss if not sc_flag else np.mean(
                reward[:, 0])
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            # eval model
            orig_batch_size = opt.batch_size
            opt.batch_size = 1
            eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            loader.batch_size = eval_kwargs.get('batch_size', 1)
            val_loss, predictions, lang_stats = eval_utils.eval_split(
                dp_model, crit, loader, eval_kwargs)
            opt.batch_size = orig_batch_size
            loader.batch_size = orig_batch_size

            if opt.reduce_on_plateau:
                if 'CIDEr' in lang_stats:
                    optimizer.scheduler_step(-lang_stats['CIDEr'])
                else:
                    optimizer.scheduler_step(val_loss)

            # Write validation result into summary
            add_summary_value(tb_summary_writer, 'validation loss', val_loss,
                              iteration)
            for k, v in lang_stats.items():
                add_summary_value(tb_summary_writer, k, v, iteration)
            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            best_flag = False
            if True:  # if true
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'infos_' + opt.id + '.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'histories_' + opt.id + '.pkl'),
                        'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos_' + opt.id + '-best.pkl'),
                            'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
Ejemplo n.º 22
0
def train(opt):
    # Deal with feature things before anything
    opt.use_fc, opt.use_att = utils.if_use_feat(opt.caption_model)
    if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5

    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)
    #opt.ss_prob=0.0
    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from, 'infos_'+opt.id+'.pkl'), 'rb') as f:
            infos = utils.pickle_load(f)
            saved_model_opt = infos['opt']
            need_be_same=["caption_model", "rnn_type", "rnn_size", "num_layers"]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')):
            with open(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl'), 'rb') as f:
                histories = utils.pickle_load(f)
    else:
        infos['iter'] = 0
        infos['epoch'] = 0
        infos['iterators'] = loader.iterators
        infos['split_ix'] = loader.split_ix
        infos['vocab'] = loader.get_vocab()
        infos['pix_perss']=loader.get_personality()
    infos['opt'] = opt
    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)
    print("current epoch: ",epoch)
    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    opt.vocab = loader.get_vocab()
    opt.xpersonality=loader.get_personality()
    if opt.use_joint==0:
        #torch.cuda.set_device(0)
        model = models.setup(opt).cuda()
    elif opt.use_joint==1:
        model = models.JointModel(opt)
        model.cuda()
    #model=models.setup(opt)
    del opt.vocab
    if opt.start_from is not None:
        opt.model=os.path.join(opt.start_from, 'model'+'.pth')
        model.load_state_dict(torch.load(opt.model))
    dp_model = torch.nn.DataParallel(model)
    lw_model = LossWrapper(model, opt)
    dp_lw_model = torch.nn.DataParallel(lw_model)
    #dp_lw_model=LossWrapper(model, opt)  # this is for no cuda
    epoch_done = True
    # Assure in training mode
    #dp_lw_model=lw_model
    dp_lw_model.train()
    if opt.noamopt:
        assert opt.caption_model == 'transformer', 'noamopt can only work with transformer'
        optimizer = utils.get_std_opt(model, factor=opt.noamopt_factor, warmup=opt.noamopt_warmup)
        optimizer._step = iteration
    elif opt.reduce_on_plateau:
        optimizer = utils.build_optimizer([p for p in model.parameters() if p.requires_grad], opt)
        optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3)
    else:
        optimizer = utils.build_optimizer([p for p in model.parameters() if p.requires_grad], opt)
    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(os.path.join(opt.start_from,"optimizer.pth")):
        optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'optimizer.pth')))
    else:
        print('Optimizer param group number not matched? There must be new parameters. Reinit the optimizer.')


    def save_checkpoint(model, infos, optimizer, histories=None, append=''):
        if len(append) > 0:
            append = '-' + append
        # if checkpoint_path doesn't exist
        if not os.path.isdir(opt.checkpoint_path):
            os.makedirs(opt.checkpoint_path)
        checkpoint_path = os.path.join(opt.checkpoint_path, 'model%s.pth' %(append))
        torch.save(model.state_dict(), checkpoint_path)
        print("model saved to {}".format(checkpoint_path))
        optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer%s.pth' %(append))
        torch.save(optimizer.state_dict(), optimizer_path)
        with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'%s.pkl' %(append)), 'wb') as f:
            utils.pickle_dump(infos, f)
        if histories:
            with open(os.path.join(opt.checkpoint_path, 'histories_'+opt.id+'%s.pkl' %(append)), 'wb') as f:
                utils.pickle_dump(histories, f)

    try:
        while True:
            if epoch_done:
                if not opt.noamopt and not opt.reduce_on_plateau:
                    # Assign the learning rate
                    if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                        frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every
                        decay_factor = opt.learning_rate_decay_rate  ** frac
                        opt.current_lr = opt.learning_rate * decay_factor
                    else:
                        opt.current_lr = opt.learning_rate
                    utils.set_lr(optimizer, opt.current_lr) # set the decayed rate
                # Assign the scheduled sampling prob
                if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                    frac = (epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every
                    opt.ss_prob = min(opt.scheduled_sampling_increase_prob  * frac, opt.scheduled_sampling_max_prob)
                    model.ss_prob = opt.ss_prob

                # If start self critical training
                if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                    sc_flag = True
                    init_scorer(opt.cached_tokens)
                else:
                    sc_flag = False
                # Assign retrieval loss weight
                if epoch > opt.retrieval_reward_weight_decay_start and opt.retrieval_reward_weight_decay_start >= 0:
                    frac = (epoch - opt.retrieval_reward_weight_decay_start) // opt.retrieval_reward_weight_decay_every
                    model.retrieval_reward_weight = opt.retrieval_reward_weight * (opt.retrieval_reward_weight_decay_rate  ** frac)
                epoch_done = False
                    
            start = time.time()
            # Load data from train split (0)
            data = loader.get_batch('train')
            print('Read data:', time.time() - start)

            torch.cuda.synchronize()
            start = time.time()
            with torch.autograd.set_detect_anomaly(True):
                tmp = [data['fc_feats'], data['att_feats'],data['densecap'], data['labels'], data['masks'], data['att_masks'], data['personality']]
                tmp = [_ if _ is None else _.cuda() for _ in tmp]
                fc_feats, att_feats,densecap, labels, masks, att_masks,personality = tmp
                optimizer.zero_grad()
                model_out = dp_lw_model(fc_feats, att_feats,densecap, labels, masks, att_masks,personality, data['gts'], torch.arange(0, len(data['gts'])), sc_flag)

                loss = model_out['loss'].mean()
                
                loss.backward()
                utils.clip_gradient(optimizer, opt.grad_clip)
                optimizer.step()
                train_loss = loss.item()
                torch.cuda.synchronize()
                end = time.time()
            if not sc_flag:
                print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, train_loss, end - start))
            else:
                print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f},train_loss = {:.3f}" \
                    .format(iteration, epoch, model_out['reward'].mean(), end - start,train_loss))

            if opt.use_joint==1:
                for k, v in model.loss().items():
                    prt_str += "{} = {:.3f} ".format(k, v)
                print(prt_str)

            # Update the iteration and epoch
            iteration += 1
            if data['bounds']['wrapped']:
                epoch += 1
                epoch_done = True

            # Write the training loss summary
            if (iteration % opt.losses_log_every == 0):
                add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration)
                if opt.noamopt:
                    opt.current_lr = optimizer.rate()
                elif opt.reduce_on_plateau:
                    opt.current_lr = optimizer.current_lr
                add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration)
                add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration)
                if sc_flag:
                    add_summary_value(tb_summary_writer, 'avg_reward', model_out['reward'].mean(), iteration)

                loss_history[iteration] = train_loss if not sc_flag else model_out['reward'].mean()
                lr_history[iteration] = opt.current_lr
                ss_prob_history[iteration] = model.ss_prob

            # update infos
            infos['iter'] = iteration
            infos['epoch'] = epoch
            infos['iterators'] = loader.iterators
            infos['split_ix'] = loader.split_ix
            
            # make evaluation on validation set, and save model
            if (iteration % opt.save_checkpoint_every == 0):
                # eval model
                eval_kwargs = {'split': 'val',
                                'dataset': opt.input_json}
                eval_kwargs.update(vars(opt))
                val_loss, predictions, lang_stats = eval_utils.eval_split(
                    dp_model, lw_model.crit, loader, eval_kwargs)

                if opt.reduce_on_plateau:                    
                    if 'CIDEr' in lang_stats:
                        optimizer.scheduler_step(-lang_stats['CIDEr'])
                    else:
                        optimizer.scheduler_step(val_loss)
                # Write validation result into summary
                add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration)
                if lang_stats is not None:
                    for k,v in lang_stats.items():
                        add_summary_value(tb_summary_writer, k, v, iteration)
                val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions}

                # Save model if is improving on validation result
                    
                if opt.language_eval == 1:
                    if opt.use_joint==1:
                        current_score = lang_stats['SPICE']*100
                    elif opt.use_joint==0:
                        current_score = lang_stats['CIDEr'] # could use SPICE
                else:
                    if opt.use_joint==0:
                        current_score = - val_loss
                    elif opt.use_joint==1:
                        current_score= - val_loss['loss_cap']
                if opt.use_joint==1:
                    current_score_vse = val_loss.get(opt.vse_eval_criterion, 0)*100

                best_flag = False
                best_flag_vse= False
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                if opt.use_joint==1:
                    if best_val_score_vse is None or current_score_vse > best_val_score_vse:
                        best_val_score_vse = current_score_vse
                        best_flag_vse = True
                    infos['best_val_score_vse'] = best_val_score_vse
                # Dump miscalleous informations
                infos['best_val_score'] = best_val_score
                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history

                save_checkpoint(model, infos, optimizer, histories)
                if opt.save_history_ckpt:
                    save_checkpoint(model, infos, optimizer, append=str(iteration))

                if best_flag:
                    save_checkpoint(model, infos, optimizer, append='best')
                if best_flag_vse:
                    save_checkpoint(model, infos, optimizer, append='vse-best')

            # Stop if reaching max epochs
            if epoch >= opt.max_epochs and opt.max_epochs != -1:
                break
    except (RuntimeError, KeyboardInterrupt):
        print('Save ckpt on exception ...')
        save_checkpoint(model, infos, optimizer)
        print('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)
Ejemplo n.º 23
0
def train(opt):
    # Deal with feature things before anything
    opt.use_att = utils.if_use_att(opt.caption_model)
    if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5

    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tb_summary_writer = tb and tb.SummaryWriter(log_dir=opt.checkpoint_path)
    print(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        print(os.getcwd())
        with open(
                os.path.join(os.getcwd(), opt.start_from,
                             'infos_' + opt.id + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt).cuda()
    # dp_model = torch.nn.DataParallel(model)
    dp_model = model

    update_lr_flag = True
    # Assure in training mode
    dp_model.train()

    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()

    optimizer = utils.build_optimizer(model.parameters(), opt)
    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(
            os.path.join(opt.start_from, "optimizer.pth")):
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    while True:
        # # [added] reproduce straight line learning rate decay in supplementary
        # #      ---- the original paper used 60k iters
        # #      ---- if lr goes to zero just stay at the last lr
        # linear_lr = -(iteration+1)*opt.learning_rate/60000 + opt.learning_rate
        # if linear_lr <= 0:
        #     pass
        # else:
        #     opt.current_lr = linear_lr
        #     utils.set_lr(optimizer, opt.current_lr)

        if update_lr_flag:
            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate
            utils.set_lr(optimizer, opt.current_lr)

            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            update_lr_flag = False

        start = time.time()
        # Load data from train split (0)
        # [naxin] knn_data is the nearest neighbour batch, the format is identical to data
        data, knn_data = loader.get_batch('train')
        print('Read data:', time.time() - start)

        torch.cuda.synchronize()
        start = time.time()

        tmp = [
            data['fc_feats'], data['att_feats'], data['labels'], data['masks'],
            data['att_masks']
        ]
        # tmp = [knn_data['fc_feats'], knn_data['att_feats'], knn_data['labels'], knn_data['masks'], knn_data['att_masks']]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        fc_feats, att_feats, labels, masks, att_masks = tmp

        optimizer.zero_grad()
        if not sc_flag:
            loss = crit(dp_model(fc_feats, att_feats, labels, att_masks),
                        labels[:, 1:], masks[:, 1:])
        else:
            gen_result, sample_logprobs = dp_model(fc_feats,
                                                   att_feats,
                                                   att_masks,
                                                   opt={'sample_max': 0},
                                                   mode='sample')
            reward = get_self_critical_reward(dp_model, fc_feats, att_feats,
                                              att_masks, data, gen_result, opt)
            loss = rl_crit(sample_logprobs, gen_result.data,
                           torch.from_numpy(reward).float().cuda())

        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        train_loss = loss.item()
        torch.cuda.synchronize()
        end = time.time()
        if not sc_flag:
            print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                .format(iteration, epoch, train_loss, end - start))
        else:
            print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \
                .format(iteration, epoch, np.mean(reward[:,0]), end - start))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            add_summary_value(tb_summary_writer, 'train_loss', train_loss,
                              iteration)
            add_summary_value(tb_summary_writer, 'learning_rate',
                              opt.current_lr, iteration)
            add_summary_value(tb_summary_writer, 'scheduled_sampling_prob',
                              model.ss_prob, iteration)
            if sc_flag:
                add_summary_value(tb_summary_writer, 'avg_reward',
                                  np.mean(reward[:, 0]), iteration)

            loss_history[iteration] = train_loss if not sc_flag else np.mean(
                reward[:, 0])
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            # eval model
            eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = eval_utils.eval_split(
                dp_model, crit, loader, eval_kwargs, eval_knn=opt.use_knn)

            # Write validation result into summary
            add_summary_value(tb_summary_writer, 'validation loss', val_loss,
                              iteration)
            if lang_stats is not None:
                for k, v in lang_stats.items():
                    add_summary_value(tb_summary_writer, k, v, iteration)
            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            best_flag = False
            if True:  # if true
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'infos_' + opt.id + '.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'histories_' + opt.id + '.pkl'),
                        'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos_' + opt.id + '-best.pkl'),
                            'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
def train(opt):
    import random
    random.seed(0)
    torch.manual_seed(0)
    torch.cuda.manual_seed_all(0)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(0)
    # Deal with feature things before anything
    opt.use_att = utils.if_use_att(opt.caption_model)
    if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5

    from dataloader_pair import DataLoader

    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    if opt.log_to_file:
        if os.path.exists(os.path.join(opt.checkpoint_path, 'log')):
            suffix = time.strftime("%Y-%m-%d %X", time.localtime())
            print('Warning !!! %s already exists ! use suffix ! ' %
                  os.path.join(opt.checkpoint_path, 'log'))
            sys.stdout = open(
                os.path.join(opt.checkpoint_path, 'log' + suffix), "w")
        else:
            print('logging to file %s' %
                  os.path.join(opt.checkpoint_path, 'log'))
            sys.stdout = open(os.path.join(opt.checkpoint_path, 'log'), "w")

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        if os.path.isfile(opt.start_from):
            with open(os.path.join(opt.infos)) as f:
                infos = cPickle.load(f)
                saved_model_opt = infos['opt']
                need_be_same = [
                    "caption_model", "rnn_type", "rnn_size", "num_layers"
                ]
                for checkme in need_be_same:
                    assert vars(saved_model_opt)[checkme] == vars(
                        opt
                    )[checkme], "Command line argument and saved model disagree on '%s' " % checkme
        else:
            if opt.load_best != 0:
                print('loading best info')
                with open(
                        os.path.join(opt.start_from,
                                     'infos_' + opt.id + '-best.pkl')) as f:
                    infos = cPickle.load(f)
                    saved_model_opt = infos['opt']
                    need_be_same = [
                        "caption_model", "rnn_type", "rnn_size", "num_layers"
                    ]
                    for checkme in need_be_same:
                        assert vars(saved_model_opt)[checkme] == vars(
                            opt
                        )[checkme], "Command line argument and saved model disagree on '%s' " % checkme
            else:
                with open(
                        os.path.join(opt.start_from,
                                     'infos_' + opt.id + '.pkl')) as f:
                    infos = cPickle.load(f)
                    saved_model_opt = infos['opt']
                    need_be_same = [
                        "caption_model", "rnn_type", "rnn_size", "num_layers"
                    ]
                    for checkme in need_be_same:
                        assert vars(saved_model_opt)[checkme] == vars(
                            opt
                        )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl'), 'rb') as f:
                try:
                    histories = cPickle.load(f)
                except:
                    print('load history error!')
                    histories = {}

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)
    start_epoch = epoch

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt).cuda()
    dp_model = torch.nn.DataParallel(model)

    update_lr_flag = True
    # Assure in training mode
    dp_model.train()

    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()

    optimizer = utils.build_optimizer(model.parameters(), opt)
    #Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(
            os.path.join(opt.start_from, "optimizer.pth")):
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    if opt.caption_model == 'att2in2p':
        optimized = [
            'logit2', 'ctx2att2', 'core2', 'prev_sent_emb', 'prev_sent_wrap'
        ]
        optimized_param = []
        optimized_param1 = []

        for name, param in model.named_parameters():
            second = False
            for n in optimized:
                if n in name:
                    print('second', name)
                    optimized_param.append(param)
                    second = True
            if 'embed' in name:
                print('all', name)
                optimized_param1.append(param)
                optimized_param.append(param)
            elif not second:
                print('first', name)
                optimized_param1.append(param)

    while True:
        if opt.val_only:
            eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            print('start evaluating')
            val_loss, predictions, lang_stats = eval_utils_pair.eval_split(
                dp_model, crit, loader, eval_kwargs)
            exit(0)
        if update_lr_flag:
            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate
            utils.set_lr(optimizer, opt.current_lr)
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            update_lr_flag = False

        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch('train')
        print('Read data:', time.time() - start)

        torch.cuda.synchronize()
        start = time.time()

        tmp = [
            data['pair_fc_feats'], data['pair_att_feats'], data['pair_labels'],
            data['pair_masks'], data['pair_att_masks']
        ]

        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        fc_feats, att_feats, labels, masks, att_masks = tmp
        masks = masks.float()

        optimizer.zero_grad()

        if not sc_flag:
            if opt.onlysecond:
                # only using the second sentence from a visual paraphrase pair. opt.caption_model should be a one-stage decoding model
                loss = crit(
                    dp_model(fc_feats, att_feats, labels[:, 1, :], att_masks),
                    labels[:, 1, 1:], masks[:, 1, 1:])
                loss1 = loss2 = loss / 2
            elif opt.first:
                # using the first sentence
                tmp = [
                    data['first_fc_feats'], data['first_att_feats'],
                    data['first_labels'], data['first_masks'],
                    data['first_att_masks']
                ]
                tmp = [
                    _ if _ is None else torch.from_numpy(_).cuda() for _ in tmp
                ]
                fc_feats, att_feats, labels, masks, att_masks = tmp
                masks = masks.float()
                loss = crit(
                    dp_model(fc_feats, att_feats, labels[:, :], att_masks),
                    labels[:, 1:], masks[:, 1:])
                loss1 = loss2 = loss / 2
            elif opt.onlyfirst:
                # only using the second sentence from a visual paraphrase pair
                loss = crit(
                    dp_model(fc_feats, att_feats, labels[:, 0, :], att_masks),
                    labels[:, 0, 1:], masks[:, 0, 1:])
                loss1 = loss2 = loss / 2
            else:
                # proposed DCVP model, opt.caption_model should be att2inp
                output1, output2 = dp_model(fc_feats, att_feats, labels,
                                            att_masks, masks[:, 0, 1:])
                loss1 = crit(output1, labels[:, 0, 1:], masks[:, 0, 1:])
                loss2 = crit(output2, labels[:, 1, 1:], masks[:, 1, 1:])
                loss = loss1 + loss2

        else:
            raise NotImplementedError
            # Our DCVP model does not support self-critical sequence training
            # We found that RL(SCST) with CIDEr reward will improve conventional metrics (BLEU, CIDEr, etc.)
            # but harm diversity and descriptiveness
            # Please refer to the paper for the details

        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()

        train_loss = loss.item()
        torch.cuda.synchronize()
        end = time.time()
        if not sc_flag:
            print("iter {} (epoch {}), train_loss = {:.3f}, loss1 = {:.3f}, loss2 = {:.3f}, time/batch = {:.3f}" \
                .format(iteration, epoch, loss.item(), loss1.item(), loss2.item(), end - start))
        else:
            print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \
                .format(iteration, epoch, np.mean(reward[:,0]), end - start))

        sys.stdout.flush()
        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            add_summary_value(tb_summary_writer, 'train_loss', train_loss,
                              iteration)
            add_summary_value(tb_summary_writer, 'learning_rate',
                              opt.current_lr, iteration)
            add_summary_value(tb_summary_writer, 'scheduled_sampling_prob',
                              model.ss_prob, iteration)
            if sc_flag:
                add_summary_value(tb_summary_writer, 'avg_reward',
                                  np.mean(reward[:, 0]), iteration)

            loss_history[iteration] = train_loss if not sc_flag else np.mean(
                reward[:, 0])
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            # eval model
            eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = eval_utils_pair.eval_split(
                dp_model, crit, loader, eval_kwargs)

            # Write validation result into summary
            add_summary_value(tb_summary_writer, 'validation loss', val_loss,
                              iteration)
            if lang_stats is not None:
                for k, v in lang_stats.items():
                    add_summary_value(tb_summary_writer, k, v, iteration)
            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            best_flag = False
            if True:  # if true
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'infos_' + opt.id + '.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'histories_' + opt.id + '.pkl'),
                        'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos_' + opt.id + '-best.pkl'),
                            'wb') as f:
                        cPickle.dump(infos, f)
                checkpoint_path = os.path.join(
                    opt.checkpoint_path, 'model' + str(iteration) + '.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                with open(
                        os.path.join(
                            opt.checkpoint_path,
                            'infos_' + opt.id + '_' + str(iteration) + '.pkl'),
                        'wb') as f:
                    cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
Ejemplo n.º 25
0
def train(opt):

    ################################
    # Build dataloader
    ################################
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    ##########################
    # Initialize infos
    ##########################
    infos = {
        'iter': 0,
        'epoch': 0,
        'loader_state_dict': None,
        'vocab': loader.get_vocab(),
    }
    # Load old infos(if there is) and check if models are compatible
    if opt.start_from is not None and os.path.isfile(
            os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl')):
        with open(os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl'),
                  'rb') as f:
            infos = utils.pickle_load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert getattr(saved_model_opt, checkme) == getattr(
                    opt, checkme
                ), "Command line argument and saved model disagree on '%s' " % checkme
    infos['opt'] = opt

    #########################
    # Build logger
    #########################
    # naive dict logger
    histories = defaultdict(dict)
    if opt.start_from is not None and os.path.isfile(
            os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
        with open(os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl'),
                  'rb') as f:
            histories.update(utils.pickle_load(f))

    # tensorboard logger
    tb_summary_writer = SummaryWriter(opt.checkpoint_path)

    ##########################
    # Build model
    ##########################
    opt.vocab = loader.get_vocab()
    multi_models_list = []
    for order in range(opt.number_of_models):
        multi_models_list.append(models.setup(opt).cuda())
    for order in range(opt.number_of_models):
        multi_models_list.append(models.setup(opt).cuda())
    for order in range(opt.number_of_models, 2 * opt.number_of_models):
        for param in multi_models_list[order].parameters():
            param.detach_()
    for order in range(opt.number_of_models):
        for param, param_ema in zip(
                multi_models_list[order].parameters(),
                multi_models_list[order + opt.number_of_models].parameters()):
            param_ema.data = param.data.clone()
    # multi_models = MultiModels(multi_models_list)
    # multi_models_list.append(SenEncodeModel(opt).cuda())
    multi_models = nn.ModuleList(multi_models_list)
    del opt.vocab
    # Load pretrained weights:
    if opt.start_from is not None and os.path.isfile(
            os.path.join(opt.start_from, 'model.pth')):
        multi_models.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'model.pth')))

    # Wrap generation model with loss function(used for training)
    # This allows loss function computed separately on each machine
    lw_models = nn.ModuleList([
        LossWrapper(multi_models[index], opt)
        for index in range(opt.number_of_models)
    ])
    kdlw_models = nn.ModuleList([
        KDLossWrapper(multi_models[index], opt)
        for index in range(opt.number_of_models)
    ])
    lw_models_ema = nn.ModuleList([
        LossWrapper(multi_models[opt.number_of_models + index], opt)
        for index in range(opt.number_of_models)
    ])
    kdlw_models_ema = nn.ModuleList([
        KDLossWrapper(multi_models[opt.number_of_models + index], opt)
        for index in range(opt.number_of_models)
    ])
    # Wrap with dataparallel
    dp_models = nn.ModuleList([
        torch.nn.DataParallel(multi_models[index])
        for index in range(opt.number_of_models)
    ])
    dp_lw_models = nn.ModuleList([
        torch.nn.DataParallel(lw_models[index])
        for index in range(opt.number_of_models)
    ])
    dp_kdlw_models = nn.ModuleList([
        torch.nn.DataParallel(kdlw_models[index])
        for index in range(opt.number_of_models)
    ])
    dp_models_ema = nn.ModuleList([
        torch.nn.DataParallel(multi_models[opt.number_of_models + index])
        for index in range(opt.number_of_models)
    ])
    dp_lw_models_ema = nn.ModuleList([
        torch.nn.DataParallel(lw_models_ema[index])
        for index in range(opt.number_of_models)
    ])
    dp_kdlw_models_ema = nn.ModuleList([
        torch.nn.DataParallel(kdlw_models_ema[index])
        for index in range(opt.number_of_models)
    ])

    ##########################
    #  Build optimizer
    ##########################
    if opt.noamopt:
        assert opt.caption_model in [
            'transformer', 'bert', 'm2transformer'
        ], 'noamopt can only work with transformer'
        optimizer = utils.get_std_opt(multi_models,
                                      factor=opt.noamopt_factor,
                                      warmup=opt.noamopt_warmup)
    elif opt.reduce_on_plateau:
        optimizer = utils.build_optimizer(multi_models.parameters(), opt)
        optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3)
    else:
        optimizer = utils.build_optimizer(multi_models.parameters(), opt)
    # Load the optimizer
    if opt.start_from is not None and os.path.isfile(
            os.path.join(opt.start_from, "optimizer.pth")):
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    ##########################
    #  Build loss
    ##########################
    # triplet_loss = nn.TripletMarginLoss()

    #########################
    # Get ready to start
    #########################
    iteration = infos['iter']
    epoch = infos['epoch']
    # For back compatibility
    if 'iterators' in infos:
        infos['loader_state_dict'] = {
            split: {
                'index_list': infos['split_ix'][split],
                'iter_counter': infos['iterators'][split]
            }
            for split in [
                'paired_train', 'unpaired_images_train',
                'unpaired_captions_train', 'train', 'val', 'test'
            ]
        }
    loader.load_state_dict(infos['loader_state_dict'])
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)
    if opt.noamopt:
        optimizer._step = iteration
    # flag indicating finish of an epoch
    # Always set to True at the beginning to initialize the lr or etc.
    epoch_done = True
    # Assure in training mode
    dp_lw_models.train()
    dp_kdlw_models.train()
    dp_lw_models_ema.train()
    dp_kdlw_models_ema.train()

    # Build the ensemble model
    # # Setup the model
    model_ensemble = AttEnsemble(multi_models_list[opt.number_of_models:2 *
                                                   opt.number_of_models],
                                 weights=None)
    # model_ensemble.seq_length = 20
    model_ensemble.cuda()
    # model_ensemble.eval()
    kd_model_outs_list = []

    # Start training
    try:
        while True:
            # Stop if reaching max epochs
            if epoch >= opt.max_epochs and opt.max_epochs != -1:
                break

            if epoch_done:
                if not opt.noamopt and not opt.reduce_on_plateau:
                    # Assign the learning rate
                    if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                        frac = (epoch - opt.learning_rate_decay_start
                                ) // opt.learning_rate_decay_every
                        decay_factor = opt.learning_rate_decay_rate**frac
                        opt.current_lr = opt.learning_rate * decay_factor
                    else:
                        opt.current_lr = opt.learning_rate
                    utils.set_lr(optimizer,
                                 opt.current_lr)  # set the decayed rate
                # Assign the scheduled sampling prob
                if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                    frac = (epoch - opt.scheduled_sampling_start
                            ) // opt.scheduled_sampling_increase_every
                    opt.ss_prob = min(
                        opt.scheduled_sampling_increase_prob * frac,
                        opt.scheduled_sampling_max_prob)
                    for index in range(opt.number_of_models):
                        multi_models[index].ss_prob = opt.ss_prob

                # If start self critical training
                if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                    sc_flag = True
                    init_scorer(opt.cached_tokens)
                else:
                    sc_flag = False

                # If start structure loss training
                if opt.structure_after != -1 and epoch >= opt.structure_after:
                    struc_flag = True
                    init_scorer(opt.cached_tokens)
                else:
                    struc_flag = False

                if epoch >= opt.paired_train_epoch:
                    opt.current_lambda_x = opt.hyper_parameter_lambda_x * \
                                         (epoch - (opt.paired_train_epoch - 1)) /\
                                         (opt.max_epochs - opt.paired_train_epoch)
                    opt.current_lambda_y = opt.hyper_parameter_lambda_y * \
                                           (epoch - (opt.paired_train_epoch - 1)) / \
                                           (opt.max_epochs - opt.paired_train_epoch)

                epoch_done = False

            start = time.time()
            # Load data from train split (0)
            if epoch < opt.language_pretrain_epoch:
                data = loader.get_batch('unpaired_captions_train')
            elif epoch < opt.paired_train_epoch:
                data = loader.get_batch('paired_train')
            else:
                data = loader.get_batch('paired_train')
                unpaired_data = loader.get_batch('unpaired_images_train')
                unpaired_caption = loader.get_batch('unpaired_captions_train')
            print('Read data:', time.time() - start)

            torch.cuda.synchronize()
            start = time.time()
            if epoch < opt.language_pretrain_epoch:
                tmp = [
                    data['fc_feats'] * 0, data['att_feats'] * 0,
                    data['labels'], data['masks'], data['att_masks']
                ]
            elif epoch < opt.paired_train_epoch:
                tmp = [
                    data['fc_feats'], data['att_feats'], data['labels'],
                    data['masks'], data['att_masks']
                ]
            else:
                tmp = [
                    data['fc_feats'], data['att_feats'], data['labels'],
                    data['masks'], data['att_masks']
                ]
                unpaired_tmp = [
                    unpaired_data['fc_feats'], unpaired_data['att_feats'],
                    unpaired_data['labels'], unpaired_data['masks'],
                    unpaired_data['att_masks']
                ]
                unpaired_caption_tmp = [
                    unpaired_caption['fc_feats'] * 0,
                    unpaired_caption['att_feats'] * 0,
                    unpaired_caption['labels'], unpaired_caption['masks'],
                    unpaired_caption['att_masks']
                ]

            tmp = [_ if _ is None else _.cuda() for _ in tmp]
            fc_feats, att_feats, labels, masks, att_masks = tmp

            if epoch >= opt.paired_train_epoch:
                unpaired_tmp = [
                    _ if _ is None else _.cuda() for _ in unpaired_tmp
                ]
                unpaired_fc_feats, unpaired_att_feats, unpaired_labels, unpaired_masks, unpaired_att_masks = unpaired_tmp
                unpaired_caption_tmp = [
                    _ if _ is None else _.cuda() for _ in unpaired_caption_tmp
                ]
                unpaired_caption_fc_feats, unpaired_caption_att_feats, unpaired_caption_labels, unpaired_caption_masks, unpaired_caption_att_masks = unpaired_caption_tmp
                unpaired_caption_fc_feats = unpaired_caption_fc_feats.repeat(
                    5, 1)
                unpaired_caption_fc_feats = opt.std_pseudo_visual_feature * torch.randn_like(
                    unpaired_caption_fc_feats)
                unpaired_caption_att_feats = unpaired_caption_att_feats.repeat(
                    5, 1, 1)
                unpaired_caption_fc_feats.requires_grad = True
                unpaired_caption_att_feats.requires_grad = True
                unpaired_caption_labels = unpaired_caption_labels.reshape(
                    unpaired_caption_fc_feats.shape[0], -1)
                unpaired_caption_masks = unpaired_caption_masks.reshape(
                    unpaired_caption_fc_feats.shape[0], -1)

            optimizer.zero_grad()
            if epoch < opt.language_pretrain_epoch:
                language_loss = 0
                model_outs_list = []
                for index in range(opt.number_of_models):
                    model_out = dp_lw_models[index](
                        fc_feats, att_feats, labels, masks,
                        att_masks, data['gts'],
                        torch.arange(0, len(data['gts'])), sc_flag, struc_flag)
                    model_outs_list.append(model_out)
                    language_loss += model_out['loss'].mean()

                loss = language_loss
            elif epoch < opt.paired_train_epoch:
                language_loss = 0
                model_outs_list = []
                for index in range(opt.number_of_models):
                    model_out = dp_lw_models[index](
                        fc_feats, att_feats, labels, masks,
                        att_masks, data['gts'],
                        torch.arange(0, len(data['gts'])), sc_flag, struc_flag)
                    model_outs_list.append(model_out)
                    language_loss += model_out['loss'].mean()

                loss = language_loss
            else:
                language_loss = 0
                model_outs_list = []
                for index in range(opt.number_of_models):
                    model_out = dp_lw_models[index](
                        fc_feats, att_feats, labels, masks,
                        att_masks, data['gts'],
                        torch.arange(0, len(data['gts'])), sc_flag, struc_flag)
                    model_outs_list.append(model_out)
                    language_loss += model_out['loss'].mean()
                loss = language_loss

                # else:
                # for unpaired image sentences
                # # Setup the model
                # model_ensemble = AttEnsemble(multi_models_list[:opt.number_of_models], weights=None)
                # model_ensemble.seq_length = 16
                # model_ensemble.cuda()
                # model_ensemble.eval()

                model_ensemble.eval()
                eval_kwargs = dict()
                eval_kwargs.update(vars(opt))

                with torch.no_grad():
                    seq, seq_logprobs = model_ensemble(unpaired_fc_feats,
                                                       unpaired_att_feats,
                                                       unpaired_att_masks,
                                                       opt=eval_kwargs,
                                                       mode='sample')
                    # val_loss, predictions, lang_stats = eval_utils.eval_split(model_ensemble, lw_models[0].crit, loader,
                    #                                                           eval_kwargs)
                # print('\n'.join([utils.decode_sequence(loader.get_vocab(), _['seq'].unsqueeze(0))[0] for _ in
                #                  model_ensemble.done_beams[0]]))
                # print('++' * 10)
                # for ii in range(10):
                #     sents = utils.decode_sequence(loader.get_vocab(), seq[ii].unsqueeze(0))
                #     gt_sent = utils.decode_sequence(loader.get_vocab(), labels[ii,0].unsqueeze(0))
                #     a=1

                model_ensemble.train()

                model_ensemble_sudo_labels = labels.new_zeros(
                    (opt.batch_size, opt.beam_size,
                     eval_kwargs['max_length'] + 2))
                model_ensemble_sudo_log_prob = masks.new_zeros(
                    (opt.batch_size,
                     opt.beam_size, eval_kwargs['max_length'] + 2,
                     len(loader.get_vocab()) + 1))
                model_ensemble_sum_log_prob = masks.new_zeros(
                    (opt.batch_size, opt.beam_size))

                for batch_index in range(opt.batch_size):
                    for beam_index in range(opt.beam_size):
                        # for beam_index in range(3):
                        pred = model_ensemble.done_beams[batch_index][
                            beam_index]['seq']
                        log_prob = model_ensemble.done_beams[batch_index][
                            beam_index]['logps']
                        model_ensemble_sudo_labels[batch_index, beam_index,
                                                   1:pred.shape[0] + 1] = pred
                        model_ensemble_sudo_log_prob[batch_index, beam_index,
                                                     1:pred.shape[0] +
                                                     1] = log_prob
                        model_ensemble_sum_log_prob[batch_index][
                            beam_index] = model_ensemble.done_beams[
                                batch_index][beam_index]['p']

                # model_ensemble_prob = F.softmax(model_ensemble_sum_log_prob)

                data_ensemble_sudo_gts = list()
                for data_ensemble_sudo_gts_index in range(
                        model_ensemble_sudo_labels.shape[0]):
                    data_ensemble_sudo_gts.append(model_ensemble_sudo_labels[
                        data_ensemble_sudo_gts_index, :,
                        1:-1].data.cpu().numpy())

                # generated_sentences = list()
                # for i in range(unpaired_fc_feats.shape[0]):
                #     generated_sentences.append(
                #         [utils.decode_sequence(loader.get_vocab(), _['seq'].unsqueeze(0))[0] for _ in
                #          model_ensemble.done_beams[i]])
                #
                # pos_tag_results = list()
                # for i in range(unpaired_fc_feats.shape[0]):
                #     generated_sentences_i = generated_sentences[i]
                #     pos_tag_results_i = []
                #     for text in generated_sentences_i:
                #         text_tokenize = nltk.word_tokenize(text)
                #         pos_tag_results_i_jbeam = []
                #         for vob, vob_type in nltk.pos_tag(text_tokenize):
                #             if vob_type == 'NN' or vob_type == 'NNS':
                #                 pos_tag_results_i_jbeam.append(vob)
                #         pos_tag_results_i.append(pos_tag_results_i_jbeam)
                #     pos_tag_results.append(pos_tag_results_i)

                # for i in range(fc_feats.shape[0]):
                #     print('\n'.join([utils.decode_sequence(loader.get_vocab(), _['seq'].unsqueeze(0))[0] for _ in
                #                      model_ensemble.done_beams[i]]))
                #     print('--' * 10)
                # dets = data['dets']
                #
                # promising_flag = labels.new_zeros(opt.batch_size, opt.beam_size)
                # for batch_index in range(opt.batch_size):
                #     dets_batch = dets[batch_index]
                #     for beam_index in range(opt.beam_size):
                #         indicator = [0] * len(dets_batch)
                #         pos_tag_batch_beam = pos_tag_results[batch_index][beam_index]
                #         for pos_tag_val in pos_tag_batch_beam:
                #             for ii in range(len(dets_batch)):
                #                 possible_list = vob_transform_list[dets_batch[ii]]
                #                 if pos_tag_val in possible_list:
                #                     indicator[ii] = 1
                #         if sum(indicator) == len(dets_batch) or sum(indicator) >= 2:
                #             promising_flag[batch_index, beam_index] = 1
                #
                # # model_ensemble_sudo_log_prob = model_ensemble_sudo_log_prob * promising_flag.unsqueeze(-1).unsqueeze(-1)
                # model_ensemble_sudo_labels = model_ensemble_sudo_labels * promising_flag.unsqueeze(-1)

                #sudo_masks_for_model = sudo_masks_for_model.detach()
                distilling_loss = 0
                # We use the random study machinism
                who_to_study = random.randint(0, opt.number_of_models - 1)

                # for index in range(opt.number_of_models):
                #     model_out = dp_kdlw_models[index](unpaired_fc_feats, unpaired_att_feats, model_ensemble_sudo_labels,
                #                                     model_ensemble_sudo_log_prob, att_masks, data_ensemble_sudo_gts,
                #                                     torch.arange(0, len(data_ensemble_sudo_gts)), sc_flag,
                #                                     struc_flag, model_ensemble_sum_log_prob)
                #     kd_model_outs_list.append(model_out)

                model_out = dp_kdlw_models[who_to_study](
                    unpaired_fc_feats, unpaired_att_feats,
                    model_ensemble_sudo_labels, model_ensemble_sudo_log_prob,
                    att_masks, data_ensemble_sudo_gts,
                    torch.arange(0, len(data_ensemble_sudo_gts)), sc_flag,
                    struc_flag, model_ensemble_sum_log_prob)
                # kd_model_outs_list.append(model_out)
                distilling_loss += model_out['loss'].mean()
                loss += opt.number_of_models * opt.current_lambda_x * distilling_loss

                ###################################################################
                # use unlabelled captions
                # simple_sgd = utils.gradient_descent(unpaired_caption_fc_feats, stepsize=1e3)
                simple_sgd = utils.gradient_descent_adagrad(
                    unpaired_caption_fc_feats, stepsize=1)
                gts_tmp = unpaired_caption['gts']
                new_gts = []
                for ii in range(len(data['gts'])):
                    for jj in range(gts_tmp[ii].shape[0]):
                        new_gts.append(gts_tmp[ii][jj])
                unpaired_caption['gts'] = new_gts
                for itr in range(opt.inner_iteration):
                    unlabelled_caption_model_out = dp_lw_models_ema[
                        itr % opt.number_of_models](
                            unpaired_caption_fc_feats,
                            unpaired_caption_att_feats,
                            unpaired_caption_labels, unpaired_caption_masks,
                            unpaired_caption_att_masks,
                            unpaired_caption['gts'],
                            torch.arange(0, len(unpaired_caption['gts'])),
                            sc_flag, struc_flag)
                    unlabelled_caption_loss = unlabelled_caption_model_out[
                        'loss'].mean()
                    unlabelled_caption_loss.backward()
                    # print(unlabelled_caption_loss)
                    simple_sgd.update(unpaired_caption_fc_feats)
                    # a=1

                unpaired_caption_fc_feats.requires_grad = False
                unpaired_caption_att_feats.requires_grad = False
                unlabelled_caption_model_out = dp_lw_models[who_to_study](
                    unpaired_caption_fc_feats, unpaired_caption_att_feats,
                    unpaired_caption_labels, unpaired_caption_masks,
                    unpaired_caption_att_masks, unpaired_caption['gts'],
                    torch.arange(0, len(unpaired_caption['gts'])), sc_flag,
                    struc_flag)
                unlabelled_caption_loss = unlabelled_caption_model_out[
                    'loss'].mean()
                loss += opt.number_of_models * opt.current_lambda_y * unlabelled_caption_loss

            loss.backward()
            if opt.grad_clip_value != 0:
                getattr(torch.nn.utils, 'clip_grad_%s_' %
                        (opt.grad_clip_mode))(multi_models.parameters(),
                                              opt.grad_clip_value)
            optimizer.step()

            for order in range(opt.number_of_models):
                for param, param_ema in zip(
                        multi_models_list[order].parameters(),
                        multi_models_list[order +
                                          opt.number_of_models].parameters()):
                    param_ema.data = opt.alpha * param_ema.data + (
                        1 - opt.alpha) * param.data

            train_loss = loss.item()
            torch.cuda.synchronize()
            end = time.time()
            # if struc_flag:
            #     print("iter {} (epoch {}), train_loss = {:.3f}, lm_loss = {:.3f}, struc_loss = {:.3f}, time/batch = {:.3f}" \
            #         .format(iteration, epoch, train_loss, model_out['lm_loss'].mean().item(), model_out['struc_loss'].mean().item(), end - start))
            # elif not sc_flag:
            #     print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
            #         .format(iteration, epoch, train_loss, end - start))
            # else:
            #     print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \
            #         .format(iteration, epoch, model_out['reward'].mean(), end - start))
            if struc_flag:
                print("iter {} (epoch {}), train_loss = {:.3f}, lm_loss = {:.3f}, struc_loss = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, train_loss/opt.number_of_models, sum([model_outs_list[index]['lm_loss'].mean().item() for index in range(opt.number_of_models)])/opt.number_of_models,
                            sum([model_outs_list[index]['struc_loss'].mean().item() for index in range(opt.number_of_models)])/opt.number_of_models,
                            end - start))
            elif not sc_flag:
                print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, language_loss.item()/opt.number_of_models, end - start))
            else:
                print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, sum([model_outs_list[index]['reward'].mean().item() for index in range(opt.number_of_models)])/opt.number_of_models, end - start))

            # Update the iteration and epoch
            iteration += 1
            if epoch < opt.paired_train_epoch:
                if data['bounds']['wrapped']:
                    epoch += 1
                    epoch_done = True
            else:
                if data['bounds']['wrapped']:
                    epoch += 1
                    epoch_done = True

            # Write the training loss summary
            if (iteration % opt.losses_log_every == 0):
                # tb_summary_writer.add_scalar('train_loss', train_loss, iteration)
                for index in range(opt.number_of_models):
                    model_id = 'model_{}'.format(index)
                    tb_summary_writer.add_scalars('language_loss', {
                        model_id:
                        model_outs_list[index]['loss'].mean().item()
                    }, iteration)
                if epoch >= opt.paired_train_epoch:
                    # for index in range(opt.number_of_models):
                    #     model_id = 'model_{}'.format(index)
                    #     kd_model_outs_val = 0 if len(kd_model_outs_list) == 0 else kd_model_outs_list[index]['loss'].mean().item()
                    #     tb_summary_writer.add_scalars('distilling_loss',
                    #                                   {model_id: kd_model_outs_val},
                    #                                   iteration)
                    tb_summary_writer.add_scalar('distilling_loss',
                                                 distilling_loss.item(),
                                                 iteration)
                    tb_summary_writer.add_scalar(
                        'unlabelled_caption_loss',
                        unlabelled_caption_loss.item(), iteration)
                    tb_summary_writer.add_scalar('hyper_parameter_lambda_x',
                                                 opt.current_lambda_x,
                                                 iteration)
                    tb_summary_writer.add_scalar('hyper_parameter_lambda_y',
                                                 opt.current_lambda_y,
                                                 iteration)
                # tb_summary_writer.add_scalar('triplet_loss', triplet_loss_val.item(), iteration)
                if opt.noamopt:
                    opt.current_lr = optimizer.rate()
                elif opt.reduce_on_plateau:
                    opt.current_lr = optimizer.current_lr
                tb_summary_writer.add_scalar('learning_rate', opt.current_lr,
                                             iteration)
                tb_summary_writer.add_scalar('scheduled_sampling_prob',
                                             multi_models[0].ss_prob,
                                             iteration)
                if sc_flag:
                    for index in range(opt.number_of_models):
                        # tb_summary_writer.add_scalar('avg_reward', model_out['reward'].mean(), iteration)
                        model_id = 'model_{}'.format(index)
                        tb_summary_writer.add_scalars(
                            'avg_reward', {
                                model_id:
                                model_outs_list[index]['reward'].mean().item()
                            }, iteration)
                elif struc_flag:
                    # tb_summary_writer.add_scalar('lm_loss', model_out['lm_loss'].mean().item(), iteration)
                    # tb_summary_writer.add_scalar('struc_loss', model_out['struc_loss'].mean().item(), iteration)
                    # tb_summary_writer.add_scalar('reward', model_out['reward'].mean().item(), iteration)
                    # tb_summary_writer.add_scalar('reward_var', model_out['reward'].var(1).mean(), iteration)
                    model_id = 'model_{}'.format(index)
                    for index in range(opt.number_of_models):
                        tb_summary_writer.add_scalars(
                            'lm_loss', {
                                model_id:
                                model_outs_list[index]
                                ['lm_loss'].mean().item()
                            }, iteration)
                        tb_summary_writer.add_scalars(
                            'struc_loss', {
                                model_id:
                                model_outs_list[index]
                                ['struc_loss'].mean().item()
                            }, iteration)
                        tb_summary_writer.add_scalars(
                            'reward', {
                                model_id:
                                model_outs_list[index]['reward'].mean().item()
                            }, iteration)
                        tb_summary_writer.add_scalars(
                            'reward_var', {
                                model_id:
                                model_outs_list[index]['reward'].var(1).mean()
                            }, iteration)

                histories['loss_history'][
                    iteration] = train_loss if not sc_flag else sum([
                        model_outs_list[index]['reward'].mean().item()
                        for index in range(opt.number_of_models)
                    ]) / opt.number_of_models
                histories['lr_history'][iteration] = opt.current_lr
                histories['ss_prob_history'][iteration] = multi_models[
                    0].ss_prob

            # update infos
            infos['iter'] = iteration
            infos['epoch'] = epoch
            infos['loader_state_dict'] = loader.state_dict()

            # make evaluation on validation set, and save model
            if (iteration % opt.save_checkpoint_every == 0 and not opt.save_every_epoch and epoch >= opt.paired_train_epoch) or \
                (epoch_done and opt.save_every_epoch and epoch >= opt.paired_train_epoch):
                # load ensemble
                # Setup the model
                model = AttEnsemble(multi_models_list[opt.number_of_models:2 *
                                                      opt.number_of_models],
                                    weights=None)
                model.seq_length = opt.max_length
                model.cuda()
                model.eval()
                # eval model
                eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
                eval_kwargs.update(vars(opt))
                # eval_kwargs['beam_size'] = 5
                # eval_kwargs['verbose_beam'] = 1
                # eval_kwargs['verbose_loss'] = 1
                # val_loss, predictions, lang_stats = eval_utils.eval_split(
                #     dp_model, lw_model.crit, loader, eval_kwargs)
                with torch.no_grad():
                    val_loss, predictions, lang_stats = eval_utils.eval_split(
                        model, lw_models[0].crit, loader, eval_kwargs)
                model.train()

                if opt.reduce_on_plateau:
                    if 'CIDEr' in lang_stats:
                        optimizer.scheduler_step(-lang_stats['CIDEr'])
                    else:
                        optimizer.scheduler_step(val_loss)
                # Write validation result into summary
                tb_summary_writer.add_scalar('validation loss', val_loss,
                                             iteration)
                if lang_stats is not None:
                    for k, v in lang_stats.items():
                        tb_summary_writer.add_scalar(k, v, iteration)
                histories['val_result_history'][iteration] = {
                    'loss': val_loss,
                    'lang_stats': lang_stats,
                    'predictions': predictions
                }

                # Save model if is improving on validation result
                if opt.language_eval == 1:
                    current_score = lang_stats['CIDEr']
                else:
                    current_score = -val_loss

                best_flag = False

                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True

                # Dump miscalleous informations
                infos['best_val_score'] = best_val_score

                utils.save_checkpoint(opt, multi_models, infos, optimizer,
                                      histories)
                if opt.save_history_ckpt:
                    utils.save_checkpoint(
                        opt,
                        multi_models,
                        infos,
                        optimizer,
                        append=str(epoch)
                        if opt.save_every_epoch else str(iteration))

                if best_flag:
                    utils.save_checkpoint(opt,
                                          multi_models,
                                          infos,
                                          optimizer,
                                          append='best')

            # if epoch_done and epoch == opt.paired_train_epoch:
            #     utils.save_checkpoint(opt, multi_models, infos, optimizer, histories)
            #     if opt.save_history_ckpt:
            #         utils.save_checkpoint(opt, multi_models, infos, optimizer,
            #                               append=str(epoch) if opt.save_every_epoch else str(iteration))
            #     cmd = 'cp -r ' + 'log_' + opt.id + ' ' + 'log_' + opt.id + '_backup'
            #     os.system(cmd)

    except (RuntimeError, KeyboardInterrupt):
        print('Save ckpt on exception ...')
        utils.save_checkpoint(opt, multi_models, infos, optimizer)
        print('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)
def train(opt):

    ################################
    # Build dataloader
    ################################
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    ##########################
    # Initialize infos
    ##########################
    infos = {
        'iter': 0,
        'epoch': 0,
        'loader_state_dict': None,
        'vocab': loader.get_vocab(),
    }
    # Load old infos(if there is) and check if models are compatible
    if opt.start_from is not None and os.path.isfile(
            os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl')):
        with open(os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl'),
                  'rb') as f:
            infos = utils.pickle_load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert getattr(saved_model_opt, checkme) == getattr(
                    opt, checkme
                ), "Command line argument and saved model disagree on '%s' " % checkme
    infos['opt'] = opt

    #########################
    # Build logger
    #########################
    # naive dict logger
    histories = defaultdict(dict)
    if opt.start_from is not None and os.path.isfile(
            os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
        with open(os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl'),
                  'rb') as f:
            histories.update(utils.pickle_load(f))
    # tensorboard logger
    tb_summary_writer = SummaryWriter(opt.checkpoint_path)

    ##########################
    # Build model
    ##########################
    opt.vocab = loader.get_vocab()
    model = models.setup(opt).cuda()
    del opt.vocab
    # Load pretrained weights:
    if opt.start_from is not None and os.path.isfile(
            os.path.join(opt.start_from, 'model.pth')):
        model.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'model.pth')))

    # Wrap generation model with loss function(used for training)
    # This allows loss function computed separately on each machine
    lw_model = LossWrapper(model, opt)
    # Wrap with dataparallel
    # dp_model = torch.nn.DataParallel(model)
    # dp_lw_model = torch.nn.DataParallel(lw_model)
    dp_model = model
    dp_lw_model = lw_model

    ##########################
    #  Build optimizer
    ##########################
    if opt.noamopt:
        assert opt.caption_model == 'transformer', 'noamopt can only work with transformer'
        optimizer = utils.get_std_opt(model,
                                      factor=opt.noamopt_factor,
                                      warmup=opt.noamopt_warmup)
    elif opt.reduce_on_plateau:
        optimizer = utils.build_optimizer(model.parameters(), opt)
        optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3)
    else:
        optimizer = utils.build_optimizer(model.parameters(), opt)
    # Load the optimizer
    if opt.start_from is not None and os.path.isfile(
            os.path.join(opt.start_from, "optimizer.pth")):
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    #########################
    # Get ready to start
    #########################
    iteration = infos['iter']
    epoch = infos['epoch']
    # For back compatibility
    if 'iterators' in infos:
        infos['loader_state_dict'] = {
            split: {
                'index_list': infos['split_ix'][split],
                'iter_counter': infos['iterators'][split]
            }
            for split in ['train', 'val', 'test']
        }
    loader.load_state_dict(infos['loader_state_dict'])
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)
    if opt.noamopt:
        optimizer._step = iteration
    # flag indicating finish of an epoch
    # Always set to True at the beginning to initialize the lr or etc.
    epoch_done = True
    # Assure in training mode
    dp_lw_model.train()

    # Start training
    try:
        while True:
            # Stop if reaching max epochs
            if epoch >= opt.max_epochs and opt.max_epochs != -1:
                break

            if epoch_done:
                if not opt.noamopt and not opt.reduce_on_plateau:
                    # Assign the learning rate
                    if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                        frac = (epoch - opt.learning_rate_decay_start
                                ) // opt.learning_rate_decay_every
                        decay_factor = opt.learning_rate_decay_rate**frac
                        opt.current_lr = opt.learning_rate * decay_factor
                    else:
                        opt.current_lr = opt.learning_rate
                    utils.set_lr(optimizer,
                                 opt.current_lr)  # set the decayed rate
                # Assign the scheduled sampling prob
                if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                    frac = (epoch - opt.scheduled_sampling_start
                            ) // opt.scheduled_sampling_increase_every
                    opt.ss_prob = min(
                        opt.scheduled_sampling_increase_prob * frac,
                        opt.scheduled_sampling_max_prob)
                    model.ss_prob = opt.ss_prob

                # If start self critical training
                if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                    sc_flag = True
                    init_scorer(opt.cached_tokens)
                else:
                    sc_flag = False

                # If start structure loss training
                if opt.structure_after != -1 and epoch >= opt.structure_after:
                    struc_flag = True
                    init_scorer(opt.cached_tokens)
                else:
                    struc_flag = False

                epoch_done = False

            start = time.time()
            # Load data from train split (0)
            data = loader.get_batch('train')

            torch.cuda.synchronize()
            start = time.time()

            tmp = [
                data['fc_feats'], data['att_feats'], data['labels'],
                data['masks'], data['att_masks'], data['topics']
            ]
            tmp = [_ if _ is None else _.cuda() for _ in tmp]
            fc_feats, att_feats, labels, masks, att_masks, topic_vecs = tmp

            # gts_end = data['gts_end']
            # import ipdb; ipdb.set_trace()

            train_loss = 0.0
            #### Model forward pass
            #### Liangming: Add for loop
            sent_num = opt.seq_per_img
            labels = labels.view(opt.batch_size, sent_num, -1)
            masks = masks.view(opt.batch_size, sent_num, -1)
            # topic_vecs = topic_vecs.view(opt.batch_size, sent_num, -1)

            # initilize topic vec, the shape is: [batch_size 10, max_seq_len 31, hidden_size 512]
            topic_vec = torch.zeros((opt.batch_size, labels.shape[2] - 1,
                                     opt.rnn_size)).float().cuda()

            total_loss = 0.0
            for sent_n in range(sent_num):
                # prepare sentence data
                sent_label = labels[:, sent_n, :]
                sent_mask = masks[:, sent_n, :]

                # We should skip the batch in which the sentences for all examples in the batch are 0s.
                # This is likely to happen at the end of the paragraph)
                if torch.sum(sent_label).item() == 0:
                    continue

                # model forward pass
                optimizer.zero_grad()
                model_out = dp_lw_model(fc_feats, att_feats, sent_label,
                                        sent_mask, att_masks, topic_vec,
                                        data['gts'],
                                        torch.arange(0, len(data['gts'])),
                                        sc_flag, struc_flag)
                # loss calculation

                loss = model_out['loss'].mean()
                total_loss += loss
                decoder_output = model_out['decoder_output']

                # Cannot backward here, this will cause multiple backwards...
                # Specify retain_graph=True, if you still want to backward here
                #loss.backward()
                #getattr(torch.nn.utils, 'clip_grad_%s_' %(opt.grad_clip_mode))(model.parameters(), opt.grad_clip_value)
                #optimizer.step()
                #train_loss = loss.item()

                # PLM: treat decoder output as "topic vec"
                # topic_vec = decoder_output

                # Optional: Shrink the size of it based on the mask?
                max_sent_len = -1
                for row in range(sent_mask.shape[0]):
                    sent_len = int(sum(sent_mask[row, :]).data.item())
                    if sent_len > max_sent_len:
                        max_sent_len = sent_len
                topic_vec = decoder_output[:, 0:max_sent_len, :]

            avg_loss = total_loss / sent_num
            avg_loss.backward()
            #loss.backward()
            getattr(torch.nn.utils, 'clip_grad_%s_' % (opt.grad_clip_mode))(
                model.parameters(), opt.grad_clip_value)
            optimizer.step()
            train_loss = loss.item()

            torch.cuda.synchronize()
            end = time.time()
            if iteration % opt.print_freq == 1:
                print('Read data:', time.time() - start)
                if struc_flag:
                    print("iter {} (epoch {}), train_loss = {:.3f}, lm_loss = {:.3f}, struc_loss = {:.3f}, time/batch = {:.3f}" \
                        .format(iteration, epoch, train_loss, model_out['lm_loss'].mean().item(), model_out['struc_loss'].mean().item(), end - start))
                elif not sc_flag:
                    print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                        .format(iteration, epoch, train_loss, end - start))
                else:
                    print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \
                        .format(iteration, epoch, model_out['reward'].mean(), end - start))

            # Update the iteration and epoch
            iteration += 1
            if data['bounds']['wrapped']:
                epoch += 1
                epoch_done = True

            # Write the training loss summary
            if (iteration % opt.losses_log_every == 0):
                tb_summary_writer.add_scalar('train_loss', train_loss,
                                             iteration)
                if opt.noamopt:
                    opt.current_lr = optimizer.rate()
                elif opt.reduce_on_plateau:
                    opt.current_lr = optimizer.current_lr
                tb_summary_writer.add_scalar('learning_rate', opt.current_lr,
                                             iteration)
                tb_summary_writer.add_scalar('scheduled_sampling_prob',
                                             model.ss_prob, iteration)
                if sc_flag:
                    tb_summary_writer.add_scalar('avg_reward',
                                                 model_out['reward'].mean(),
                                                 iteration)
                elif struc_flag:
                    tb_summary_writer.add_scalar(
                        'lm_loss', model_out['lm_loss'].mean().item(),
                        iteration)
                    tb_summary_writer.add_scalar(
                        'struc_loss', model_out['struc_loss'].mean().item(),
                        iteration)
                    tb_summary_writer.add_scalar(
                        'reward', model_out['reward'].mean().item(), iteration)

                histories['loss_history'][
                    iteration] = train_loss if not sc_flag else model_out[
                        'reward'].mean()
                histories['lr_history'][iteration] = opt.current_lr
                histories['ss_prob_history'][iteration] = model.ss_prob

            # update infos
            infos['iter'] = iteration
            infos['epoch'] = epoch
            infos['loader_state_dict'] = loader.state_dict()

            # make evaluation on validation set, and save model
            if (iteration % opt.save_checkpoint_every == 0 and not opt.save_every_epoch) or \
                (epoch_done and opt.save_every_epoch):
                # eval model
                eval_kwargs = {'split': 'val', 'dataset': 'val'}
                eval_kwargs.update(vars(opt))

                val_loss, predictions, lang_stats = eval_utils.eval_split(
                    dp_model, lw_model.crit, loader, eval_kwargs)

                if opt.reduce_on_plateau:
                    if 'CIDEr' in lang_stats:
                        optimizer.scheduler_step(-lang_stats['CIDEr'])
                    else:
                        optimizer.scheduler_step(val_loss)
                # Write validation result into summary
                tb_summary_writer.add_scalar('validation loss', val_loss,
                                             iteration)
                if lang_stats is not None:
                    for k, v in lang_stats.items():
                        tb_summary_writer.add_scalar(k, v, iteration)
                histories['val_result_history'][iteration] = {
                    'loss': val_loss,
                    'lang_stats': lang_stats,
                    'predictions': predictions
                }

                # Save model if is improving on validation result
                if opt.language_eval == 1:
                    current_score = lang_stats['CIDEr']
                else:
                    current_score = -val_loss

                best_flag = False

                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True

                # Dump miscalleous informations
                infos['best_val_score'] = best_val_score

                utils.save_checkpoint(opt, model, infos, optimizer, histories)
                if opt.save_history_ckpt:
                    utils.save_checkpoint(
                        opt,
                        model,
                        infos,
                        optimizer,
                        append=str(epoch)
                        if opt.save_every_epoch else str(iteration))

                if best_flag:
                    utils.save_checkpoint(opt,
                                          model,
                                          infos,
                                          optimizer,
                                          append='best')

    except (RuntimeError, KeyboardInterrupt):
        print('Save ckpt on exception ...')
        utils.save_checkpoint(opt, model, infos, optimizer)
        print('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)
Ejemplo n.º 27
0
def train(opt):
    # Deal with feature things before anything
    opt.use_att = utils.if_use_att(opt.caption_model)
    ac = 0

    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(
                os.path.join(
                    opt.checkpoint_path, 'infos_' + opt.id +
                    format(int(opt.start_from), '04') + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(
                os.path.join(
                    opt.checkpoint_path, 'histories_' + opt.id +
                    format(int(opt.start_from), '04') + '.pkl')):
            with open(
                    os.path.join(
                        opt.checkpoint_path, 'histories_' + opt.id +
                        format(int(opt.start_from), '04') + '.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt).cuda()

    #dp_model = torch.nn.DataParallel(model)
    #dp_model = torch.nn.DataParallel(model, [0,2,3])
    dp_model = model

    update_lr_flag = True
    # Assure in training mode
    dp_model.train()

    for name, param in model.named_parameters():
        print(name)

    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()
    CE_ac = utils.CE_ac()

    optim_para = model.parameters()
    optimizer = utils.build_optimizer(optim_para, opt)

    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(
            os.path.join(
                opt.checkpoint_path, 'optimizer' + opt.id +
                format(int(opt.start_from), '04') + '.pth')):
        optimizer.load_state_dict(
            torch.load(
                os.path.join(
                    opt.checkpoint_path, 'optimizer' + opt.id +
                    format(int(opt.start_from), '04') + '.pth')))

    optimizer.zero_grad()
    accumulate_iter = 0
    train_loss = 0
    reward = np.zeros([1, 1])
    sim_lambda = opt.sim_lambda

    while True:
        if update_lr_flag:
            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate
            utils.set_lr(optimizer, opt.current_lr)
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            update_lr_flag = False

        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch(opt.train_split)
        print('Read data:', time.time() - start)

        torch.cuda.synchronize()
        start = time.time()

        tmp = [data['labels'], data['masks'], data['mods']]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        labels, masks, mods = tmp

        tmp = [
            data['att_feats'], data['att_masks'], data['attr_feats'],
            data['attr_masks'], data['rela_feats'], data['rela_masks']
        ]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        att_feats, att_masks, attr_feats, attr_masks, rela_feats, rela_masks = tmp

        rs_data = {}
        rs_data['att_feats'] = att_feats
        rs_data['att_masks'] = att_masks
        rs_data['attr_feats'] = attr_feats
        rs_data['attr_masks'] = attr_masks
        rs_data['rela_feats'] = rela_feats
        rs_data['rela_masks'] = rela_masks

        if not sc_flag:
            logits, cw_logits = dp_model(rs_data, labels)
            ac = CE_ac(logits, labels[:, 1:], masks[:, 1:])
            print('ac :{0}'.format(ac))
            loss_lan = crit(logits, labels[:, 1:], masks[:, 1:])
        else:
            gen_result, sample_logprobs, cw_logits = dp_model(
                rs_data, opt={'sample_max': 0}, mode='sample')
            reward = get_self_critical_reward(dp_model, rs_data, data,
                                              gen_result, opt)
            loss_lan = rl_crit(sample_logprobs, gen_result.data,
                               torch.from_numpy(reward).float().cuda())

        loss_cw = crit(cw_logits, mods[:, 1:], masks[:, 1:])
        ac2 = CE_ac(cw_logits, mods[:, 1:], masks[:, 1:])
        print('ac :{0}'.format(ac2))
        if epoch < opt.step2_train_after:
            loss = loss_lan + sim_lambda * loss_cw
        else:
            loss = loss_lan

        accumulate_iter = accumulate_iter + 1
        loss = loss / opt.accumulate_number
        loss.backward()
        if accumulate_iter % opt.accumulate_number == 0:
            utils.clip_gradient(optimizer, opt.grad_clip)
            optimizer.step()
            optimizer.zero_grad()
            iteration += 1
            accumulate_iter = 0
            train_loss = loss.item() * opt.accumulate_number
            train_loss_lan = loss_lan.item()
            train_loss_cw = loss_cw.item()
            end = time.time()

            if not sc_flag:
                print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                      .format(iteration, epoch, train_loss, end - start))
                print("train_loss_lan = {:.3f}, train_loss_cw = {:.3f}" \
                      .format(train_loss_lan, train_loss_cw))
            else:
                print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \
                      .format(iteration, epoch, np.mean(reward[:, 0]), end - start))
                print("train_loss_lan = {:.3f}, train_loss_cw = {:.3f}" \
                      .format(train_loss_lan, train_loss_cw))
            print('lr:{0}'.format(opt.current_lr))

        torch.cuda.synchronize()

        # Update the iteration and epoch
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every
                == 0) and (accumulate_iter % opt.accumulate_number == 0):
            add_summary_value(tb_summary_writer, 'train_loss', train_loss,
                              iteration)
            add_summary_value(tb_summary_writer, 'train_loss_lan',
                              train_loss_lan, iteration)
            add_summary_value(tb_summary_writer, 'train_loss_cw',
                              train_loss_cw, iteration)
            add_summary_value(tb_summary_writer, 'learning_rate',
                              opt.current_lr, iteration)
            add_summary_value(tb_summary_writer, 'scheduled_sampling_prob',
                              model.ss_prob, iteration)
            add_summary_value(tb_summary_writer, 'ac', ac, iteration)
            if sc_flag:
                add_summary_value(tb_summary_writer, 'avg_reward',
                                  np.mean(reward[:, 0]), iteration)

            loss_history[iteration] = train_loss if not sc_flag else np.mean(
                reward[:, 0])
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every
                == 0) and (accumulate_iter % opt.accumulate_number == 0):
            # eval model
            eval_kwargs = {'split': 'test', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            #val_loss, predictions, lang_stats = eval_utils_rs3.eval_split(dp_model, crit, loader, eval_kwargs)

            # Write validation result into summary
            # add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration)
            # if lang_stats is not None:
            #     for k,v in lang_stats.items():
            #         add_summary_value(tb_summary_writer, k, v, iteration)
            # val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions}

            # Save model if is improving on validation result
            # if opt.language_eval == 1:
            #     current_score = lang_stats['CIDEr']
            # else:
            #     current_score = - val_loss
            current_score = 0

            best_flag = False
            if True:  # if true
                save_id = iteration / opt.save_checkpoint_every
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                checkpoint_path = os.path.join(
                    opt.checkpoint_path,
                    'model' + opt.id + format(int(save_id), '04') + '.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(
                    opt.checkpoint_path,
                    'optimizer' + opt.id + format(int(save_id), '04') + '.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(
                        os.path.join(
                            opt.checkpoint_path, 'infos_' + opt.id +
                            format(int(save_id), '04') + '.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(
                        os.path.join(
                            opt.checkpoint_path, 'histories_' + opt.id +
                            format(int(save_id), '04') + '.pkl'), 'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos_' + opt.id + '-best.pkl'),
                            'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
Ejemplo n.º 28
0
def train(opt):
    iteration = 0
    epoch = 0
    # Load data
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    # Tensorboard summaries (they're great!)
    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    # Load pretrained model, info file, histories file
    infos = {}
    histories = {}

    # Create model
    model = convcap(opt).cuda()
    pretrained_dict = torch.load(opt.model)
    model.load_state_dict(pretrained_dict, strict=False)
    start = time.time()
    dp_model = torch.nn.DataParallel(model)
    dp_model.train()

    optimizer = utils.build_optimizer(model.parameters(), opt)
    update_lr_flag = True
    samplenet = sampleNet(dp_model, opt)
    while True:
        # Unpack data
        #torch.cuda.synchronize()
        data = loader.get_batch('train')
        data_time = time.time() - start
        tmp = [
            data['fc_feats'], data['att_feats'], data['labels'], data['dist'],
            data['masks'], data['att_masks']
        ]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        fc_feats, att_feats, labels, dist_label, masks, att_masks = tmp
        batchsize = fc_feats.size(0)
        # Forward pass and loss
        optimizer.zero_grad()
        labels_decode = labels.view(-1, 180)
        captions = utils.decode_sequence(loader.get_vocab(), labels_decode,
                                         None)
        captions_all = []
        for index, caption in enumerate(captions):
            caption = caption.replace('<start>',
                                      '').replace(' ,', '').replace('  ', ' ')
            captions_all.append(caption)
        #print (captions_all[0])
        #with torch.no_grad():
        target, outcap, sample_right = samplenet(batchsize, 30 * 6,
                                                 loader.get_vocab(), att_feats)
        #wordclass_feed = wordclass_feed.reshape((batchsize, 6, 30))
        #out, _ = dp_model(fc_feats, att_feats, torch.tensor(wordclass_feed))
        #Logprobs = torch.log(F.softmax(out.transpose(2,1)))
        #target = target.view((batchsize, (30*6), -1))
        #sampleLogprobs = Logprobs.gather(2, target.long().unsqueeze(2)) # gather t
        #print (sampleLogprobs.size(), sample_right.size())
        #print (sampleLogprobs.squeeze()[:, :], sample_right[:, :])
        with torch.no_grad():
            reward, cider_sample, cider_greedy = get_self_critical_reward(
                batchsize, dp_model, att_feats, outcap, captions_all,
                loader.get_vocab(), 30 * 6)
        loss_rl = rl_crit(sample_right, target.data,
                          torch.from_numpy(reward).float())
        wordact, x_all = dp_model(fc_feats, att_feats, labels, 30, 6)
        mask = masks[:, 1:].contiguous()
        wordact = wordact[:, :, :-1]
        wordact_t = wordact.permute(0, 2, 1).contiguous()
        wordact_t = wordact_t.view(wordact_t.size(0) * wordact_t.size(1), -1)
        labels = labels.contiguous().view(-1, 6 * 30).cpu()
        wordclass_v = labels[:, 1:]
        wordclass_t = wordclass_v.contiguous().view(\
           wordclass_v.size(0) * wordclass_v.size(1), 1)
        maskids = torch.nonzero(mask.view(-1).cpu()).numpy().reshape(-1)
        loss_xe = F.cross_entropy(wordact_t[maskids, ...], \
           wordclass_t[maskids, ...].contiguous().view(maskids.shape[0]))
        loss_xe_all = loss_rl  #+ F.mse_loss(x_all_inference.cuda(), x_all.cuda()).cuda()
        loss_xe_all.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        train_loss = loss_xe_all.item()
        torch.cuda.synchronize()
        # Print
        total_time = time.time() - start
        reward = reward[:, 0].mean()
        cider_sample = cider_sample.mean()
        cider_greedy = cider_greedy.mean()
        if 1:
            if iteration % 2 == 1:
                print('Read data:', time.time() - start)
                if 0:
                    print("iter {} (epoch {}), train_loss = {:.3f}, data_time = {:.3f}, time/batch = {:.3f}" \
                        .format(iteration, epoch, train_loss, data_time, total_time))
                if 1:
                    print("iter {} (epoch {}), train_loss = {:.3f}, avg_reward = {:.3f},cider_sample  = {:.3f}, cider_greedy ={:.3f},  data_time = {:.3f}, time/batch = {:.3f}" \
                        .format(iteration, epoch, train_loss, reward, cider_sample, cider_greedy, data_time, total_time))

            # Update the iteration and epoch
            iteration += 1
            if data['bounds']['wrapped']:
                epoch += 1
                update_lr_flag = True

            # Write the training loss summary
            '''
        if (iteration % opt.losses_log_every == 0):
            add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration)
            add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration)
            #add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration)
            if sc_flag:
                add_summary_value(tb_summary_writer, 'avg_reward', np.mean(reward[:,0]), iteration)
            loss_history[iteration] = train_loss if not sc_flag else np.mean(reward[:,0])
            lr_history[iteration] = opt.current_lr
            #ss_prob_history[iteration] = model.ss_prob
        '''
            # Validate and save model
            if (iteration % opt.save_checkpoint_every == 0):
                checkpoint_path = os.path.join(
                    opt.checkpoint_path, 'model' + str(iteration) + '.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)
                # Evaluate model
                '''
Ejemplo n.º 29
0
def train(opt):

    # Load data
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    # Tensorboard summaries (they're great!)
    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    # Load pretrained model, info file, histories file
    infos = {}
    histories = {}
    if opt.start_from is not None:
        with open(os.path.join(opt.start_from,
                               'infos_' + opt.id + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = ["rnn_type", "rnn_size", "num_layers"]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme
        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl')) as f:
                histories = cPickle.load(f)
    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)
    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})
    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    # Create model
    model = models.setup(opt).cuda()
    dp_model = torch.nn.DataParallel(model, device_ids=[0])
    dp_model.train()

    # Loss function
    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()

    # Optimizer and learning rate adjustment flag
    optimizer = utils.build_optimizer(model.parameters(), opt)
    update_lr_flag = True

    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(
            os.path.join(opt.start_from, "optimizer.pth")):
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    # Training loop
    while True:

        # Update learning rate once per epoch
        if update_lr_flag:

            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate
            utils.set_lr(optimizer, opt.current_lr)

            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            update_lr_flag = False

        # Load data from train split (0)
        start = time.time()
        data = loader.get_batch('train')
        data_time = time.time() - start
        start = time.time()

        # Unpack data
        torch.cuda.synchronize()
        tmp = [
            data['fc_feats'], data['att_feats'], data['labels'], data['masks'],
            data['att_masks']
        ]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        fc_feats, att_feats, labels, masks, att_masks = tmp

        # Forward pass and loss
        optimizer.zero_grad()
        if not sc_flag:
            loss = crit(dp_model(fc_feats, att_feats, labels, att_masks),
                        labels[:, 1:], masks[:, 1:])
        else:
            gen_result, sample_logprobs = dp_model(fc_feats,
                                                   att_feats,
                                                   att_masks,
                                                   opt={'sample_max': 0},
                                                   mode='sample')
            reward = get_self_critical_reward(dp_model, fc_feats, att_feats,
                                              att_masks, data, gen_result, opt)
            loss = rl_crit(sample_logprobs, gen_result.data,
                           torch.from_numpy(reward).float().cuda())

        # Backward pass
        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        train_loss = loss.item()
        torch.cuda.synchronize()

        # Print
        total_time = time.time() - start
        if iteration % opt.print_freq == 1:
            print('Read data:', time.time() - start)
            if not sc_flag:
                print("iter {} (epoch {}), train_loss = {:.3f}, data_time = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, train_loss, data_time, total_time))
            else:
                print("iter {} (epoch {}), avg_reward = {:.3f}, data_time = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, np.mean(reward[:,0]), data_time, total_time))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            add_summary_value(tb_summary_writer, 'train_loss', train_loss,
                              iteration)
            add_summary_value(tb_summary_writer, 'learning_rate',
                              opt.current_lr, iteration)
            add_summary_value(tb_summary_writer, 'scheduled_sampling_prob',
                              model.ss_prob, iteration)
            if sc_flag:
                add_summary_value(tb_summary_writer, 'avg_reward',
                                  np.mean(reward[:, 0]), iteration)
            loss_history[iteration] = train_loss if not sc_flag else np.mean(
                reward[:, 0])
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # Validate and save model
        if (iteration % opt.save_checkpoint_every == 0):

            # Evaluate model
            eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = eval_utils.eval_split(
                dp_model, crit, loader, eval_kwargs)

            # Write validation result into summary
            add_summary_value(tb_summary_writer, 'validation loss', val_loss,
                              iteration)
            if lang_stats is not None:
                for k, v in lang_stats.items():
                    add_summary_value(tb_summary_writer, k, v, iteration)
            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Our metric is CIDEr if available, otherwise validation loss
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            # Save model in checkpoint path
            best_flag = False
            if best_val_score is None or current_score > best_val_score:
                best_val_score = current_score
                best_flag = True
            checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth')
            torch.save(model.state_dict(), checkpoint_path)
            print("model saved to {}".format(checkpoint_path))
            optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth')
            torch.save(optimizer.state_dict(), optimizer_path)

            # Dump miscalleous informations
            infos['iter'] = iteration
            infos['epoch'] = epoch
            infos['iterators'] = loader.iterators
            infos['split_ix'] = loader.split_ix
            infos['best_val_score'] = best_val_score
            infos['opt'] = opt
            infos['vocab'] = loader.get_vocab()
            histories['val_result_history'] = val_result_history
            histories['loss_history'] = loss_history
            histories['lr_history'] = lr_history
            histories['ss_prob_history'] = ss_prob_history
            with open(
                    os.path.join(opt.checkpoint_path,
                                 'infos_' + opt.id + '.pkl'), 'wb') as f:
                cPickle.dump(infos, f)
            with open(
                    os.path.join(opt.checkpoint_path,
                                 'histories_' + opt.id + '.pkl'), 'wb') as f:
                cPickle.dump(histories, f)

            # Save model to unique file if new best model
            if best_flag:
                model_fname = 'model-best-i{:05d}-score{:.4f}.pth'.format(
                    iteration, best_val_score)
                infos_fname = 'model-best-i{:05d}-infos.pkl'.format(iteration)
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               model_fname)
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                with open(os.path.join(opt.checkpoint_path, infos_fname),
                          'wb') as f:
                    cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
Ejemplo n.º 30
0
def train(opt):
    logger = initialize_logger(os.path.join(opt.checkpoint_path, 'train.log'))
    print = logger.info

    if opt.use_box:
        opt.att_feat_size = opt.att_feat_size + 5
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    # Print out the option variables
    print("*" * 20)
    for k, v in opt.__dict__.items():
        print("%r: %r" % (k, v))
    print("*" * 20)

    infos = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from, 'infos.json'), 'r') as f:
            infos = json.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)
    else:
        best_val_score = None

    model = models.setup(opt).to(device)
    dp_model = torch.nn.DataParallel(model)

    update_lr_flag = True
    # Assure in training mode
    dp_model.train()

    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()

    optimizer = utils.build_optimizer(model.parameters(), opt)
    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(
            os.path.join(opt.start_from, "optimizer.pth")):
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    start_time = time.time()
    while True:
        if update_lr_flag:
            # Assign the learning rate
            if 0 <= opt.learning_rate_decay_start < epoch:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
                utils.set_lr(optimizer, opt.current_lr)  # set the decayed rate
            else:
                opt.current_lr = opt.learning_rate
            # Assign the scheduled sampling prob
            if 0 <= opt.scheduled_sampling_start < epoch:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer()
            else:
                sc_flag = False

            update_lr_flag = False

        # Load data from train split (0)
        batch_data = loader.get_batch('train')
        torch.cuda.synchronize(device)

        tmp = [
            batch_data['fc_feats'], batch_data['att_feats'],
            batch_data['labels'], batch_data['masks'], batch_data['att_masks']
        ]
        tmp = [_ if _ is None else torch.from_numpy(_).to(device) for _ in tmp]
        fc_feats, att_feats, labels, masks, att_masks = tmp

        optimizer.zero_grad()
        if not sc_flag:
            outputs = dp_model(fc_feats, att_feats, labels, att_masks)
            loss = crit(outputs, labels[:, 1:], masks[:, 1:])
        else:
            gen_result, sample_logprobs = dp_model(fc_feats,
                                                   att_feats,
                                                   att_masks,
                                                   opt={'sample_max': 0},
                                                   mode='sample')
            reward = get_self_critical_reward(dp_model, fc_feats, att_feats,
                                              att_masks, batch_data,
                                              gen_result, opt)
            loss = rl_crit(sample_logprobs, gen_result.data,
                           torch.from_numpy(reward).float().to(device))

        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        train_loss = loss.data
        torch.cuda.synchronize(device)

        # Update the iteration and epoch
        iteration += 1
        if batch_data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Print train loss or avg reward
        if iteration % opt.losses_print_every == 0:
            if not sc_flag:
                print(
                    "iter {} (epoch {}), loss = {:.3f}, time = {:.3f}".format(
                        iteration, epoch, loss.item(),
                        time.time() - start_time))
            else:
                print("iter {} (epoch {}), avg_reward = {:.3f}, time = {:.3f}".
                      format(iteration, epoch, np.mean(reward[:, 0]),
                             time.time() - start_time))
            start_time = time.time()

        # make evaluation on validation set, and save model
        if (opt.save_checkpoint_every > 0 and iteration % opt.save_checkpoint_every == 0)\
                or (opt.save_checkpoint_every <= 0 and update_lr_flag):
            # eval model
            eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = eval_utils.simple_eval_split(
                dp_model, loader, eval_kwargs)

            # Save model if is improving on validation result
            if not os.path.exists(opt.checkpoint_path):
                os.makedirs(opt.checkpoint_path)

            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            best_flag = False

            if best_val_score is None or current_score > best_val_score:
                best_val_score = current_score
                best_flag = True
            checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth')
            torch.save(model.state_dict(), checkpoint_path)
            print("model saved to {}".format(checkpoint_path))
            optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth')
            torch.save(optimizer.state_dict(), optimizer_path)

            # Dump miscellaneous information
            infos['iter'] = iteration
            infos['epoch'] = epoch
            infos['iterators'] = loader.iterators
            infos['split_ix'] = loader.split_ix
            infos['best_val_score'] = best_val_score
            infos['opt'] = vars(opt)
            infos['vocab'] = loader.get_vocab()

            with open(os.path.join(opt.checkpoint_path, 'infos.json'),
                      'w') as f:
                json.dump(infos, f, sort_keys=True, indent=4)

            if best_flag:
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model-best.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                with open(os.path.join(opt.checkpoint_path, 'infos-best.json'),
                          'w') as f:
                    json.dump(infos, f, sort_keys=True, indent=4)

            # Stop if reaching max epochs
            if opt.max_epochs != -1 and epoch >= opt.max_epochs:
                break