예제 #1
0
def train(opt):
    assert opt.annfile is not None and len(opt.annfile) > 0

    print('Checkpoint path is ' + opt.checkpoint_path)
    print('This program is using GPU ' +
          str(os.environ['CUDA_VISIBLE_DEVICES']))
    # Deal with feature things before anything
    opt.use_att = utils.if_use_att(opt.caption_model)
    if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5

    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        if opt.load_best:
            info_path = os.path.join(opt.start_from,
                                     'infos_' + opt.id + '-best.pkl')
        else:
            info_path = os.path.join(opt.start_from,
                                     'infos_' + opt.id + '.pkl')
        with open(info_path) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)
    if opt.learning_rate_decay_start is None:
        opt.learning_rate_decay_start = infos.get(
            'opt', None).learning_rate_decay_start
    # if opt.load_best:
    #     opt.self_critical_after = epoch
    elif opt.learning_rate_decay_start == -1 and opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
        opt.learning_rate_decay_start = epoch

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    # loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)
        best_val_score_ave_model = infos.get('best_val_score_ave_model', None)

    model = models.setup(opt).cuda()
    dp_model = torch.nn.DataParallel(model)

    update_lr_flag = True
    # Assure in training mode
    dp_model.train()

    crit = utils.LanguageModelCriterion(opt.XE_eps)
    rl_crit = utils.RewardCriterion()

    # build_optimizer
    optimizer = build_optimizer(model, opt)

    # Load the optimizer
    if opt.load_opti and vars(opt).get(
            'start_from',
            None) is not None and opt.load_best == 0 and os.path.isfile(
                os.path.join(opt.start_from, "optimizer.pth")):
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    # initialize the running average of parameters
    avg_param = deepcopy(list(p.data for p in model.parameters()))

    # make evaluation using original model
    best_val_score, histories, infos = eva_original_model(
        best_val_score, crit, epoch, histories, infos, iteration, loader,
        loss_history, lr_history, model, opt, optimizer, ss_prob_history,
        tb_summary_writer, val_result_history)

    while True:
        if update_lr_flag:
            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                if opt.lr_decay == 'exp':
                    frac = (epoch - opt.learning_rate_decay_start
                            ) // opt.learning_rate_decay_every
                    decay_factor = opt.learning_rate_decay_rate**frac
                    opt.current_lr = opt.learning_rate * decay_factor
                elif opt.lr_decay == 'cosine':
                    lr_epoch = min((epoch - opt.learning_rate_decay_start),
                                   opt.lr_max_epoch)
                    cosine_decay = 0.5 * (
                        1 + math.cos(math.pi * lr_epoch / opt.lr_max_epoch))
                    decay_factor = (1 - opt.lr_cosine_decay_base
                                    ) * cosine_decay + opt.lr_cosine_decay_base
                    opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate

            lr = [opt.current_lr]
            if opt.att_normalize_method is not None and '6' in opt.att_normalize_method:
                lr = [opt.current_lr, opt.lr_ratio * opt.current_lr]

            utils.set_lr(optimizer, lr)
            print('learning rate is: ' + str(lr))

            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            update_lr_flag = False

        # Update the iteration
        iteration += 1

        # Load data from train split (0)
        data = loader.get_batch(opt.train_split)

        torch.cuda.synchronize()
        start = time.time()

        tmp = [
            data['fc_feats'], data['att_feats'], data['labels'], data['masks'],
            data['att_masks']
        ]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        fc_feats, att_feats, labels, masks, att_masks = tmp

        optimizer.zero_grad()
        if not sc_flag:
            output = dp_model(fc_feats, att_feats, labels, att_masks)
            # calculate loss
            loss = crit(output[0], labels[:, 1:], masks[:, 1:])

            # add some middle variable histogram
            if iteration % (4 * opt.losses_log_every) == 0:
                outputs = [
                    _.data.cpu().numpy() if _ is not None else None
                    for _ in output
                ]
                variables_histogram(data, iteration, outputs,
                                    tb_summary_writer, opt)

        else:
            gen_result, sample_logprobs = dp_model(fc_feats,
                                                   att_feats,
                                                   att_masks,
                                                   opt={'sample_max': 0},
                                                   mode='sample')
            reward = get_self_critical_reward(dp_model, fc_feats, att_feats,
                                              att_masks, data, gen_result, opt)
            loss = rl_crit(sample_logprobs, gen_result.data,
                           torch.from_numpy(reward).float().cuda())

        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        # grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_max_norm)
        # add_summary_value(tb_summary_writer, 'grad_L2_norm', grad_norm, iteration)

        optimizer.step()
        train_loss = loss.item()
        torch.cuda.synchronize()
        end = time.time()

        # compute the running average of parameters
        for p, avg_p in zip(model.parameters(), avg_param):
            avg_p.mul_(opt.beta).add_((1.0 - opt.beta), p.data)

        if iteration % 10 == 0:
            if not sc_flag:
                print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, train_loss, end - start))
            else:
                print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, np.mean(reward[:,0]), end - start))

        # Update the epoch
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            add_summary_value(tb_summary_writer, 'train_loss', train_loss,
                              iteration)
            add_summary_value(tb_summary_writer, 'learning_rate',
                              opt.current_lr, iteration)
            add_summary_value(tb_summary_writer, 'scheduled_sampling_prob',
                              model.ss_prob, iteration)
            if sc_flag:
                add_summary_value(tb_summary_writer, 'avg_reward',
                                  np.mean(reward[:, 0]), iteration)

            loss_history[iteration] = train_loss if not sc_flag else np.mean(
                reward[:, 0])
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        if opt.tensorboard_weights_grads and (iteration %
                                              (8 * opt.losses_log_every) == 0):
            # add weights histogram to tensorboard summary
            for name, param in model.named_parameters():
                if (opt.tensorboard_parameters_name is None or sum([
                        p_name in name
                        for p_name in opt.tensorboard_parameters_name
                ]) > 0) and param.grad is not None:
                    tb_summary_writer.add_histogram(
                        'Weights_' + name.replace('.', '/'), param, iteration)
                    tb_summary_writer.add_histogram(
                        'Grads_' + name.replace('.', '/'), param.grad,
                        iteration)

        if opt.tensorboard_buffers and (iteration %
                                        (opt.losses_log_every) == 0):
            for name, buffer in model.named_buffers():
                if (opt.tensorboard_buffers_name is None or sum([
                        p_name in name
                        for p_name in opt.tensorboard_buffers_name
                ]) > 0) and buffer is not None:
                    add_summary_value(tb_summary_writer,
                                      name.replace('.',
                                                   '/'), buffer, iteration)

        if opt.distance_sensitive_coefficient and iteration % (
                4 * opt.losses_log_every) == 0:
            print('The coefficient in intra_att_att_lstm is as follows:')
            print(
                model.core.intra_att_att_lstm.coefficient.data.cpu().tolist())
            print('The coefficient in intra_att_lang_lstm is as follows:')
            print(
                model.core.intra_att_lang_lstm.coefficient.data.cpu().tolist())
        if opt.distance_sensitive_bias and iteration % (
                4 * opt.losses_log_every) == 0:
            print('The bias in intra_att_att_lstm is as follows:')
            print(model.core.intra_att_att_lstm.bias.data.cpu().tolist())
            print('The bias in intra_att_lang_lstm is as follows:')
            print(model.core.intra_att_lang_lstm.bias.data.cpu().tolist())

        # make evaluation using original model
        if (iteration % opt.save_checkpoint_every == 0):
            best_val_score, histories, infos = eva_original_model(
                best_val_score, crit, epoch, histories, infos, iteration,
                loader, loss_history, lr_history, model, opt, optimizer,
                ss_prob_history, tb_summary_writer, val_result_history)

        # make evaluation with the averaged parameters model
        if iteration > opt.ave_threshold and (iteration %
                                              opt.save_checkpoint_every == 0):
            best_val_score_ave_model, infos = eva_ave_model(
                avg_param, best_val_score_ave_model, crit, infos, iteration,
                loader, model, opt, tb_summary_writer)

        # # Stop if reaching max epochs
        # if epoch >= opt.max_epochs and opt.max_epochs != -1:
        #     break

        if iteration >= opt.max_iter:
            break
예제 #2
0
def train(opt):
    opt.use_att = utils.if_use_att(opt.caption_model)
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tf_summary_writer = tf and tf.summary.FileWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from, 'infos_'+opt.id+'.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same=["caption_model", "rnn_type", "rnn_size", "num_layers"]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')):
            with open(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt)
    model.cuda()

    update_lr_flag = True
    # Assure in training mode
    model.train()

    crit = utils.LanguageModelCriterion()

    optimizer = optim.Adam(model.parameters(), lr=opt.learning_rate, weight_decay=opt.weight_decay)

    # Load the optimizer
    if vars(opt).get('start_from', None) is not None:
        optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    while True:
        if update_lr_flag:
                # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate  ** frac
                opt.current_lr = opt.learning_rate * decay_factor
                utils.set_lr(optimizer, opt.current_lr) # set the decayed rate
            else:
                opt.current_lr = opt.learning_rate
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob  * frac, opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob
            update_lr_flag = False
                
        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch('train')
        print('Read data:', time.time() - start)

        torch.cuda.synchronize()
        start = time.time()

        tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks']]
        tmp = [Variable(torch.from_numpy(_), requires_grad=False).cuda() for _ in tmp]
        fc_feats, att_feats, labels, masks = tmp
        
        optimizer.zero_grad()
        loss = crit(model(fc_feats, att_feats, labels), labels[:,1:], masks[:,1:])
        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        train_loss = loss.item()
        torch.cuda.synchronize()
        end = time.time()
        print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
            .format(iteration, epoch, train_loss, end - start))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            if tf is not None:
                add_summary_value(tf_summary_writer, 'train_loss', train_loss, iteration)
                add_summary_value(tf_summary_writer, 'learning_rate', opt.current_lr, iteration)
                add_summary_value(tf_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration)
                tf_summary_writer.flush()

            loss_history[iteration] = train_loss
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            # eval model
            eval_kwargs = {'split': 'val',
                            'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = eval_utils.eval_split(model, crit, loader, eval_kwargs)

            # Write validation result into summary
            if tf is not None:
                add_summary_value(tf_summary_writer, 'validation loss', val_loss, iteration)
                for k,v in lang_stats.items():
                    add_summary_value(tf_summary_writer, k, v, iteration)
                tf_summary_writer.flush()
            val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions}

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = - val_loss

            best_flag = False
            if True: # if true
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(os.path.join(opt.checkpoint_path, 'histories_'+opt.id+'.pkl'), 'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path, 'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'-best.pkl'), 'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
예제 #3
0
def train(opt):
    if vars(opt).get('start_from_en', None) is not None:
        opt.checkpoint_path_p = opt.start_from_en
        opt.id_p = opt.checkpoint_path_p.split('/')[-1]
        print('Point to folder: {}'.format(opt.checkpoint_path_p))
    else:
        opt.id_p = datetime.datetime.now().strftime(
            '%Y%m%d_%H%M%S') + '_' + opt.caption_model
        opt.checkpoint_path_p = os.path.join(opt.checkpoint_path_p, opt.id_p)

        if not os.path.exists(opt.checkpoint_path_p):
            os.makedirs(opt.checkpoint_path_p)
        print('Create folder: {}'.format(opt.checkpoint_path_p))

    # # Deal with feature things before anything
    # opt.use_att = utils.if_use_att(opt.caption_model)
    # if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5

    # loader = DataLoader_UP(opt)
    # opt.vocab_size = loader.vocab_size
    # if opt.use_rela == 1:
    #     opt.rela_dict_size = loader.rela_dict_size
    # opt.seq_length = loader.seq_length
    # use_rela = getattr(opt, 'use_rela', 0)

    try:
        tb_summary_writer = tf and tf.compat.v1.summary.FileWriter(
            opt.checkpoint_path_p)
    except:
        print('Set tensorboard error!')
        pdb.set_trace()

    infos = {}
    histories = {}
    if opt.start_from_en is not None or opt.use_pretrained_setting == 1:
        # open old infos and check if models are compatible
        # with open(os.path.join(opt.checkpoint_path_p, 'infos.pkl')) as f:
        #     infos = cPickle.load(f)
        #     saved_model_opt = infos['opt']
        #     need_be_same = ["caption_model", "rnn_type", "rnn_size", "num_layers"]
        #     for checkme in need_be_same:
        #         assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], "Command line argument and saved model disagree on '%s' " % checkme
        #
        #     # override and collect parameters
        #     if len(opt.input_fc_dir) == 0:
        #         opt.input_fc_dir = infos['opt'].input_fc_dir
        #         opt.input_att_dir = infos['opt'].input_att_dir
        #         opt.input_box_dir = infos['opt'].input_box_dir
        #         # opt.input_label_h5 = infos['opt'].input_label_h5
        #     if len(opt.input_json) == 0:
        #         opt.input_json = infos['opt'].input_json
        #     if opt.batch_size == 0:
        #         opt.batch_size = infos['opt'].batch_size
        #     if len(opt.id) == 0:
        #         opt.id = infos['opt'].id
        #         # opt.id = infos['opt'].id_p
        #
        #     ignore = ['checkpoint_path', "use_gfc", "use_isg", "ssg_dict_path", "input_json", "input_label_h5", "id",
        #               "batch_size", "start_from", "language_eval", "use_rela", "input_ssg_dir", "ssg_dict_path",
        #               "input_rela_dir", "use_spectral_norm", "beam_size", 'gpu', 'caption_model','use_att','max_epochs']
        #     beam_size = opt.beam_size
        #
        #     vocab = infos['vocab']  # ix -> word mapping
        #     opt.vocab = vocab
        #     opt.vocab_size = len(vocab)
        #     for k in vars(infos['opt']).keys():
        #         if k != 'model':
        #             if k not in ignore:
        #                 if k in vars(opt):
        #                     # assert vars(opt)[k] == vars(infos['opt'])[k], k + ' option not consistent'
        #                     vars(opt).update({k: vars(infos['opt'])[k]})
        #                     print (vars(opt)[k] == vars(infos['opt'])[k], k + ' option not consistent, will be copyed from pretrained model')
        #                 else:
        #                     vars(opt).update({k: vars(infos['opt'])[k]})  # copy over options from model
        #     opt.input_fc_dir = 'data/cocobu_fc'
        #     opt.p_flag = 0

        # Load infos
        # opt.infos_path=os.path.join(opt.checkpoint_path_p, 'infos.pkl')
        opt.infos_path = os.path.join('data/fc/infos.pkl')
        with open(opt.infos_path) as f:
            infos = cPickle.load(f)

        # override and collect parameters
        if len(opt.input_fc_dir) == 0:
            opt.input_fc_dir = infos['opt'].input_fc_dir
            opt.input_att_dir = infos['opt'].input_att_dir
            opt.input_box_dir = infos['opt'].input_box_dir
            # opt.input_label_h5 = infos['opt'].input_label_h5
        if len(opt.input_json) == 0:
            opt.input_json = infos['opt'].input_json
        if opt.batch_size == 0:
            opt.batch_size = infos['opt'].batch_size
        if len(opt.id) == 0:
            opt.id = infos['opt'].id
            # opt.id = infos['opt'].id_p

        ignore = [
            'checkpoint_path', "use_gfc", "use_isg", "ssg_dict_path",
            "input_json", "input_label_h5", "id", "batch_size", "start_from",
            "language_eval", "use_rela", "input_ssg_dir", "ssg_dict_path",
            "input_rela_dir", "use_spectral_norm", "beam_size", 'gpu',
            'caption_model', 'self_critical_after', 'save_checkpoint_every'
        ]
        beam_size = opt.beam_size
        for k in vars(infos['opt']).keys():
            if k != 'model':
                if k not in ignore:
                    if k in vars(opt):
                        if not vars(opt)[k] == vars(infos['opt'])[k]:
                            print(
                                k +
                                ' option not consistent, copyed from pretrained model'
                            )
                            vars(opt).update({k: vars(infos['opt'])[k]})
                        else:
                            vars(opt).update({
                                k: vars(infos['opt'])[k]
                            })  # copy over options from model

        vocab = infos['vocab']  # ix -> word mapping
        opt.vocab = vocab
        opt.vocab_size = len(vocab)
        opt.input_fc_dir = 'data/cocobu_fc'

        if os.path.isfile(os.path.join(opt.checkpoint_path_p,
                                       'histories.pkl')):
            with open(os.path.join(opt.checkpoint_path_p,
                                   'histories.pkl')) as f:
                histories = cPickle.load(f)

    # Create the Data Loader instance
    loader = DataLoader_UP(opt)
    if opt.use_rela == 1:
        opt.rela_dict_size = loader.rela_dict_size
    opt.seq_length = loader.seq_length
    use_rela = getattr(opt, 'use_rela', 0)
    # When eval using provided pretrained model, the vocab may be different from what you have in your cocotalk.json
    # So make sure to use the vocab in infos file.
    try:  # if use pretrained model
        loader.ix_to_word = infos['vocab']
    except:  # if train from scratch
        infos = json.load(open(opt.input_json))
        opt.ix_to_word = infos['ix_to_word']
        opt.vocab_size = len(opt.ix_to_word)

    # iteration = infos.get('iter', 0)
    # epoch = infos.get('epoch', 0)
    iteration = 0
    epoch = 0

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    # Setup the model
    try:
        opt.caption_model = opt.caption_model_zh
    except:
        opt.caption_model = opt.caption_model
    model = models.setup(opt).cuda()
    # dp_model = torch.nn.DataParallel(model)
    # dp_model = torch.nn.DataParallel(model, [0,2,3])
    dp_model = model

    update_lr_flag = True
    # Assure in training mode
    dp_model.train()
    parameters = model.named_children()
    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()

    optimizer = utils.build_optimizer(
        filter(lambda p: p.requires_grad, model.parameters()), opt)

    optimizer.zero_grad()
    accumulate_iter = 0
    train_loss = 0
    reward = np.zeros([1, 1])

    while True:
        if update_lr_flag:
            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate
            utils.set_lr(optimizer, opt.current_lr)
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            update_lr_flag = False

        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch(opt.train_split)
        # print('Read data:', time.time() - start)

        torch.cuda.synchronize()
        start = time.time()

        fc_feats = None
        att_feats = None
        att_masks = None
        ssg_data = None
        rela_data = None

        if getattr(opt, 'use_ssg', 0) == 1:
            if getattr(opt, 'use_isg', 0) == 1:
                tmp = [
                    data['fc_feats'], data['labels'], data['masks'],
                    data['att_feats'], data['att_masks'],
                    data['isg_rela_matrix'], data['isg_rela_masks'],
                    data['isg_obj'], data['isg_obj_masks'], data['isg_attr'],
                    data['isg_attr_masks'], data['ssg_rela_matrix'],
                    data['ssg_rela_masks'], data['ssg_obj'],
                    data['ssg_obj_masks'], data['ssg_attr'],
                    data['ssg_attr_masks']
                ]

                tmp = [
                    _ if _ is None else torch.from_numpy(_).cuda() for _ in tmp
                ]
                fc_feats, labels, masks, att_feats, att_masks, \
                isg_rela_matrix, isg_rela_masks, isg_obj, isg_obj_masks, isg_attr, isg_attr_masks, \
                ssg_rela_matrix, ssg_rela_masks, ssg_obj, ssg_obj_masks, ssg_attr, ssg_attr_masks = tmp

                # image graph domain
                isg_data = {}
                isg_data['att_feats'] = att_feats
                isg_data['att_masks'] = att_masks

                isg_data['isg_rela_matrix'] = isg_rela_matrix
                isg_data['isg_rela_masks'] = isg_rela_masks
                isg_data['isg_obj'] = isg_obj
                isg_data['isg_obj_masks'] = isg_obj_masks
                isg_data['isg_attr'] = isg_attr
                isg_data['isg_attr_masks'] = isg_attr_masks
                # text graph domain
                ssg_data = {}
                ssg_data['ssg_rela_matrix'] = ssg_rela_matrix
                ssg_data['ssg_rela_masks'] = ssg_rela_masks
                ssg_data['ssg_obj'] = ssg_obj
                ssg_data['ssg_obj_masks'] = ssg_obj_masks
                ssg_data['ssg_attr'] = ssg_attr
                ssg_data['ssg_attr_masks'] = ssg_attr_masks
            else:
                tmp = [
                    data['fc_feats'], data['ssg_rela_matrix'],
                    data['ssg_rela_masks'], data['ssg_obj'],
                    data['ssg_obj_masks'], data['ssg_attr'],
                    data['ssg_attr_masks'], data['labels'], data['masks']
                ]
                tmp = [
                    _ if _ is None else torch.from_numpy(_).cuda() for _ in tmp
                ]
                fc_feats, ssg_rela_matrix, ssg_rela_masks, ssg_obj, ssg_obj_masks, ssg_attr, ssg_attr_masks, labels, masks = tmp
                ssg_data = {}
                ssg_data['ssg_rela_matrix'] = ssg_rela_matrix
                ssg_data['ssg_rela_masks'] = ssg_rela_masks
                ssg_data['ssg_obj'] = ssg_obj
                ssg_data['ssg_obj_masks'] = ssg_obj_masks
                ssg_data['ssg_attr'] = ssg_attr

                isg_data = None
                ssg_data['ssg_attr_masks'] = ssg_attr_masks
        else:
            tmp = [
                data['fc_feats'], data['att_feats'], data['labels'],
                data['masks'], data['att_masks']
            ]
            tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
            fc_feats, att_feats, labels, masks, att_masks = tmp

        if not sc_flag:
            # loss = crit(dp_model(fc_feats, labels, isg_data, ssg_data), labels[:, 1:], masks[:, 1:])
            loss = crit(dp_model(fc_feats, att_feats, labels, att_masks),
                        labels[:, 1:], masks[:, 1:])
        else:
            gen_result, sample_logprobs = dp_model(fc_feats,
                                                   isg_data,
                                                   ssg_data,
                                                   opt={'sample_max': 0},
                                                   mode='sample')
            reward = get_self_critical_reward(dp_model, fc_feats, isg_data,
                                              ssg_data, data, gen_result, opt)
            loss = rl_crit(sample_logprobs, gen_result.data,
                           torch.from_numpy(reward).float().cuda())

        accumulate_iter = accumulate_iter + 1
        loss = loss / opt.accumulate_number
        loss.backward()
        if accumulate_iter % opt.accumulate_number == 0:
            utils.clip_gradient(optimizer, opt.grad_clip)
            optimizer.step()
            optimizer.zero_grad()
            iteration += 1
            accumulate_iter = 0
            train_loss = loss.item() * opt.accumulate_number
            end = time.time()

            if not sc_flag:
                print("{}/{}/{}|train_loss={:.3f}|time/batch={:.3f}" \
                      .format(opt.id_p, iteration, epoch, train_loss, end - start))
            else:
                print("{}/{}/{}|avg_reward={:.3f}|time/batch={:.3f}" \
                      .format(opt.id_p, iteration, epoch, np.mean(reward[:, 0]), end - start))

        torch.cuda.synchronize()

        # Update the iteration and epoch
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0) and (iteration != 0):
            add_summary_value(tb_summary_writer, 'train_loss', train_loss,
                              iteration)
            add_summary_value(tb_summary_writer, 'learning_rate',
                              opt.current_lr, iteration)
            add_summary_value(tb_summary_writer, 'scheduled_sampling_prob',
                              model.ss_prob, iteration)
            if sc_flag:
                add_summary_value(tb_summary_writer, 'avg_reward',
                                  np.mean(reward[:, 0]), iteration)

            loss_history[iteration] = train_loss if not sc_flag else np.mean(
                reward[:, 0])
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0) and (iteration != 0):
            # if (iteration % 100 == 0) and (iteration != 0):
            # eval model
            if use_rela:
                eval_kwargs = {
                    'split': 'val',
                    'dataset': opt.input_json,
                    'use_real': 1
                }
            else:
                eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = eval_utils.eval_split_fc(
                dp_model, crit, loader, eval_kwargs)

            # Write validation result into summary
            add_summary_value(tb_summary_writer, 'validation loss', val_loss,
                              iteration)
            if lang_stats is not None:
                for k, v in lang_stats.items():
                    add_summary_value(tb_summary_writer, k, v, iteration)
            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            best_flag = False
            if True:  # if true
                save_id = iteration / opt.save_checkpoint_every
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                checkpoint_path_p = os.path.join(opt.checkpoint_path_p,
                                                 'model.pth')
                torch.save(model.state_dict(), checkpoint_path_p)
                print("model saved to {}".format(checkpoint_path_p))
                optimizer_path = os.path.join(opt.checkpoint_path_p,
                                              'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(os.path.join(opt.checkpoint_path_p, 'infos.pkl'),
                          'wb') as f:
                    cPickle.dump(infos, f)
                with open(os.path.join(opt.checkpoint_path_p, 'histories.pkl'),
                          'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path_p = os.path.join(opt.checkpoint_path_p,
                                                     'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path_p)
                    print("model saved to {}".format(checkpoint_path_p))
                    with open(
                            os.path.join(opt.checkpoint_path_p,
                                         'infos-best.pkl'), 'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
예제 #4
0
def train(opt):

    # Load data
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    # Tensorboard summaries (they're great!)
    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    # Load pretrained model, info file, histories file
    infos = {}
    histories = {}
    if opt.start_from is not None:
        with open(os.path.join(opt.start_from,
                               'infos_' + opt.id + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = ["rnn_type", "rnn_size", "num_layers"]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme
        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl')) as f:
                histories = cPickle.load(f)
    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)
    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})
    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    # Create model
    model = models.setup(opt).cuda()
    dp_model = torch.nn.DataParallel(model)
    dp_model.train()

    # Loss function
    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()

    # Optimizer and learning rate adjustment flag
    optimizer = utils.build_optimizer(model.parameters(), opt)
    update_lr_flag = True

    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(
            os.path.join(opt.start_from, "optimizer.pth")):
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    # Training loop
    while True:

        # Update learning rate once per epoch
        if update_lr_flag:

            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate
            utils.set_lr(optimizer, opt.current_lr)

            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            update_lr_flag = False

        # Load data from train split (0)
        start = time.time()
        data = loader.get_batch('train')
        data_time = time.time() - start
        start = time.time()

        # Unpack data
        torch.cuda.synchronize()
        tmp = [
            data['fc_feats'], data['att_feats'], data['labels'], data['masks'],
            data['att_masks']
        ]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        fc_feats, att_feats, labels, masks, att_masks = tmp

        # Forward pass and loss
        optimizer.zero_grad()
        if not sc_flag:
            loss = crit(dp_model(fc_feats, att_feats, labels, att_masks),
                        labels[:, 1:], masks[:, 1:])
        else:
            gen_result, sample_logprobs = dp_model(fc_feats,
                                                   att_feats,
                                                   att_masks,
                                                   opt={'sample_max': 0},
                                                   mode='sample')
            reward = get_self_critical_reward(dp_model, fc_feats, att_feats,
                                              att_masks, data, gen_result, opt)
            loss = rl_crit(sample_logprobs, gen_result.data,
                           torch.from_numpy(reward).float().cuda())

        # Backward pass
        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        train_loss = loss.item()
        torch.cuda.synchronize()

        # Print
        total_time = time.time() - start
        if iteration % opt.print_freq == 1:
            print('Read data:', time.time() - start)
            if not sc_flag:
                print("iter {} (epoch {}), train_loss = {:.3f}, data_time = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, train_loss, data_time, total_time))
            else:
                print("iter {} (epoch {}), avg_reward = {:.3f}, data_time = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, np.mean(reward[:,0]), data_time, total_time))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            add_summary_value(tb_summary_writer, 'train_loss', train_loss,
                              iteration)
            add_summary_value(tb_summary_writer, 'learning_rate',
                              opt.current_lr, iteration)
            add_summary_value(tb_summary_writer, 'scheduled_sampling_prob',
                              model.ss_prob, iteration)
            if sc_flag:
                add_summary_value(tb_summary_writer, 'avg_reward',
                                  np.mean(reward[:, 0]), iteration)
            loss_history[iteration] = train_loss if not sc_flag else np.mean(
                reward[:, 0])
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # Validate and save model
        if (iteration % opt.save_checkpoint_every == 0):

            # Evaluate model
            eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = eval_utils.eval_split(
                dp_model, crit, loader, eval_kwargs)

            # Write validation result into summary
            add_summary_value(tb_summary_writer, 'validation loss', val_loss,
                              iteration)
            if lang_stats is not None:
                for k, v in lang_stats.items():
                    add_summary_value(tb_summary_writer, k, v, iteration)
            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Our metric is CIDEr if available, otherwise validation loss
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            # Save model in checkpoint path
            best_flag = False
            if best_val_score is None or current_score > best_val_score:
                best_val_score = current_score
                best_flag = True
            checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth')
            torch.save(model.state_dict(), checkpoint_path)
            print("model saved to {}".format(checkpoint_path))
            optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth')
            torch.save(optimizer.state_dict(), optimizer_path)

            # Dump miscalleous informations
            infos['iter'] = iteration
            infos['epoch'] = epoch
            infos['iterators'] = loader.iterators
            infos['split_ix'] = loader.split_ix
            infos['best_val_score'] = best_val_score
            infos['opt'] = opt
            infos['vocab'] = loader.get_vocab()
            histories['val_result_history'] = val_result_history
            histories['loss_history'] = loss_history
            histories['lr_history'] = lr_history
            histories['ss_prob_history'] = ss_prob_history
            with open(
                    os.path.join(opt.checkpoint_path,
                                 'infos_' + opt.id + '.pkl'), 'wb') as f:
                cPickle.dump(infos, f)
            with open(
                    os.path.join(opt.checkpoint_path,
                                 'histories_' + opt.id + '.pkl'), 'wb') as f:
                cPickle.dump(histories, f)

            # Save model to unique file if new best model
            if best_flag:
                model_fname = 'model-best-i{:05d}-score{:.4f}.pth'.format(
                    iteration, best_val_score)
                infos_fname = 'model-best-i{:05d}-infos.pkl'.format(iteration)
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               model_fname)
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                with open(os.path.join(opt.checkpoint_path, infos_fname),
                          'wb') as f:
                    cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
예제 #5
0
파일: model.py 프로젝트: benjamink85/VSRN
    def __init__(self, opt):
        # tutorials/09 - Image Captioning
        # Build Models
        self.grad_clip = opt.grad_clip
        self.img_enc = EncoderImage(opt.data_name,
                                    opt.img_dim,
                                    opt.embed_size,
                                    opt.finetune,
                                    opt.cnn_type,
                                    use_abs=opt.use_abs,
                                    no_imgnorm=opt.no_imgnorm)
        self.txt_enc = EncoderText(opt.vocab_size,
                                   opt.word_dim,
                                   opt.embed_size,
                                   opt.num_layers,
                                   use_abs=opt.use_abs)
        if torch.cuda.is_available():
            self.img_enc.cuda()
            self.txt_enc.cuda()
            cudnn.benchmark = True

        #####   captioning elements

        self.encoder = EncoderRNN(opt.dim_vid,
                                  opt.dim_hidden,
                                  bidirectional=opt.bidirectional,
                                  input_dropout_p=opt.input_dropout_p,
                                  rnn_cell=opt.rnn_type,
                                  rnn_dropout_p=opt.rnn_dropout_p)
        self.decoder = DecoderRNN(opt.vocab_size,
                                  opt.max_len,
                                  opt.dim_hidden,
                                  opt.dim_word,
                                  input_dropout_p=opt.input_dropout_p,
                                  rnn_cell=opt.rnn_type,
                                  rnn_dropout_p=opt.rnn_dropout_p,
                                  bidirectional=opt.bidirectional)

        self.caption_model = S2VTAttModel(self.encoder, self.decoder)

        self.crit = utils.LanguageModelCriterion()
        self.rl_crit = utils.RewardCriterion()

        if torch.cuda.is_available():
            self.caption_model.cuda()

        # Loss and Optimizer
        self.criterion = ContrastiveLoss(margin=opt.margin,
                                         measure=opt.measure,
                                         max_violation=opt.max_violation)
        params = list(self.txt_enc.parameters())
        params += list(self.img_enc.parameters())
        params += list(self.decoder.parameters())
        params += list(self.encoder.parameters())
        params += list(self.caption_model.parameters())

        if opt.finetune:
            params += list(self.img_enc.cnn.parameters())
        self.params = params

        self.optimizer = torch.optim.Adam(params, lr=opt.learning_rate)

        self.Eiters = 0
예제 #6
0
def train(rank, model, opt, optimizer=None):
    torch.manual_seed(opt.seed + rank)
    if opt.use_cuda:
        torch.cuda.manual_seed(opt.seed + rank)

    loader = DataLoader(opt)
    index_2_word = loader.get_vocab()
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    infos = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(
                os.path.join(opt.start_from,
                             'infos_' + opt.load_model_id + '.pkl'),
                'rb') as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            # for checkme in need_be_same:
            #     assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], "Command line argument and saved model disagree on '%s' " % checkme

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)
    val_result_history = infos.get('val_result_history', {})
    loss_history = infos.get('loss_history', {})
    lr_history = infos.get('lr_history', {})
    ss_prob_history = infos.get('ss_prob_history', {})

    sorted_lr = sorted(lr_history.items(), key=operator.itemgetter(1))
    if opt.load_lr and len(lr_history) > 0:
        opt.optim_rl_lr = sorted_lr[0][1] / opt.optim_rl_lr_ratio

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_image_id = infos.get('split_image_id', loader.split_image_id)

    entropy_reg = opt.entropy_reg
    best_val_score = 0
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    update_lr_flag = True

    if opt.caption_model == 'show_tell':
        crit = utils.LanguageModelCriterion(opt)
        rl_crit = utils.RewardCriterion(opt)

    elif opt.caption_model == 'review_net':
        crit = utils.ReviewNetCriterion(opt)
        rl_crit = utils.ReviewNetRewardCriterion(opt)

    elif opt.caption_model == 'recurrent_fusion_model':
        crit = utils.ReviewNetEnsembleCriterion(opt)
        rl_crit = utils.ReviewNetRewardCriterion(opt)

    else:
        raise Exception("caption_model not supported: {}".format(
            opt.caption_model))

    if optimizer is None:
        if opt.optim == 'adam':
            optimizer = optim.Adam(model.parameters(),
                                   lr=opt.optim_rl_lr,
                                   betas=(opt.optim_adam_beta1,
                                          opt.optim_adam_beta2),
                                   weight_decay=opt.optim_weight_decay)
        elif opt.optim == 'rmsprop':
            optimizer = optim.RMSprop(model.parameters(),
                                      lr=opt.optim_rl_lr,
                                      momentum=opt.optim_momentum,
                                      alpha=opt.optim_rmsprop_alpha,
                                      weight_decay=opt.weight_decay)
        elif opt.optim == 'sgd':
            optimizer = optim.SGD(model.parameters(),
                                  lr=opt.optim_rl_lr,
                                  momentum=opt.optim_momentum,
                                  weight_decay=opt.optim_weight_decay)
        elif opt.optim == 'adagrad':
            optimizer = optim.Adagrad(model.parameters(),
                                      lr=opt.optim_rl_lr,
                                      lr_decay=opt.optim_lr_decay,
                                      weight_decay=opt.optim_weight_decay)
        elif opt.optim == 'adadelta':
            optimizer = optim.Adadelta(model.parameters(),
                                       rho=opt.optim_rho,
                                       eps=opt.optim_epsilon,
                                       lr=opt.optim_rl_lr,
                                       weight_decay=opt.optim_weight_decay)
        else:
            raise Exception("optim not supported: {}".format(opt.feature_type))

        # Load the optimizer
        if opt.load_lr and vars(opt).get(
                'start_from', None) is not None and os.path.isfile(
                    os.path.join(opt.start_from,
                                 'optimizer_' + opt.load_model_id + '.pth')):
            optimizer.load_state_dict(
                torch.load(
                    os.path.join(opt.start_from,
                                 'optimizer_' + opt.load_model_id + '.pth')))
            utils.set_lr(optimizer, opt.optim_rl_lr)

    num_period_best = 0
    current_score = 0
    while True:
        if update_lr_flag:
            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.optim_rl_lr * decay_factor
                utils.set_lr(optimizer, opt.current_lr)  # set the decayed rate
            else:
                opt.current_lr = opt.optim_rl_lr
            update_lr_flag = False

        start = time.time()
        data = loader.get_batch('train')

        if opt.use_cuda:
            torch.cuda.synchronize()

        if opt.feature_type == 'feat_array':
            fc_feat_array = data['fc_feats_array']
            att_feat_array = data['att_feats_array']
            assert (len(fc_feat_array) == len(att_feat_array))
            for feat_id in range(len(fc_feat_array)):
                if opt.use_cuda:
                    fc_feat_array[feat_id] = Variable(
                        torch.from_numpy(fc_feat_array[feat_id]),
                        requires_grad=False).cuda()
                    att_feat_array[feat_id] = Variable(
                        torch.from_numpy(att_feat_array[feat_id]),
                        requires_grad=False).cuda()
                else:
                    fc_feat_array[feat_id] = Variable(torch.from_numpy(
                        fc_feat_array[feat_id]),
                                                      requires_grad=False)
                    att_feat_array[feat_id] = Variable(torch.from_numpy(
                        att_feat_array[feat_id]),
                                                       requires_grad=False)

            tmp = [data['labels'], data['masks'], data['top_words']]
            if opt.use_cuda:
                tmp = [
                    Variable(torch.from_numpy(_), requires_grad=False).cuda()
                    for _ in tmp
                ]
            else:
                tmp = [
                    Variable(torch.from_numpy(_), requires_grad=False)
                    for _ in tmp
                ]
            labels, masks, top_words = tmp

        else:
            tmp = [
                data['fc_feats'], data['att_feats'], data['labels'],
                data['masks'], data['top_words']
            ]
            if opt.use_cuda:
                tmp = [
                    Variable(torch.from_numpy(_), requires_grad=False).cuda()
                    for _ in tmp
                ]
            else:
                tmp = [
                    Variable(torch.from_numpy(_), requires_grad=False)
                    for _ in tmp
                ]
            fc_feats, att_feats, labels, masks, top_words = tmp

        optimizer.zero_grad()

        if opt.caption_model == 'show_tell':
            gen_result, sample_logprobs, logprobs_all = model.sample(
                fc_feats, att_feats, {'sample_max': 0})
            rewards = get_rewards.get_self_critical_reward(
                index_2_word, model, fc_feats, att_feats, data, gen_result,
                opt)
            sample_logprobs_old = Variable(sample_logprobs.data,
                                           requires_grad=False)

            if opt.use_cuda:
                loss = rl_crit(
                    sample_logprobs, gen_result,
                    Variable(torch.from_numpy(rewards).float().cuda(),
                             requires_grad=False), logprobs_all, entropy_reg,
                    sample_logprobs_old, opt)
            else:
                loss = rl_crit(
                    sample_logprobs, gen_result,
                    Variable(torch.from_numpy(rewards).float(),
                             requires_grad=False), logprobs_all, entropy_reg,
                    sample_logprobs_old, opt)

        elif opt.caption_model == 'recurrent_fusion_model':
            gen_result, sample_logprobs, logprobs_all, top_pred = model.sample(
                fc_feat_array, att_feat_array, {'sample_max': 0})
            rewards = get_rewards.get_self_critical_reward_feat_array(
                index_2_word, model, fc_feat_array, att_feat_array, data,
                gen_result, opt)
            sample_logprobs_old = Variable(sample_logprobs.data,
                                           requires_grad=False)

            if opt.use_cuda:
                loss = rl_crit(
                    sample_logprobs, gen_result,
                    Variable(torch.from_numpy(rewards).float().cuda(),
                             requires_grad=False), logprobs_all, entropy_reg,
                    top_pred, top_words, opt.reason_weight,
                    sample_logprobs_old, opt)
            else:
                loss = rl_crit(
                    sample_logprobs, gen_result,
                    Variable(torch.from_numpy(rewards).float(),
                             requires_grad=False), logprobs_all, entropy_reg,
                    top_pred, top_words, opt.reason_weight,
                    sample_logprobs_old, opt)

        elif opt.caption_model == 'review_net':
            gen_result, sample_logprobs, logprobs_all, top_pred = model.sample(
                fc_feats, att_feats, {'sample_max': 0})
            rewards = get_rewards.get_self_critical_reward(
                index_2_word, model, fc_feats, att_feats, data, gen_result,
                opt)
            sample_logprobs_old = Variable(sample_logprobs.data,
                                           requires_grad=False)

            if opt.use_cuda:
                loss = rl_crit(
                    sample_logprobs, gen_result,
                    Variable(torch.from_numpy(rewards).float().cuda(),
                             requires_grad=False), logprobs_all, entropy_reg,
                    top_pred, top_words, opt.reason_weight,
                    sample_logprobs_old, opt)
            else:
                loss = rl_crit(
                    sample_logprobs, gen_result,
                    Variable(torch.from_numpy(rewards).float(),
                             requires_grad=False), logprobs_all, entropy_reg,
                    top_pred, top_words, opt.reason_weight,
                    sample_logprobs_old, opt)

        else:
            raise Exception("caption_model not supported: {}".format(
                opt.caption_model))

        if opt.use_ppo and opt.ppo_k > 0:
            loss.backward(retain_graph=True)
        else:
            loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()

        train_loss = loss.data[0]
        if opt.use_ppo:
            for i in range(opt.ppo_k):
                print(i)
                optimizer.zero_grad()
                loss.backward(retain_graph=True)
                utils.clip_gradient(optimizer, opt.grad_clip)
                optimizer.step()

        if opt.use_cuda:
            torch.cuda.synchronize()
        end = time.time()

        # Update the iteration and epoch
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if iteration % opt.losses_log_every == 0:
            loss_history[iteration] = np.mean(rewards[:, 0])
            lr_history[iteration] = opt.current_lr

        # make evaluation on validation set, and save model
        if iteration % opt.save_checkpoint_every == 0:
            # eval model
            eval_kwargs = {
                'eval_split': 'val',
                'dataset': opt.input_json,
                'caption_model': opt.caption_model,
                'reason_weight': opt.reason_weight,
                'guiding_l1_penality': opt.guiding_l1_penality,
                'use_cuda': opt.use_cuda,
                'feature_type': opt.feature_type,
                'rank': rank
            }
            eval_kwargs.update(vars(opt))
            eval_kwargs['eval_split'] = 'val'
            val_loss, predictions, lang_stats = eval_utils.eval_split(
                model, crit, loader, eval_kwargs)

            # Write validation result into summary
            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }
            print("iter {} (epoch {}), val_loss = {:.3f}".format(
                iteration, epoch, val_loss))

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            best_flag = False
            if best_val_score is None or current_score > best_val_score:
                best_val_score = current_score
                best_flag = True
                num_period_best = 1
            else:
                num_period_best = num_period_best + 1

            # Dump miscalleous informations
            infos['iter'] = iteration
            infos['epoch'] = epoch
            infos['iterators'] = loader.iterators
            infos['split_image_id'] = loader.split_image_id
            infos['best_val_score'] = best_val_score
            infos['opt'] = opt
            infos['val_result_history'] = val_result_history
            infos['loss_history'] = loss_history
            infos['lr_history'] = lr_history
            infos['ss_prob_history'] = ss_prob_history
            infos['vocab'] = loader.get_vocab()
            with open(
                    os.path.join(
                        opt.checkpoint_path,
                        'rl_infos_' + opt.id + '_' + str(rank) + '.pkl'),
                    'wb') as f:
                cPickle.dump(infos, f)

            if best_flag:
                checkpoint_path = os.path.join(
                    opt.checkpoint_path,
                    'rl_model_' + opt.id + '_' + str(rank) + '-best.pth')
                torch.save(model.state_dict(), checkpoint_path)
                optimizer_path = os.path.join(
                    opt.checkpoint_path,
                    'rl_optimizer_' + opt.id + '_' + str(rank) + '-best.pth')
                torch.save(optimizer.state_dict(), optimizer_path)
                print("model saved to {}".format(checkpoint_path))
                with open(
                        os.path.join(
                            opt.checkpoint_path, 'rl_infos_' + opt.id + '_' +
                            str(rank) + '-best.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)

            if num_period_best >= opt.num_eval_no_improve:
                print('no improvement, exit')
                sys.exit()
        print("rank {}, iter {}, (epoch {}), avg_reward: {:.3f}, train_loss: {}, learning rate: {}, current cider: {:.3f}, best cider: {:.3f}, time: {:.3f}" \
              .format(rank, iteration, epoch, np.mean(rewards[:, 0]), train_loss, opt.current_lr, current_score, best_val_score, (end-start)))

        iteration += 1
        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
예제 #7
0
    def _add_losses(self, sigma_rpn=3.0):
        # RPN, class loss
        rpn_cls_score = self._predictions['rpn_cls_score_reshape'].view(-1, 2)
        rpn_label = self._anchor_targets['rpn_labels'].view(-1)
        rpn_select = Variable((rpn_label.data != -1).nonzero().view(-1))
        rpn_cls_score = rpn_cls_score.index_select(
            0, rpn_select).contiguous().view(-1, 2)
        rpn_label = rpn_label.index_select(0, rpn_select).contiguous().view(-1)
        rpn_cross_entropy = F.cross_entropy(rpn_cls_score, rpn_label)

        # RPN, bbox loss
        rpn_bbox_pred = self._predictions['rpn_bbox_pred']
        rpn_bbox_targets = self._anchor_targets['rpn_bbox_targets']
        rpn_bbox_inside_weights = self._anchor_targets[
            'rpn_bbox_inside_weights']
        rpn_bbox_outside_weights = self._anchor_targets[
            'rpn_bbox_outside_weights']
        rpn_loss_box = self._smooth_l1_loss(rpn_bbox_pred,
                                            rpn_bbox_targets,
                                            rpn_bbox_inside_weights,
                                            rpn_bbox_outside_weights,
                                            sigma=sigma_rpn,
                                            dim=[1, 2, 3])

        # RCNN, class loss
        cls_score = self._predictions['cls_score']
        label = self._proposal_targets['labels'].view(
            -1)  # (n, ) ranging [0, num_class)
        cross_entropy = F.cross_entropy(cls_score.view(-1, self._num_classes),
                                        label)

        # RCNN, bbox loss
        bbox_pred = self._predictions['bbox_pred']
        bbox_targets = self._proposal_targets['bbox_targets']
        bbox_inside_weights = self._proposal_targets['bbox_inside_weights']
        bbox_outside_weights = self._proposal_targets['bbox_outside_weights']
        loss_box = self._smooth_l1_loss(bbox_pred, bbox_targets,
                                        bbox_inside_weights,
                                        bbox_outside_weights)

        # MASK, mask loss, only regress fg rois
        mask_targets = self._proposal_targets[
            'mask_targets']  # (num_fg, 14, 14)
        mask_score = self._predictions[
            'mask_score']  # (num_fg, num_classes, 14, 14)
        assert mask_targets.size(0) == mask_score.size(0)
        num_fg = mask_targets.size(0)
        fg_label = label[:num_fg]  # (num_fg, )
        fg_label = fg_label.view(num_fg, 1, 1,
                                 1).expand(num_fg, 1, cfg.MASK_SIZE,
                                           cfg.MASK_SIZE)
        mask_score = torch.gather(mask_score, 1,
                                  fg_label)  # (num_fg, 1, 14, 14)
        mask_score = mask_score.squeeze(1)  # (num_fg, 14, 14)
        loss_mask = F.binary_cross_entropy_with_logits(mask_score,
                                                       mask_targets)

        # Caption model, caption loss
        feats_all = self._head_to_tail(self._predictions['net_conv'])

        fc_feats_all = feats_all.mean(3).mean(2)
        att_feats_all = F.adaptive_avg_pool2d(feats_all, [14, 14]).permute(
            0, 2, 3, 1).contiguous()

        tmp = [self._gt_masks, self._cap_labels, self._cap_masks]
        tmp = [
            Variable(torch.from_numpy(_), requires_grad=False).cuda()
            for _ in tmp
        ]
        gt_masks, cap_labels, cap_masks = tmp
        gt_masks = gt_masks.unsqueeze(1).float()

        gt_masks = F.adaptive_avg_pool2d(
            gt_masks, [feats_all.size(2), feats_all.size(3)])
        gt_masks[gt_masks < 0.5] = 0
        gt_masks[gt_masks >= 0.5] = 1

        feats_mask = feats_all * gt_masks
        fc_feats_mask = feats_mask.mean(3).mean(2)
        att_feats_mask = F.adaptive_avg_pool2d(feats_mask,
                                               [14, 14]).permute(0, 2, 3, 1)

        fc_feats = torch.cat((fc_feats_all, fc_feats_mask), 1)
        att_feats = torch.cat((att_feats_all, att_feats_mask), 3)

        crit = utils.LanguageModelCriterion()

        loss_caption = crit(
            self.caption_model(fc_feats, att_feats, cap_labels),
            cap_labels[:, 1:], cap_masks[:, 1:])

        self._losses['cross_entropy'] = cross_entropy
        self._losses['loss_box'] = loss_box
        self._losses['rpn_cross_entropy'] = rpn_cross_entropy
        self._losses['rpn_loss_box'] = rpn_loss_box
        self._losses['loss_mask'] = loss_mask
        self._losses['loss_caption'] = loss_caption

        loss = cross_entropy + loss_box + rpn_cross_entropy + rpn_loss_box + loss_mask + self._cap_loss_weight * loss_caption
        self._losses['total_loss'] = loss

        for k in self._losses.keys():
            self._event_summaries[k] = self._losses[k]

        return loss
def main():
    import opts
    import misc.utils as utils
    opt = opts.parse_opt()
    opt.caption_model = 'topdown'
    opt.batch_size = 10  #512#32*4*4
    opt.id = 'topdown'
    opt.learning_rate = 5e-4
    opt.learning_rate_decay_start = 0
    opt.scheduled_sampling_start = 0
    opt.save_checkpoint_every = 5000  #450#5000#11500
    opt.val_images_use = 5000
    opt.max_epochs = 50  #30
    opt.start_from = 'save/rt'  #"save" #None
    opt.language_eval = 1
    opt.input_json = 'data/meta_coco_en.json'
    opt.input_label_h5 = 'data/label_coco_en.h5'
    #    opt.input_json='data/coco_ccg.json' #'data/meta_coco_en.json'
    #    opt.input_label_h5='data/coco_ccg_label.h5' #'data/label_coco_en.h5'
    #    opt.input_fc_dir='/nlp/andyweizhao/self-critical.pytorch-master/data/cocotalk_fc'
    #    opt.input_att_dir='/nlp/andyweizhao/self-critical.pytorch-master/data/cocotalk_att'
    opt.finetune_cnn_after = 0
    opt.ccg = False
    opt.input_image_h5 = 'data/coco_image_512.h5'

    opt.use_att = utils.if_use_att(opt.caption_model)

    from dataloader import DataLoader  # just-in-time generated features
    loader = DataLoader(opt)

    #    from dataloader_fixcnn import DataLoader # load pre-processed features
    #    loader = DataLoader(opt)

    opt.vocab_size = loader.vocab_size
    opt.vocab_ccg_size = loader.vocab_ccg_size
    opt.seq_length = loader.seq_length

    import models
    model = models.setup(opt)
    cnn_model = utils.build_cnn(opt)
    cnn_model.cuda()
    model.cuda()

    data = loader.get_batch('train')
    images = data['images']

    #    _fc_feats_2048 = []
    #    _fc_feats_81 = []
    #    _att_feats = []
    #    for i in range(loader.batch_size):
    #        x = Variable(torch.from_numpy(images[i]), volatile=True).cuda()
    #        x = x.unsqueeze(0)
    #        att_feats, fc_feats_81 = cnn_model(x)
    #        fc_feats_2048 = att_feats.mean(3).mean(2).squeeze()
    #        att_feats = F.adaptive_avg_pool2d(att_feats,[14,14]).squeeze().permute(1, 2, 0)#(0, 2, 3, 1)
    #        _fc_feats_2048.append(fc_feats_2048)
    #        _fc_feats_81.append(fc_feats_81)
    #        _att_feats.append(att_feats)
    #    _fc_feats_2048 = torch.stack(_fc_feats_2048)
    #    _fc_feats_81 = torch.stack(_fc_feats_81)
    #    _att_feats = torch.stack(_att_feats)
    #    att_feats = _att_feats.unsqueeze(1).expand(*((_att_feats.size(0), loader.seq_per_img,) + \
    #                                                   _att_feats.size()[1:])).contiguous().view(*((_att_feats.size(0) * loader.seq_per_img,) + \
    #                                                   _att_feats.size()[1:]))
    #    fc_feats_2048 = _fc_feats_2048.unsqueeze(1).expand(*((_fc_feats_2048.size(0), loader.seq_per_img,) + \
    #                                                  _fc_feats_2048.size()[1:])).contiguous().view(*((_fc_feats_2048.size(0) * loader.seq_per_img,) + \
    #                                                  _fc_feats_2048.size()[1:]))
    #    fc_feats_81 = _fc_feats_81
    #
    #    att_feats = Variable(att_feats, requires_grad=False).cuda()
    #    Variable(fc_feats_81)

    crit = utils.LanguageModelCriterion()
    eval_kwargs = {'split': 'val', 'dataset': opt.input_json, 'verbose': True}
    eval_kwargs.update(vars(opt))
    val_loss, predictions, lang_stats = eval_split(cnn_model, model, crit,
                                                   loader, eval_kwargs, True)
# Setup the model
from models.AttEnsemble import AttEnsemble

_models = []
for i in range(len(model_infos)):
    model_infos[i]['opt'].start_from = None
    tmp = models.setup(model_infos[i]['opt'])
    tmp.load_state_dict(torch.load(model_paths[i]))
    tmp.cuda()
    tmp.eval()
    _models.append(tmp)

model = AttEnsemble(_models)
model.seq_length = opt.seq_length
model.eval()
crit = utils.LanguageModelCriterion(opt.XE_eps)

# Create the Data Loader instance
if len(opt.image_folder) == 0:
    loader = DataLoader(opt)
else:
    loader = DataLoaderRaw({
        'folder_path': opt.image_folder,
        'coco_json': opt.coco_json,
        'batch_size': opt.batch_size,
        'cnn_model': opt.cnn_model
    })
# When eval using provided pretrained model, the vocab may be different from what you have in your cocotalk.json
# So make sure to use the vocab in infos file.
loader.ix_to_word = infos['vocab']
예제 #10
0
def train(opt):

    # Load data
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    # Tensorboard summaries (they're great!)
    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    # Load pretrained model, info file, histories file
    infos = {}
    histories = {}

    if opt.start_from is not None:
        with open(os.path.join(opt.start_from,
                               'infos_' + opt.id + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = ["rnn_type", "rnn_size", "num_layers"]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme
        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl')) as f:
                histories = cPickle.load(f)
    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)
    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    #ss_prob_history = histories.get('ss_prob_history', {})
    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)
    encoder = MemoryAugmentedEncoder(
        3,
        0,
        attention_module=ScaledDotProductAttentionMemory,
        attention_module_kwargs={'m': 40})
    decoder = MeshedDecoder(8668, 180, 3, 0)
    models = Transformer(8667, encoder, decoder)
    # Create model
    model = models.cuda()
    lang_model = Seq2Seq().cuda()
    model.load_state_dict(torch.load('log_meshed/all2model20000.pth'))
    lang_model.load_state_dict(torch.load('language_model/langmodel06000.pth'))
    optimizer = utils.build_optimizer_adam(
        list(models.parameters()) + list(lang_model.parameters()), opt)

    #back_optimizer = utils.build_optimizer(back_model.parameters(), opt)
    update_lr_flag = True

    while True:

        # Update learning rate once per epoch
        if update_lr_flag:

            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate
            utils.set_lr(optimizer, opt.current_lr)

            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                #opt.ss_prob = min(opt.scheduled_sampling_increase_prob  * frac, opt.scheduled_sampling_max_prob)
                #model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            update_lr_flag = False

        # Load data from train split (0)
        start = time.time()
        data = loader.get_batch('train')
        data_time = time.time() - start
        start = time.time()

        # Unpack data
        torch.cuda.synchronize()
        tmp = [
            data['fc_feats'], data['att_feats'], data['labels'], data['dist'],
            data['masks'], data['att_masks']
        ]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        fc_feats, att_feats, labels, dist_label, masks, attmasks = tmp
        labels = labels.long()

        captions = utils.decode_sequence(loader.get_vocab(),
                                         labels.view(fc_feats.size(0), -1),
                                         None)
        captions_all = []
        for index, caption in enumerate(captions):
            caption = caption.replace('<start>',
                                      '').replace(' ,', '').replace('  ', ' ')
            captions_all.append(caption)

        nd_labels = labels
        batchsize = fc_feats.size(0)
        # Forward pass and loss
        d_steps = 1
        g_steps = 1
        beta = 0.2
        #print (orch.sum(labels!=0), torch.sum(masks!=0))
        if 1:
            if 1:
                model.train()
                optimizer.zero_grad()
                wordact, _ = model(att_feats, labels.view(batchsize, -1))
                wordact_t = wordact[:, :-1, :]
                wordact_t = wordact_t.contiguous().view(
                    wordact_t.size(0) * wordact_t.size(1), -1)
                labels_flat = labels.view(batchsize, -1)
                wordclass_v = labels_flat[:, 1:]
                wordclass_t = wordclass_v.contiguous().view(\
                 wordclass_v.size(0) * wordclass_v.size(1), -1)
                loss_xe = F.cross_entropy(wordact_t[ ...], \
                 wordclass_t[...].contiguous().view(-1))

                with torch.no_grad():
                    outcap, sampled_ids, sample_logprobs, x_all_langauge, outputs, log_probs_all = lang_model.sample(
                        labels.view(batchsize, -1).transpose(1, 0),
                        att_feats.transpose(1, 0), loader.get_vocab())
                logprobs_input, _ = model(att_feats, sampled_ids.cuda().long())
                log_probs = F.log_softmax(logprobs_input[:, :, :], 2)
                sample_logprobs_true = log_probs.gather(
                    2, sampled_ids[:, :].cuda().long().unsqueeze(2))
                with torch.no_grad():
                    reward, cider_sample, cider_greedy, caps_sample, caps = get_self_critical_reward(
                        batchsize, lang_model,
                        labels.view(batchsize, -1).transpose(1, 0),
                        att_feats.transpose(1, 0), outcap, captions_all,
                        loader, 180)
                    reward = torch.tensor(reward)
                    kl_div = F.kl_div(log_probs.squeeze().cuda().detach(),
                                      torch.exp(log_probs_all.transpose(
                                          1, 0)).cuda().detach(),
                                      reduce=False)
                    ratio_no = sample_logprobs_true.squeeze().cpu().double()
                    ratio_de = sample_logprobs.cpu().double()
                    ratio_no_f = torch.exp(ratio_no)
                    ratio_de_f = torch.exp(ratio_de)
                    ratio = (ratio_no_f /
                             ((1 - beta) * ratio_de_f + beta * ratio_no_f))
                    ratio = torch.clamp(ratio, min=0.96)
                    ratio_prod = ratio.prod(1)
                    reward = (
                        torch.tensor(reward).cuda()) - 0.05 * kl_div.mean()
                loss_rl1 = rl_crit(
                    ratio_prod.cuda().unsqueeze(1).detach() *
                    sample_logprobs_true.squeeze()[:, :-1],
                    sampled_ids[:, 1:].cpu(),
                    reward.float().cuda().detach())
                #writer.add_scalar('RL loss', loss_rl1 , iteration)
                #writer.add_scalar('TRIS ratio', ratio.mean(), iteration)
                #writer.add_scalar('XE_loss', loss_xe, iteration)
                #writer.add_scalar('KL_div', kl_div.mean(), iteration)
                lamb = 0.5
                train_loss = lamb * loss_rl1 + (1 - lamb) * loss_xe
                train_loss.backward()
                optimizer.step()

            if 1:
                if iteration % opt.print_freq == 1:
                    print('Read data:', time.time() - start)
                    if not sc_flag:
                        print(ratio.mean())
                        print(reward.mean())
                        print(kl_div.mean())
                        print("iter {} (epoch {}), train_loss = {:.4f}, xe_loss = {:.3f}, train_time = {:.3f}" \
                          .format(iteration, epoch, train_loss.item(), loss_xe, data_time))
                    else:
                        print("iter {} (epoch {}), avg_reward = {:.3f}, data_time = {:.3f}, time/batch = {:.3f}" \
                            .format(iteration, epoch, np.mean(reward[:,0]), data_time, total_time))

            # Update the iteration and epoch
            iteration += 1
            if data['bounds']['wrapped']:
                epoch += 1
                update_lr_flag = True

            # Write the training loss summary
            if (iteration % opt.losses_log_every == 0):
                add_summary_value(tb_summary_writer, 'train_loss', train_loss,
                                  iteration)
                add_summary_value(tb_summary_writer, 'learning_rate',
                                  opt.current_lr, iteration)
                #add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration)
                if sc_flag:
                    add_summary_value(tb_summary_writer, 'avg_reward',
                                      np.mean(reward[:, 0]), iteration)
                loss_history[
                    iteration] = train_loss if not sc_flag else np.mean(
                        reward[:, 0])
                lr_history[iteration] = opt.current_lr
                #ss_prob_history[iteration] = model.ss_prob

        # Validate and save model
            if (iteration % opt.save_checkpoint_every == 0):
                checkpoint_path = os.path.join(
                    opt.checkpoint_path,
                    'all2model{:05d}.pth'.format(iteration))
                torch.save(model.state_dict(), checkpoint_path)
                checkpoint_path = os.path.join(
                    opt.checkpoint_path,
                    'lang_model{:05d}.pth'.format(iteration))
                torch.save(lang_model.state_dict(), checkpoint_path)

                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)
                # Evaluate model
                #if 0:
                eval_kwargs = {'split': 'test', 'dataset': opt.input_json}
                eval_kwargs.update(vars(opt))
                crit = utils.LanguageModelCriterion()
                val_loss, predictions, lang_stats = eval_utils.eval_split(
                    model, crit, loader, eval_kwargs)
                # Write validation result into summary
                #add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration)
                #if lang_stats is not None:
                #    for k,v in lang_stats.items():
                #        add_summary_value(tb_summary_writer, k, v, iteration)
                #val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions}

                # Our metric is CIDEr if available, otherwise validation loss
                #if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
                #   else:
                #       current_score = - val_loss

                # Save model in checkpoint path
                best_flag = False
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model.pth')
                    torch.save(model.state_dict(), checkpoint_path)
예제 #11
0
파일: train.py 프로젝트: vplab-github/sacic
def train(opt):
    # Deal with feature things before anything
    opt.use_att = utils.if_use_att(opt.caption_model)
    if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5

    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from_path is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from_path,
                               'infos_' + opt.id + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(
                os.path.join(opt.start_from_path,
                             'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from_path,
                                 'histories_' + opt.id + '.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    #print(val_result_history.get(3000))
    #exit(0)
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt).cuda()
    no = sum(p.numel() for p in model.parameters())
    pytorch_total_params = sum(p.numel() for p in model.parameters()
                               if p.requires_grad)
    print("Trainable Params:" + str(pytorch_total_params))
    print("Total Params:" + str(no))
    #exit(0)
    dp_model = torch.nn.DataParallel(model)

    epoch_done = True
    # Assure in training mode
    dp_model.train()
    if (opt.use_obj_mcl_loss == 1):
        mcl_crit = utils.MultiLabelClassification()
    if opt.label_smoothing > 0:
        crit = utils.LabelSmoothing(smoothing=opt.label_smoothing)
    else:
        crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()

    if opt.noamopt:
        assert opt.caption_model == 'transformer', 'noamopt can only work with transformer'
        optimizer = utils.get_std_opt(model,
                                      factor=opt.noamopt_factor,
                                      warmup=opt.noamopt_warmup)
        optimizer._step = iteration
    elif opt.reduce_on_plateau:
        optimizer = utils.build_optimizer(model.parameters(), opt)
        optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3)
    else:
        optimizer = utils.build_optimizer(model.parameters(), opt)
    # Load the optimizer
    if vars(opt).get('start_from_path', None) is not None and os.path.isfile(
            os.path.join(opt.start_from_path, "optimizer.pth")):
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from_path, 'optimizer.pth')))

    time_epoch_start = time.time()
    data_time_sum = 0
    batch_time_sum = 0
    while True:
        if epoch_done:
            torch.cuda.synchronize()
            time_epoch_end = time.time()
            time_elapsed = (time_epoch_end - time_epoch_start)
            print('[DEBUG] Epoch Time: ' + str(time_elapsed))
            print('[DEBUG] Sum Data Time: ' + str(data_time_sum))
            print('[DEBUG] Sum Batch Time: ' + str(batch_time_sum))
            #if epoch==1:
            #    exit(0)
            if not opt.noamopt and not opt.reduce_on_plateau:
                # Assign the learning rate
                if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                    frac = (epoch - opt.learning_rate_decay_start
                            ) // opt.learning_rate_decay_every
                    decay_factor = opt.learning_rate_decay_rate**frac
                    opt.current_lr = opt.learning_rate * decay_factor
                else:
                    opt.current_lr = opt.learning_rate
                utils.set_lr(optimizer, opt.current_lr)  # set the decayed rate
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            epoch_done = False

        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch('train')
        print('Read data:', time.time() - start)
        data_time_sum += time.time() - start
        torch.cuda.synchronize()
        start = time.time()

        if (opt.use_obj_mcl_loss == 0):
            tmp = [
                data['fc_feats'], data['att_feats'], data['labels'],
                data['masks'], data['att_masks']
            ]
            tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
            fc_feats, att_feats, labels, masks, att_masks = tmp
        else:
            if opt.use_obj_att and opt.use_seg_feat:
                tmp = [
                    data['fc_feats'], data['att_feats'], data['obj_att_feats'],
                    data['seg_feat_feats'], data['labels'], data['masks'],
                    data['obj_labels'], data['att_masks'],
                    data['obj_att_masks'], data['seg_feat_masks']
                ]
                tmp = [
                    _ if _ is None else torch.from_numpy(_).cuda() for _ in tmp
                ]
                fc_feats, att_feats, obj_att_feats, seg_feat_feats, labels, masks, obj_labels, att_masks, obj_att_masks, seg_feat_masks = tmp
            elif not opt.use_obj_att and opt.use_seg_feat:
                tmp = [
                    data['fc_feats'], data['att_feats'],
                    data['seg_feat_feats'], data['labels'], data['masks'],
                    data['obj_labels'], data['att_masks'],
                    data['seg_feat_masks']
                ]
                tmp = [
                    _ if _ is None else torch.from_numpy(_).cuda() for _ in tmp
                ]
                fc_feats, att_feats, seg_feat_feats, labels, masks, obj_labels, att_masks, seg_feat_masks = tmp
            elif not opt.use_obj_att and not opt.use_seg_feat:
                tmp = [
                    data['fc_feats'], data['att_feats'], data['labels'],
                    data['masks'], data['obj_labels'], data['att_masks']
                ]
                tmp = [
                    _ if _ is None else torch.from_numpy(_).cuda() for _ in tmp
                ]
                fc_feats, att_feats, labels, masks, obj_labels, att_masks = tmp
            elif opt.use_obj_att and not opt.use_seg_feat:
                tmp = [
                    data['fc_feats'], data['att_feats'], data['obj_att_feats'],
                    data['labels'], data['masks'], data['obj_labels'],
                    data['att_masks'], data['obj_att_masks']
                ]
                tmp = [
                    _ if _ is None else torch.from_numpy(_).cuda() for _ in tmp
                ]
                fc_feats, att_feats, obj_att_feats, labels, masks, obj_labels, att_masks, obj_att_masks = tmp

        optimizer.zero_grad()
        if (opt.use_obj_mcl_loss == 0):
            if not sc_flag:
                loss = crit(dp_model(fc_feats, att_feats, labels, att_masks),
                            labels[:, 1:], masks[:, 1:])
            else:
                gen_result, sample_logprobs = dp_model(fc_feats,
                                                       att_feats,
                                                       att_masks,
                                                       opt={'sample_max': 0},
                                                       mode='sample')
                reward = get_self_critical_reward(dp_model, fc_feats,
                                                  att_feats, att_masks, data,
                                                  gen_result, opt)
                loss = rl_crit(sample_logprobs, gen_result.data,
                               torch.from_numpy(reward).float().cuda())
        else:
            if opt.use_obj_att and opt.use_seg_feat:
                if not sc_flag:
                    logits, out = dp_model(
                        fc_feats, [att_feats, obj_att_feats, seg_feat_feats],
                        labels, [att_masks, obj_att_masks, seg_feat_masks])
                    caption_loss = crit(logits, labels[:, 1:], masks[:, 1:])
                    obj_loss = mcl_crit(out, obj_labels)
                    loss = opt.lambda_caption * caption_loss + opt.lambda_obj * obj_loss
                    #loss = 0.1*caption_loss + obj_loss
                    #loss = caption_loss + 0 * obj_loss
                else:
                    gen_result, sample_logprobs = dp_model(
                        fc_feats,
                        att_feats,
                        att_masks,
                        opt={'sample_max': 0},
                        mode='sample')
                    reward = get_self_critical_reward(dp_model, fc_feats,
                                                      att_feats, att_masks,
                                                      data, gen_result, opt)
                    loss = rl_crit(sample_logprobs, gen_result.data,
                                   torch.from_numpy(reward).float().cuda())
            elif not opt.use_obj_att and opt.use_seg_feat:
                if not sc_flag:
                    logits, out = dp_model(fc_feats,
                                           [att_feats, seg_feat_feats], labels,
                                           [att_masks, seg_feat_masks])
                    caption_loss = crit(logits, labels[:, 1:], masks[:, 1:])
                    obj_loss = mcl_crit(out, obj_labels)
                    loss = opt.lambda_caption * caption_loss + opt.lambda_obj * obj_loss
                    #loss = caption_loss + 0 * obj_loss
                else:
                    gen_result, sample_logprobs = dp_model(
                        fc_feats,
                        att_feats,
                        att_masks,
                        opt={'sample_max': 0},
                        mode='sample')
                    reward = get_self_critical_reward(dp_model, fc_feats,
                                                      att_feats, att_masks,
                                                      data, gen_result, opt)
                    loss = rl_crit(sample_logprobs, gen_result.data,
                                   torch.from_numpy(reward).float().cuda())
            if not opt.use_obj_att and not opt.use_seg_feat:
                if not sc_flag:
                    logits, out = dp_model(fc_feats, att_feats, labels,
                                           att_masks)
                    caption_loss = crit(logits, labels[:, 1:], masks[:, 1:])
                    obj_loss = mcl_crit(out, obj_labels)
                    loss = opt.lambda_caption * caption_loss + opt.lambda_obj * obj_loss
                    #loss = caption_loss + 0 * obj_loss
                else:
                    gen_result, sample_logprobs = dp_model(
                        fc_feats,
                        att_feats,
                        att_masks,
                        opt={'sample_max': 0},
                        mode='sample')
                    reward = get_self_critical_reward(dp_model, fc_feats,
                                                      att_feats, att_masks,
                                                      data, gen_result, opt)
                    loss = rl_crit(sample_logprobs, gen_result.data,
                                   torch.from_numpy(reward).float().cuda())
            elif opt.use_obj_att and not opt.use_seg_feat:
                if not sc_flag:
                    logits, out = dp_model(fc_feats,
                                           [att_feats, obj_att_feats], labels,
                                           [att_masks, obj_att_masks])
                    caption_loss = crit(logits, labels[:, 1:], masks[:, 1:])
                    obj_loss = mcl_crit(out, obj_labels)
                    loss = 0.1 * caption_loss + obj_loss
                    #loss = caption_loss + 0 * obj_loss
                else:
                    gen_result, sample_logprobs = dp_model(
                        fc_feats,
                        att_feats,
                        att_masks,
                        opt={'sample_max': 0},
                        mode='sample')
                    reward = get_self_critical_reward(dp_model, fc_feats,
                                                      att_feats, att_masks,
                                                      data, gen_result, opt)
                    loss = rl_crit(sample_logprobs, gen_result.data,
                                   torch.from_numpy(reward).float().cuda())
        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        train_loss = loss.item()
        torch.cuda.synchronize()
        end = time.time()
        batch_time_sum += end - start
        if not sc_flag:
            if (opt.use_obj_mcl_loss == 1):
                print("iter {} (epoch {}), train_loss = {:.3f}, caption_loss = {:.3f}, object_loss = {:.3f}, time/batch = {:.3f}" \
                      .format(iteration, epoch, train_loss, caption_loss.item(), obj_loss.item(), end - start))
            else:
                print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                .format(iteration, epoch, train_loss, end - start))
        else:
            print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \
                .format(iteration, epoch, np.mean(reward[:,0]), end - start))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            epoch_done = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            add_summary_value(tb_summary_writer, 'train_loss', train_loss,
                              iteration)
            if (opt.use_obj_mcl_loss == 1):
                add_summary_value(tb_summary_writer, 'obj_loss',
                                  obj_loss.item(), iteration)
                add_summary_value(tb_summary_writer, 'caption_loss',
                                  caption_loss.item(), iteration)
            if opt.noamopt:
                opt.current_lr = optimizer.rate()
            elif opt.reduce_on_plateau:
                opt.current_lr = optimizer.current_lr
            add_summary_value(tb_summary_writer, 'learning_rate',
                              opt.current_lr, iteration)
            add_summary_value(tb_summary_writer, 'scheduled_sampling_prob',
                              model.ss_prob, iteration)
            if sc_flag:
                add_summary_value(tb_summary_writer, 'avg_reward',
                                  np.mean(reward[:, 0]), iteration)

            loss_history[iteration] = train_loss if not sc_flag else np.mean(
                reward[:, 0])
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            # eval model
            orig_batch_size = opt.batch_size
            opt.batch_size = 1
            eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            loader.batch_size = eval_kwargs.get('batch_size', 1)
            val_loss, predictions, lang_stats = eval_utils.eval_split(
                dp_model, crit, loader, eval_kwargs)
            opt.batch_size = orig_batch_size
            loader.batch_size = orig_batch_size

            if opt.reduce_on_plateau:
                if 'CIDEr' in lang_stats:
                    optimizer.scheduler_step(-lang_stats['CIDEr'])
                else:
                    optimizer.scheduler_step(val_loss)

            # Write validation result into summary
            add_summary_value(tb_summary_writer, 'validation loss', val_loss,
                              iteration)
            for k, v in lang_stats.items():
                add_summary_value(tb_summary_writer, k, v, iteration)
            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            best_flag = False
            if True:  # if true
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'infos_' + opt.id + '.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'histories_' + opt.id + '.pkl'),
                        'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos_' + opt.id + '-best.pkl'),
                            'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
예제 #12
0
def main(opt):
    dataset = VideoDataset(opt, 'train')
    dataloader = DataLoader(dataset,
                            batch_size=opt["batch_size"],
                            shuffle=True)
    opt["vocab_size"] = dataset.get_vocab_size()
    if opt["model"] == 'S2VTModel':
        model = S2VTModel(opt["vocab_size"],
                          opt["max_len"],
                          opt["dim_hidden"],
                          opt["dim_word"],
                          opt['dim_vid'],
                          rnn_cell=opt['rnn_type'],
                          n_layers=opt['num_layers'],
                          rnn_dropout_p=opt["rnn_dropout_p"])
    elif opt["model"] == "S2VTAttModel":
        encoder = EncoderRNN(opt["dim_vid"],
                             opt["dim_hidden"],
                             bidirectional=opt["bidirectional"],
                             input_dropout_p=opt["input_dropout_p"],
                             rnn_cell=opt['rnn_type'],
                             rnn_dropout_p=opt["rnn_dropout_p"])
        decoder = DecoderRNN(opt["vocab_size"],
                             opt["max_len"],
                             opt["dim_hidden"],
                             opt["dim_word"],
                             input_dropout_p=opt["input_dropout_p"],
                             rnn_cell=opt['rnn_type'],
                             rnn_dropout_p=opt["rnn_dropout_p"],
                             bidirectional=opt["bidirectional"])
        model = S2VTAttModel(encoder, decoder)
    elif opt["model"] == "S2VT_GCN_AttModel_1feat":
        encoder = EncoderRNN(opt["dim_vid"],
                             opt["dim_hidden"],
                             bidirectional=opt["bidirectional"],
                             input_dropout_p=opt["input_dropout_p"],
                             rnn_cell=opt['rnn_type'],
                             rnn_dropout_p=opt["rnn_dropout_p"])
        decoder = DecoderRNN(opt["vocab_size"],
                             opt["max_len"],
                             opt["dim_hidden"],
                             opt["dim_word"],
                             input_dropout_p=opt["input_dropout_p"],
                             rnn_cell=opt['rnn_type'],
                             rnn_dropout_p=opt["rnn_dropout_p"],
                             bidirectional=opt["bidirectional"])
        # model = S2VT_GCN_AttModel_1feat(encoder, decoder, opt['dim_vid'])
        model = S2VT_GCN_AttModel_1feat(encoder, decoder, opt['dim_vid'])
    elif opt["model"] == "S2VT_GCN_AttModel_3feat":
        encoder = EncoderRNN(opt["dim_vid"],
                             opt["dim_hidden"],
                             bidirectional=opt["bidirectional"],
                             input_dropout_p=opt["input_dropout_p"],
                             rnn_cell=opt['rnn_type'],
                             rnn_dropout_p=opt["rnn_dropout_p"])
        decoder = DecoderRNN(opt["vocab_size"],
                             opt["max_len"],
                             opt["dim_hidden"],
                             opt["dim_word"],
                             input_dropout_p=opt["input_dropout_p"],
                             rnn_cell=opt['rnn_type'],
                             rnn_dropout_p=opt["rnn_dropout_p"],
                             bidirectional=opt["bidirectional"])
        model = S2VT_GCN_AttModel_3feat(encoder, decoder, opt['dim_vid'])
    elif opt["model"] == "S2VT_GCN_SEAttn_3feat":
        encoder = EncoderRNN(opt["dim_vid"],
                             opt["dim_hidden"],
                             bidirectional=opt["bidirectional"],
                             input_dropout_p=opt["input_dropout_p"],
                             rnn_cell=opt['rnn_type'],
                             rnn_dropout_p=opt["rnn_dropout_p"])
        decoder = DecoderRNN(opt["vocab_size"],
                             opt["max_len"],
                             opt["dim_hidden"],
                             opt["dim_word"],
                             input_dropout_p=opt["input_dropout_p"],
                             rnn_cell=opt['rnn_type'],
                             rnn_dropout_p=opt["rnn_dropout_p"],
                             bidirectional=opt["bidirectional"])
        model = S2VT_GCN_SEAttn_3feat(encoder, decoder, opt['dim_vid'])
    elif opt["model"] == "S2VT_GCN_SimpleAtt_3feat":
        encoder = EncoderRNN(opt["dim_vid"],
                             opt["dim_hidden"],
                             bidirectional=opt["bidirectional"],
                             input_dropout_p=opt["input_dropout_p"],
                             rnn_cell=opt['rnn_type'],
                             rnn_dropout_p=opt["rnn_dropout_p"])
        decoder = DecoderRNN(opt["vocab_size"],
                             opt["max_len"],
                             opt["dim_hidden"],
                             opt["dim_word"],
                             input_dropout_p=opt["input_dropout_p"],
                             rnn_cell=opt['rnn_type'],
                             rnn_dropout_p=opt["rnn_dropout_p"],
                             bidirectional=opt["bidirectional"])
        model = S2VT_GCN_SimpleAtt_3feat(encoder, decoder, opt['dim_vid'])
    elif opt["model"] == "S2VT_GCN":
        encoder = EncoderRNN(opt["dim_vid"],
                             opt["dim_hidden"],
                             bidirectional=opt["bidirectional"],
                             input_dropout_p=opt["input_dropout_p"],
                             rnn_cell=opt['rnn_type'],
                             rnn_dropout_p=opt["rnn_dropout_p"])
        decoder = DecoderRNN(opt["vocab_size"],
                             opt["max_len"],
                             opt["dim_hidden"],
                             opt["dim_word"],
                             input_dropout_p=opt["input_dropout_p"],
                             rnn_cell=opt['rnn_type'],
                             rnn_dropout_p=opt["rnn_dropout_p"],
                             bidirectional=opt["bidirectional"])
        model = S2VT_GCN(encoder, decoder, opt['dim_vid'])
    elif opt["model"] == "S2VT_GCN_Sub":
        encoder = EncoderRNN(opt["dim_vid"],
                             opt["dim_hidden"],
                             bidirectional=opt["bidirectional"],
                             input_dropout_p=opt["input_dropout_p"],
                             rnn_cell=opt['rnn_type'],
                             rnn_dropout_p=opt["rnn_dropout_p"])
        decoder = DecoderRNN(opt["vocab_size"],
                             opt["max_len"],
                             opt["dim_hidden"],
                             opt["dim_word"],
                             input_dropout_p=opt["input_dropout_p"],
                             rnn_cell=opt['rnn_type'],
                             rnn_dropout_p=opt["rnn_dropout_p"],
                             bidirectional=opt["bidirectional"])
        model = S2VT_GCN_Sub(encoder, decoder, opt['dim_vid'])
    model = model.cuda()
    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()
    optimizer = optim.Adam(model.parameters(),
                           lr=opt["learning_rate"],
                           weight_decay=opt["weight_decay"])
    exp_lr_scheduler = optim.lr_scheduler.StepLR(
        optimizer,
        step_size=opt["learning_rate_decay_every"],
        gamma=opt["learning_rate_decay_rate"])

    train(dataloader, model, crit, optimizer, exp_lr_scheduler, opt, rl_crit)
def train(opt):
    if vars(opt).get('start_from', None) is not None:
        opt.checkpoint_path = opt.start_from
        opt.id = opt.checkpoint_path.split('/')[-1]
        print('Point to folder: {}'.format(opt.checkpoint_path))
    else:
        opt.id = datetime.datetime.now().strftime(
            '%Y%m%d_%H%M%S') + '_' + opt.caption_model
        opt.checkpoint_path = os.path.join(opt.checkpoint_path, opt.id)

        if not os.path.exists(opt.checkpoint_path):
            os.makedirs(opt.checkpoint_path)
        print('Create folder: {}'.format(opt.checkpoint_path))

    # Deal with feature things before anything
    opt.use_att = utils.if_use_att(opt.caption_model)
    # opt.use_att = False
    if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5

    loader = DataLoader_UP(opt)
    opt.vocab_size = loader.vocab_size
    if opt.use_rela == 1:
        opt.rela_dict_size = loader.rela_dict_size
    opt.seq_length = loader.seq_length
    use_rela = getattr(opt, 'use_rela', 0)

    try:
        tb_summary_writer = tf and tf.compat.v1.summary.FileWriter(
            opt.checkpoint_path)
    except:
        print('Set tensorboard error!')
        pdb.set_trace()

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.checkpoint_path, 'infos.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(os.path.join(opt.checkpoint_path, 'histories.pkl')):
            with open(os.path.join(opt.checkpoint_path, 'histories.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt).cuda()
    # dp_model = torch.nn.DataParallel(model)
    # dp_model = torch.nn.DataParallel(model, [0,2,3])
    dp_model = model

    print('### Model summary below###\n {}\n'.format(str(model)))
    model_params = sum(p.numel() for p in model.parameters()
                       if p.requires_grad)
    print('model parameter:{}'.format(model_params))

    update_lr_flag = True
    # Assure in training mode
    dp_model.train()
    parameters = model.named_children()
    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()

    optimizer = utils.build_optimizer(
        filter(lambda p: p.requires_grad, model.parameters()), opt)

    optimizer.zero_grad()
    accumulate_iter = 0
    train_loss = 0
    reward = np.zeros([1, 1])

    while True:
        if update_lr_flag:
            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate
            utils.set_lr(optimizer, opt.current_lr)
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            update_lr_flag = False

        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch(opt.train_split)
        # print('Read data:', time.time() - start)

        torch.cuda.synchronize()
        start = time.time()

        fc_feats = None
        att_feats = None
        att_masks = None
        ssg_data = None
        rela_data = None

        if getattr(opt, 'use_ssg', 0) == 1:
            if getattr(opt, 'use_isg', 0) == 1:
                tmp = [
                    data['fc_feats'], data['labels'], data['masks'],
                    data['att_feats'], data['att_masks'],
                    data['isg_rela_matrix'], data['isg_rela_masks'],
                    data['isg_obj'], data['isg_obj_masks'], data['isg_attr'],
                    data['isg_attr_masks'], data['ssg_rela_matrix'],
                    data['ssg_rela_masks'], data['ssg_obj'],
                    data['ssg_obj_masks'], data['ssg_attr'],
                    data['ssg_attr_masks']
                ]

                tmp = [
                    _ if _ is None else torch.from_numpy(_).cuda() for _ in tmp
                ]
                fc_feats, labels, masks, att_feats, att_masks, \
                isg_rela_matrix, isg_rela_masks, isg_obj, isg_obj_masks, isg_attr, isg_attr_masks, \
                ssg_rela_matrix, ssg_rela_masks, ssg_obj, ssg_obj_masks, ssg_attr, ssg_attr_masks = tmp

                # image graph domain
                isg_data = {}
                isg_data['att_feats'] = att_feats
                isg_data['att_masks'] = att_masks

                isg_data['isg_rela_matrix'] = isg_rela_matrix
                isg_data['isg_rela_masks'] = isg_rela_masks
                isg_data['isg_obj'] = isg_obj
                isg_data['isg_obj_masks'] = isg_obj_masks
                isg_data['isg_attr'] = isg_attr
                isg_data['isg_attr_masks'] = isg_attr_masks
                # text graph domain
                ssg_data = {}
                ssg_data['ssg_rela_matrix'] = ssg_rela_matrix
                ssg_data['ssg_rela_masks'] = ssg_rela_masks
                ssg_data['ssg_obj'] = ssg_obj
                ssg_data['ssg_obj_masks'] = ssg_obj_masks
                ssg_data['ssg_attr'] = ssg_attr
                ssg_data['ssg_attr_masks'] = ssg_attr_masks
            else:
                tmp = [
                    data['fc_feats'], data['ssg_rela_matrix'],
                    data['ssg_rela_masks'], data['ssg_obj'],
                    data['ssg_obj_masks'], data['ssg_attr'],
                    data['ssg_attr_masks'], data['labels'], data['masks']
                ]
                tmp = [
                    _ if _ is None else torch.from_numpy(_).cuda() for _ in tmp
                ]
                fc_feats, ssg_rela_matrix, ssg_rela_masks, ssg_obj, ssg_obj_masks, ssg_attr, ssg_attr_masks, labels, masks = tmp
                ssg_data = {}
                ssg_data['ssg_rela_matrix'] = ssg_rela_matrix
                ssg_data['ssg_rela_masks'] = ssg_rela_masks
                ssg_data['ssg_obj'] = ssg_obj
                ssg_data['ssg_obj_masks'] = ssg_obj_masks
                ssg_data['ssg_attr'] = ssg_attr

                isg_data = None
                ssg_data['ssg_attr_masks'] = ssg_attr_masks
        else:
            tmp = [
                data['fc_feats'], data['att_feats'], data['labels'],
                data['masks'], data['att_masks']
            ]
            tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
            fc_feats, att_feats, labels, masks, att_masks = tmp

        if not sc_flag:
            # loss = crit(dp_model(model_zh,model_en,itow_zh,itow, fc_feats, labels, isg_data, ssg_data), labels[:, 1:], masks[:, 1:])
            # print('ssg:')
            # print(ssg_data['ssg_obj'])
            # print('predict:')
            # print(dp_model(fc_feats, labels, isg_data, ssg_data))
            # print('label:')
            # print(labels[:, 1:])
            loss = crit(dp_model(fc_feats, labels, isg_data, ssg_data),
                        labels[:, 1:], masks[:, 1:])
        else:
            gen_result, sample_logprobs = dp_model(fc_feats,
                                                   isg_data,
                                                   ssg_data,
                                                   opt={'sample_max': 0},
                                                   mode='sample')
            reward = get_self_critical_reward(dp_model, fc_feats, isg_data,
                                              ssg_data, data, gen_result, opt)
            loss = rl_crit(sample_logprobs, gen_result.data,
                           torch.from_numpy(reward).float().cuda())

        accumulate_iter = accumulate_iter + 1
        loss = loss / opt.accumulate_number
        loss.backward()
        if accumulate_iter % opt.accumulate_number == 0:
            utils.clip_gradient(optimizer, opt.grad_clip)
            optimizer.step()
            optimizer.zero_grad()
            iteration += 1
            accumulate_iter = 0
            train_loss = loss.item() * opt.accumulate_number
            end = time.time()

            if not sc_flag:
                print("{}/{}/{}|train_loss={:.3f}|time/batch={:.3f}" \
                      .format(opt.id, iteration, epoch, train_loss, end - start))
            else:
                print("{}/{}/{}|avg_reward={:.3f}|time/batch={:.3f}" \
                      .format(opt.id, iteration, epoch, np.mean(reward[:, 0]), end - start))

        torch.cuda.synchronize()

        # Update the iteration and epoch
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0) and (iteration != 0):
            add_summary_value(tb_summary_writer, 'train_loss', train_loss,
                              iteration)
            add_summary_value(tb_summary_writer, 'learning_rate',
                              opt.current_lr, iteration)
            add_summary_value(tb_summary_writer, 'scheduled_sampling_prob',
                              model.ss_prob, iteration)
            if sc_flag:
                add_summary_value(tb_summary_writer, 'avg_reward',
                                  np.mean(reward[:, 0]), iteration)

            loss_history[iteration] = train_loss if not sc_flag else np.mean(
                reward[:, 0])
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        # if (iteration %2 == 0) and (iteration != 0):
        if (iteration % opt.save_checkpoint_every == 0) and (iteration != 0):
            # eval model
            if use_rela:
                eval_kwargs = {
                    'split': 'val',
                    'dataset': opt.input_json,
                    'use_real': 1
                }
            else:
                eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            # val_loss, predictions, lang_stats = eval_utils.eval_split(model_zh,model_en,itow_zh,itow, dp_model, crit, loader, eval_kwargs)
            val_loss, predictions, lang_stats = eval_utils.eval_split(
                dp_model, crit, loader, eval_kwargs)

            # Write validation result into summary
            add_summary_value(tb_summary_writer, 'validation loss', val_loss,
                              iteration)
            if lang_stats is not None:
                for k, v in lang_stats.items():
                    add_summary_value(tb_summary_writer, k, v, iteration)
            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            best_flag = False
            if True:  # if true
                save_id = iteration / opt.save_checkpoint_every
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(os.path.join(opt.checkpoint_path, 'infos.pkl'),
                          'wb') as f:
                    cPickle.dump(infos, f)
                with open(os.path.join(opt.checkpoint_path, 'histories.pkl'),
                          'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos-best.pkl'), 'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
예제 #14
0
def train(opt):
    torch.cuda.set_device(opt.device)
    # opt.use_att = utils.if_use_att(opt.caption_model)
    opt.use_att = True
    if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5

    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length
    print(opt.seq_length)
    print(opt.checkpoint_path)
    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from, 'infos_'+opt.id+'.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same=["caption_model", "rnn_type", "rnn_size", "num_layers"]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')):
            with open(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    critic_loss_history = histories.get('critic_loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})
    variance_history = histories.get('variance_history', {})
    time_history = histories.get('time_history', {})
    pseudo_num_history = histories.get('pseudo_num_history', {})
    pseudo_num_length_history = histories.get('pseudo_num_length_history', {})
    pseudo_num_batch_history = histories.get('pseudo_num_batch_history', {})
    sum_logits_history = histories.get('sum_logits_history', {})
    reward_main_history = histories.get('reward_main_history', {})
    first_order = histories.get('first_order_history', np.zeros(1))
    second_order = histories.get('second_order_history', np.zeros(1))
    first_order = torch.from_numpy(first_order).float().cuda()
    second_order = torch.from_numpy(second_order).float().cuda()

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt).cuda()
    dp_model = model

    target_actor = models.setup(opt).cuda()

    ####################### Critic pretrain #####################################################################
    ##### Critic with state as input
    # if opt.critic_model == 'state_critic':
    #     critic_model = CriticModel(opt)
    # else:
    critic_model = AttCriticModel(opt)
    target_critic = AttCriticModel(opt)
    if vars(opt).get('start_from_critic', None) is not None and True:
        # check if all necessary files exist
        assert os.path.isdir(opt.start_from_critic), " %s must be a a path" % opt.start_from_critic
        print(os.path.join(opt.start_from_critic, opt.critic_model + '_model.pth'))
        critic_model.load_state_dict(torch.load(os.path.join(opt.start_from_critic, opt.critic_model + '_model.pth')))
        target_critic.load_state_dict(torch.load(os.path.join(opt.start_from_critic, opt.critic_model + '_model.pth')))
    critic_model = critic_model.cuda()
    target_critic = target_critic.cuda()
    critic_optimizer = utils.build_optimizer(critic_model.parameters(), opt)
    dp_model.eval()
    critic_iter = 0
    init_scorer(opt.cached_tokens)
    critic_model.train()
    error_sum = 0
    loss_vector_sum = 0
    while opt.pretrain_critic == 1:
        if critic_iter > opt.pretrain_critic_steps:
            print('****************Finished critic training!')
            break
        data = loader.get_batch('train')
        torch.cuda.synchronize()
        start = time.time()
        tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks']]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        fc_feats, att_feats, labels, masks, att_masks = tmp
        critic_model.train()
        critic_optimizer.zero_grad()
        # assert opt.critic_model == 'att_critic_vocab'
        # crit_loss, reward, std = critic_loss_fun(fc_feats, att_feats, att_masks, dp_model, critic_model, opt, data)
        crit_loss, reward, std = target_critic_loss_fun_mask(fc_feats, att_feats, att_masks, dp_model, critic_model, opt, data, target_critic, target_actor)
        crit_loss.backward()
        critic_optimizer.step()
        #TODO update target.
        for cp, tp in zip(critic_model.parameters(), target_critic.parameters()):
            tp.data = tp.data + opt.gamma_critic * (cp.data - tp.data)
        crit_train_loss = crit_loss.item()
        torch.cuda.synchronize()
        end = time.time()
        error_sum += crit_train_loss**0.5-std
        if (critic_iter % opt.losses_log_every == 0):
            print("iter {} , crit_train_loss = {:.3f}, difference = {:.3f}, difference_sum = {:.3f}, time/batch = {:.3f}" \
                .format(critic_iter, crit_train_loss**0.5, crit_train_loss**0.5-std, error_sum, end - start))
            print(opt.checkpoint_path)
            opt.importance_sampling = 1
            critic_model.eval()
            _, _, _, _ = get_rf_loss(dp_model, fc_feats, att_feats, att_masks, data, opt, loader, critic_model, test_critic=True)

        critic_iter += 1

        # make evaluation on validation set, and save model
        if (critic_iter % opt.save_checkpoint_every == 0):
            if not os.path.isdir(opt.checkpoint_path):
                os.mkdir(opt.checkpoint_path)
            checkpoint_path = os.path.join(opt.checkpoint_path, opt.critic_model + '_model.pth')
            torch.save(critic_model.state_dict(), checkpoint_path)

    ######################### Actor-critic Training #####################################################################

    update_lr_flag = True
    # Assure in training mode
    dp_model.train()

    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()

    optimizer = utils.build_optimizer(model.parameters(), opt)
    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(os.path.join(opt.start_from,"optimizer.pth")):
        optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    # first_order = 0
    # second_order = 0
    while True:
        if update_lr_flag:
                # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate  ** frac
                opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate
            utils.set_lr(optimizer, opt.current_lr)
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob  * frac, opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            update_lr_flag = False

        # Load data from train split (0)
        data = loader.get_batch('train')
        # if data['bounds']['it_pos_now'] > 5000:
        #     loader.reset_iterator('train')
        #     continue
        dp_model.train()
        critic_model.eval()

        torch.cuda.synchronize()
        start = time.time()
        gen_result = None
        tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks']]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        fc_feats, att_feats, labels, masks, att_masks = tmp
        optimizer.zero_grad()
        if not sc_flag:
            loss = crit(dp_model(fc_feats, att_feats, labels, att_masks), labels[:,1:], masks[:,1:])
        else:
            if opt.rl_type == 'sc':
                gen_result, sample_logprobs = dp_model(fc_feats, att_feats, att_masks, opt={'sample_max':0}, mode='sample')
                reward = get_self_critical_reward(dp_model, fc_feats, att_feats, att_masks, data, gen_result, opt)
                loss = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda())
                pseudo_num = 0
                pseudo_num_length = 0
            elif opt.rl_type == 'reinforce':
                gen_result, sample_logprobs = dp_model(fc_feats, att_feats, att_masks, opt={'sample_max':0}, mode='sample')
                reward = get_reward(data, gen_result, opt)
                loss = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda())
                pseudo_num_length = 0
                pseudo_num = 0

            elif opt.rl_type == 'arsm':
                loss, pseudo_num, pseudo_num_length, pseudo_num_batch, rewards_main, sum_logits = get_arm_loss_daniel(dp_model, fc_feats, att_feats, att_masks, data, opt, loader)
                #print(loss)
                reward = np.zeros([2,2])
            elif opt.rl_type == 'rf4':
                loss,_,_,_ = get_rf_loss(dp_model, fc_feats, att_feats, att_masks, data, opt, loader)
                # print(loss)
                reward = np.zeros([2, 2])
            elif opt.rl_type == 'importance_sampling':
                opt.importance_sampling = 1
                loss, gen_result, reward, sample_logprobs_total = get_rf_loss(dp_model, fc_feats, att_feats, att_masks, data, opt, loader)
                reward = np.repeat(reward[:, np.newaxis], gen_result.shape[1], 1)
                std = np.std(reward)
            elif opt.rl_type == 'importance_sampling_critic':
                opt.importance_sampling = 1
                loss, gen_result, reward, sample_logprobs_total = get_rf_loss(target_actor, fc_feats, att_feats, att_masks, data, opt, loader, target_critic)
                reward = np.repeat(reward[:, np.newaxis], gen_result.shape[1], 1)
                std = np.std(reward)
            elif opt.rl_type == 'ar':
                loss = get_ar_loss(dp_model, fc_feats, att_feats, att_masks, data, opt, loader)
                reward = np.zeros([2,2])
            elif opt.rl_type == 'mct':
                opt.rf_demean = 0
                gen_result, sample_logprobs, probs, mct_baseline = get_mct_loss(dp_model, fc_feats, att_feats,
                                                                                att_masks, data,
                                                                                opt, loader)
                reward = get_reward(data, gen_result, opt)
                pseudo_num = 0
                pseudo_num_length = 0
                reward_cuda = torch.from_numpy(reward).float().cuda()
                mct_baseline[mct_baseline < 0] = reward_cuda[mct_baseline < 0]
                final_reward = torch.cat([mct_baseline[:, 1:], reward_cuda[:, 0:1]], 1)
                final_reward = final_reward - torch.mean(final_reward)
                if opt.arm_step_sample == 'greedy':
                    sample_logprobs = sample_logprobs * probs
                loss = rl_crit(sample_logprobs, gen_result.data, final_reward)
            elif opt.rl_type == 'mct_sc':
                opt.rf_demean = 0
                gen_result, sample_logprobs, probs, mct_baseline = get_mct_loss(dp_model, fc_feats, att_feats,
                                                                                att_masks, data,
                                                                                opt, loader)
                reward = get_reward(data, gen_result, opt)
                pseudo_num = 0
                pseudo_num_length = 0
                reward_cuda = torch.from_numpy(reward).float().cuda()
                mct_baseline[mct_baseline < 0] = reward_cuda[mct_baseline < 0]
                final_reward = torch.cat([mct_baseline[:, 1:], reward_cuda[:, 0:1]], 1)
                gen_result_sc, sample_logprobs_sc = dp_model(fc_feats, att_feats, att_masks, opt={'sample_max': 1},
                                                       mode='sample')
                reward = get_reward(data, gen_result_sc, opt)
                final_reward = final_reward - torch.from_numpy(reward).float().cuda()
                loss = rl_crit(sample_logprobs, gen_result.data, final_reward)
            elif opt.rl_type == 'mct_critic':
                #TODO change the critic to attention
                if opt.critic_model == 'state_critic':
                    opt.rf_demean = 0
                    gen_result, sample_logprobs, probs, mct_baseline = get_mct_loss(dp_model, fc_feats, att_feats,
                                                                                    att_masks, data,
                                                                                    opt, loader)
                    gen_result_pad = torch.cat(
                        [gen_result.new_zeros(gen_result.size(0), 1, dtype=torch.long), gen_result], 1)
                    critic_value = critic_model(gen_result_pad, fc_feats, att_feats, True, opt, att_masks).squeeze(2)
                    reward, std = get_reward(data, gen_result, opt, critic=True)
                    pseudo_num = 0
                    pseudo_num_length = 0
                    reward_cuda = torch.from_numpy(reward).float().cuda()
                    mct_baseline[mct_baseline < 0] = reward_cuda[mct_baseline < 0]
                    final_reward = torch.cat([mct_baseline[:, 1:], reward_cuda[:, 0:1]], 1)
                    print(critic_value.shape)
                    loss = rl_crit(sample_logprobs, gen_result.data, final_reward - critic_value)




                    critic_value, gen_result, sample_logprobs = critic_model(dp_model, fc_feats, att_feats, opt, att_masks)
                    reward, std = get_reward(data, gen_result, opt, critic=True)
                    loss = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda() - critic_value[:,:-1].data)
                elif opt.critic_model == 'att_critic':
                    opt.rf_demean = 0
                    gen_result, sample_logprobs, probs, mct_baseline = get_mct_loss(dp_model, fc_feats, att_feats,
                                                                                    att_masks, data,
                                                                                    opt, loader)
                    gen_result_pad = torch.cat(
                        [gen_result.new_zeros(gen_result.size(0), 1, dtype=torch.long), gen_result], 1)
                    critic_value = critic_model(gen_result_pad, fc_feats, att_feats, True, opt, att_masks).squeeze(2)
                    reward, std = get_reward(data, gen_result, opt, critic=True)
                    pseudo_num = 0
                    pseudo_num_length = 0
                    reward_cuda = torch.from_numpy(reward).float().cuda()
                    mct_baseline[mct_baseline < 0] = reward_cuda[mct_baseline < 0]
                    final_reward = torch.cat([mct_baseline[:, 1:], reward_cuda[:, 0:1]], 1)
                    print(critic_value.shape)
                    loss = rl_crit(sample_logprobs, gen_result.data, final_reward - critic_value)
            elif opt.rl_type =='mct_baseline':
                opt.rf_demean = 0
                gen_result, sample_logprobs, probs, mct_baseline = get_mct_loss(dp_model, fc_feats, att_feats, att_masks, data,
                                                                         opt, loader)
                reward = get_reward(data, gen_result, opt)
                pseudo_num = 0
                pseudo_num_length = 0
                reward_cuda = torch.from_numpy(reward).float().cuda()
                mct_baseline[mct_baseline < 0] = reward_cuda[mct_baseline < 0]
                if opt.arm_step_sample == 'greedy':
                    sample_logprobs = sample_logprobs * probs
                loss = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda() - mct_baseline)
            elif opt.rl_type == 'arsm_baseline':
                opt.arm_as_baseline = 1
                opt.rf_demean = 0
                gen_result, sample_logprobs, probs, arm_baseline = get_arm_loss(dp_model, fc_feats, att_feats, att_masks, data, opt, loader)
                reward = get_reward(data, gen_result, opt)
                reward_cuda = torch.from_numpy(reward).float().cuda()
                arm_baseline[arm_baseline < 0] = reward_cuda[arm_baseline < 0]
                if opt.arm_step_sample == 'greedy' and False:
                    sample_logprobs = sample_logprobs * probs
                loss = rl_crit(sample_logprobs, gen_result.data, reward_cuda - arm_baseline)
            elif opt.rl_type == 'ars_indicator':
                opt.arm_as_baseline = 1
                opt.rf_demean = 0
                gen_result, sample_logprobs, probs, arm_baseline = get_arm_loss(dp_model, fc_feats, att_feats, att_masks, data, opt, loader)
                reward = get_self_critical_reward(dp_model, fc_feats, att_feats, att_masks, data, gen_result, opt)
                reward_cuda = torch.from_numpy(reward).float().cuda()
                loss = rl_crit(sample_logprobs, gen_result.data, reward_cuda * arm_baseline)
            elif opt.rl_type == 'arsm_baseline_critic':
                opt.arm_as_baseline = 1
                opt.rf_demean = 0
                gen_result, sample_logprobs, probs, arm_baseline = get_arm_loss(dp_model, fc_feats, att_feats, att_masks, data, opt, loader, critic_model)
                reward, std = get_reward(data, gen_result, opt, critic=True)
                if opt.arm_step_sample == 'greedy':
                    sample_logprobs = sample_logprobs * probs
                loss = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda() - arm_baseline)
            elif opt.rl_type == 'arsm_critic':
                #print(opt.critic_model)
                tic = time.time()
                loss = get_arm_loss(dp_model, fc_feats, att_feats, att_masks, data, opt, loader, critic_model)
                #print('arm_loss time', str(time.time()-tic))
                reward = np.zeros([2, 2])
            elif opt.rl_type == 'critic_vocab_sum':
                assert opt.critic_model == 'att_critic_vocab'
                tic = time.time()
                gen_result, sample_logprobs_total = dp_model(fc_feats, att_feats, att_masks, opt={'sample_max': 0}, total_probs=True,
                                                       mode='sample') #batch, seq, vocab
                #print('generation time', time.time()-tic)
                gen_result_pad = torch.cat(
                    [gen_result.new_zeros(gen_result.size(0), 1, dtype=torch.long), gen_result], 1)
                tic = time.time()
                critic_value = critic_model(gen_result_pad, fc_feats, att_feats, True, opt, att_masks) #batch, seq, vocab
                #print('critic time', time.time() - tic)
                probs = torch.sum(F.softmax(sample_logprobs_total, 2) * critic_value.detach(), 2)
                mask = (gen_result > 0).float()
                mask = torch.cat([mask.new(mask.size(0), 1).fill_(1), mask[:, :-1]], 1)
                loss = -torch.sum(probs * mask) / torch.sum(mask)
                reward = np.zeros([2, 2])
            elif opt.rl_type == 'reinforce_critic':
                #TODO change the critic to attention
                if opt.critic_model == 'state_critic':
                    critic_value, gen_result, sample_logprobs = critic_model(dp_model, fc_feats, att_feats, opt, att_masks)
                    reward, std = get_reward(data, gen_result, opt, critic=True)
                    loss = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda() - critic_value[:,:-1].data)
                elif opt.critic_model == 'att_critic':
                    gen_result, sample_logprobs = dp_model(fc_feats, att_feats, att_masks, opt={'sample_max': 0},
                                                           mode='sample')
                    gen_result_pad = torch.cat(
                        [gen_result.new_zeros(gen_result.size(0), 1, dtype=torch.long), gen_result], 1)
                    critic_value = critic_model(gen_result_pad, fc_feats, att_feats, True, opt, att_masks).squeeze(2)

                    reward, std = get_reward(data, gen_result, opt, critic=True)
                    loss = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda() - critic_value.data)
        if opt.mle_weights != 0:
            loss += opt.mle_weights * crit(dp_model(fc_feats, att_feats, labels, att_masks), labels[:, 1:], masks[:, 1:])
        #TODO make sure all sampling replaced by greedy for critic
        #### update the actor
        loss.backward()
        # with open(os.path.join(opt.checkpoint_path, 'best_embed.pkl'), 'wb') as f:
        #     cPickle.dump(list(dp_model.embed.parameters())[0].data.cpu().numpy(), f)
        # with open(os.path.join(opt.checkpoint_path, 'best_logit.pkl'), 'wb') as f:
        #     cPickle.dump(list(dp_model.logit.parameters())[0].data.cpu().numpy(), f)
        ## compute variance
        gradient = torch.zeros([0]).cuda()
        for i in model.parameters():
            gradient = torch.cat((gradient, i.grad.view(-1)), 0)
        first_order = 0.9999 * first_order + 0.0001 * gradient
        second_order = 0.9999 * second_order + 0.0001 * gradient.pow(2)
        # print(torch.max(torch.abs(gradient)))
        variance = torch.mean(torch.abs(second_order - first_order.pow(2))).item()
        if opt.rl_type != 'arsm' or not sc_flag:
            utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        # ### update the critic
        if 'critic' in opt.rl_type:
            dp_model.eval()
            critic_model.train()
            utils.set_lr(critic_optimizer, opt.critic_learning_rate)
            critic_optimizer.zero_grad()
            #assert opt.critic_model == 'att_critic_vocab'
            crit_loss, reward, std = target_critic_loss_fun_mask(fc_feats, att_feats, att_masks, dp_model, critic_model, opt,
                                                            data, target_critic, target_actor, gen_result=gen_result, sample_logprobs_total=sample_logprobs_total, reward=reward)
            crit_loss.backward()
            critic_optimizer.step()
            for cp, tp in zip(critic_model.parameters(), target_critic.parameters()):
                tp.data = tp.data + opt.gamma_critic * (cp.data - tp.data)
            for cp, tp in zip(dp_model.parameters(), target_actor.parameters()):
                tp.data = tp.data + opt.gamma_actor * (cp.data - tp.data)
            crit_train_loss = crit_loss.item()
            error_sum += crit_train_loss ** 0.5 - std
        train_loss = loss.item()
        torch.cuda.synchronize()
        end = time.time()
        if (iteration % opt.losses_log_every == 0):
            if not sc_flag:
                print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, train_loss, end - start))
                print(opt.checkpoint_path)
            elif 'critic' in opt.rl_type:
                print(
                    "iter {} , crit_train_loss = {:.3f}, difference = {:.3f}, difference_sum = {:.3f},variance = {:g}, time/batch = {:.3f}" \
                    .format(iteration, crit_train_loss ** 0.5, crit_train_loss ** 0.5 - std, error_sum, variance, end - start))
                print(opt.checkpoint_path)
                critic_model.eval()
                _, _, _, _ = get_rf_loss(dp_model, fc_feats, att_feats, att_masks, data, opt, loader, critic_model, test_critic=True)
            else:
                print("iter {} (epoch {}), avg_reward = {:.3f}, variance = {:g}, time/batch = {:.3f}" \
                      .format(iteration, epoch, np.mean(reward[:, 0]), variance, end - start))
                print("pseudo num: ", pseudo_num)

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration)
            add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration)
            add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration)
            if sc_flag:
                add_summary_value(tb_summary_writer, 'avg_reward', np.mean(reward), iteration)
                add_summary_value(tb_summary_writer, 'variance', variance, iteration)

            loss_history[iteration] = train_loss if not sc_flag else np.mean(reward)
            critic_loss_history[iteration] = crit_train_loss if 'critic' in opt.rl_type else 0
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob
            variance_history[iteration] = variance
            pseudo_num_history[iteration] = pseudo_num
            reward_main_history[iteration] = rewards_main
            #print(pseudo_num_length)
            #print(type(pseudo_num_length).__module__)
            if type(pseudo_num_length).__module__ != 'torch':
                print('not right')
                pseudo_num_length_history[iteration] = pseudo_num_length
                pseudo_num_batch_history[iteration] = pseudo_num_batch
                sum_logits_history[iteration] = sum_logits
            else:
                pseudo_num_length_history[iteration] = pseudo_num_length.data.cpu().numpy()
                pseudo_num_batch_history[iteration] = pseudo_num_batch.data.cpu().numpy()
                sum_logits_history[iteration] = sum_logits.data.cpu().numpy()
            time_history[iteration] = end - start


        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            # eval model
            eval_kwargs = {'split': 'val',
                            'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = eval_utils.eval_split(dp_model, crit, loader, eval_kwargs)

            # Write validation result into summary
            add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration)
            if lang_stats is not None:
                for k,v in lang_stats.items():
                    add_summary_value(tb_summary_writer, k, v, iteration)
            val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions}

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = - val_loss

            best_flag = False
            if True: # if true
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                if not os.path.isdir(opt.checkpoint_path):
                    os.mkdir(opt.checkpoint_path)
                checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                checkpoint_path = os.path.join(opt.checkpoint_path, opt.critic_model + '_model.pth')
                torch.save(critic_model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['critic_loss_history'] = critic_loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                histories['variance_history'] = variance_history
                histories['pseudo_num_history'] = pseudo_num_history
                histories['pseudo_num_length_history'] = pseudo_num_length_history
                histories['pseudo_num_batch_history'] = pseudo_num_batch_history
                histories['sum_logits_history'] = sum_logits_history
                histories['reward_main_history'] = reward_main_history
                histories['time'] = time_history
                histories['first_order_history'] = first_order.data.cpu().numpy()
                histories['second_order_history'] = second_order.data.cpu().numpy()
                # histories['variance'] = 0
                with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(os.path.join(opt.checkpoint_path, 'histories_'+opt.id+'.pkl'), 'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path, 'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'-best.pkl'), 'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
예제 #15
0
def train(opt):
    # Load data
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    # Tensorboard summaries (they're great!)
    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    # Load pretrained model, info file, histories file
    infos = {}
    histories = {}
    if opt.start_from is not None:
        with open(os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl'),
                  'rb') as f:
            infos = cPickle.load(f, encoding='latin1')
            saved_model_opt = infos['opt']
            need_be_same = ["rnn_type", "rnn_size", "num_layers"]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme
        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl'), 'rb') as f:
                histories = cPickle.load(f, encoding='bytes')
    iteration = infos.get('iter', 0)
    epoch = 0
    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    # ss_prob_history = histories.get('ss_prob_history', {})
    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    # Create model
    model = transformer.convcap(opt).cuda()
    pretrained_dict = torch.load(opt.model)
    model.load_state_dict(pretrained_dict, strict=False)

    dp_model = model
    dp_model.train()

    # Loss function
    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()

    # Optimizer and learning rate adjustment flag
    optimizer = utils.build_optimizer_adam(model.parameters(), opt)
    update_lr_flag = True
    if os.path.isfile(os.path.join('log_cvpr/', "optimizer.pth")):
        optimizer.load_state_dict(
            torch.load(os.path.join('log_cvpr/', 'optimizer.pth')))
        print('optimiser loaded')
    # Training loop
    while True:

        # Update learning rate once per epoch
        if update_lr_flag:

            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate
            utils.set_lr(optimizer, opt.current_lr)

            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                # opt.ss_prob = min(opt.scheduled_sampling_increase_prob  * frac, opt.scheduled_sampling_max_prob)
                # model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            update_lr_flag = False

        # Load data from train split (0)
        start = time.time()
        data = loader.get_batch('train')
        data_time = time.time() - start
        start = time.time()

        # Unpack data
        torch.cuda.synchronize()
        tmp = [
            data['fc_feats'], data['att_feats'], data['labels'], data['dist'],
            data['masks'], data['att_masks']
        ]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        fc_feats, att_feats, labels, dist_label, masks, att_masks = tmp
        batchsize = fc_feats.size(0)
        labels = labels.long()
        #labels[:,:,0] = 8667

        # Forward pass and loss
        optimizer.zero_grad()
        if 1:
            #with torch.autograd.detect_anomaly():
            wordact, _ = dp_model(fc_feats, att_feats, labels, 30, 6)
            p = F.softmax(wordact, 1)
            #logp = F.log_softmax(wordact, 1)
            entropy = 0
            # loss_dis = F.binary_cross_entropy(dist_all, dist_label.cpu().float())
            mask = masks.view(batchsize, -1)
            mask = mask[:, 1:].contiguous()
            labels = labels.view(batchsize, -1)
            wordact = wordact[:, :, :-1]
            wordact_t = wordact.permute(0, 2, 1).contiguous()
            wordact_t = wordact_t.view(
                wordact_t.size(0) * wordact_t.size(1), -1)
            wordclass_v = labels[:, 1:]
            wordclass_t = wordclass_v.contiguous().view( \
                wordclass_v.size(0) * wordclass_v.size(1), 1).cpu()
            maskids = torch.nonzero(mask.view(-1).cpu()).numpy().reshape(-1)
            loss_xe = F.cross_entropy(wordact_t[...], \
                                      wordclass_t[...].contiguous().view(-1).cuda())
            loss = loss_xe

            loss.backward()
            #utils.clip_gradient(optimizer, opt.grad_clip)
            optimizer.step()
            train_loss = loss.item()
            torch.cuda.synchronize()

        # Print
        total_time = time.time() - start
        if iteration % opt.print_freq == 0:
            print('Read data:', time.time() - start)
            if not sc_flag:
                print(
                    "iter {} (epoch {}), train_loss = {:.3f}, entropy = {:.3f}, data_time = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, loss_xe, entropy, data_time, total_time))
            # writer.add_scalar('Loss/train', loss_xe, epoch)
            else:
                print("iter {} (epoch {}), avg_reward = {:.3f}, data_time = {:.3f}, time/batch = {:.3f}" \
                      .format(iteration, epoch, np.mean(reward[:, 0]), data_time, total_time))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True
        #            writer.add_scalar('Loss/train', loss_xe, epoch)
        #           writer.add_scalar('Entropy_After_Softmax/train', entropy, epoch)
        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            add_summary_value(tb_summary_writer, 'train_loss', train_loss,
                              iteration)
            add_summary_value(tb_summary_writer, 'learning_rate',
                              opt.current_lr, iteration)
            # add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration)
            if sc_flag:
                add_summary_value(tb_summary_writer, 'avg_reward',
                                  np.mean(reward[:, 0]), iteration)
            loss_history[iteration] = train_loss if not sc_flag else np.mean(
                reward[:, 0])
            lr_history[iteration] = opt.current_lr
            # ss_prob_history[iteration] = model.ss_prob

        # Validate and save model
        if (iteration % opt.save_checkpoint_every == 0):
            checkpoint_path = os.path.join(opt.checkpoint_path,
                                           str(iteration) + 'model2.pth')
            torch.save(model.state_dict(), checkpoint_path)
            print("model saved to {}".format(checkpoint_path))
            optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth')
            torch.save(optimizer.state_dict(), optimizer_path)
            # Evaluate model
            '''
            eval_kwargs = {'split': 'test',
                           'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = eval_utils.eval_split(dp_model, crit, loader, eval_kwargs)
            # Write validation result into summary
            add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration)
            if lang_stats is not None:
                for k, v in lang_stats.items():
                    add_summary_value(tb_summary_writer, k, v, iteration)
            val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions}

            # Our metric is CIDEr if available, otherwise validation loss
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = - val_loss

            # Save model in checkpoint path
            best_flag = False
            if best_val_score is None or current_score > best_val_score:
                best_val_score = current_score
                best_flag = True
            checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth')
            torch.save(model.state_dict(), checkpoint_path)
            print("model saved to {}".format(checkpoint_path))
            optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth')
            torch.save(optimizer.state_dict(), optimizer_path)

            # Dump miscalleous informations
            infos['iter'] = iteration
            infos['epoch'] = epoch
            infos['iterators'] = loader.iterators
            infos['split_ix'] = loader.split_ix
            infos['best_val_score'] = best_val_score
            infos['opt'] = opt
            infos['vocab'] = loader.get_vocab()
            histories['val_result_history'] = val_result_history
            histories['loss_history'] = loss_history
            histories['lr_history'] = lr_history
            # histories['ss_prob_history'] = ss_prob_history
            with open(os.path.join(opt.checkpoint_path, 'infos_' + opt.id + '.pkl'), 'wb') as f:
                cPickle.dump(infos, f)
            with open(os.path.join(opt.checkpoint_path, 'histories_' + opt.id + '.pkl'), 'wb') as f:
                cPickle.dump(histories, f)

            # Save model to unique file if new best model
            if best_flag:
                model_fname = 'model-best-i{:05d}-score{:.4f}.pth'.format(iteration, best_val_score)
                infos_fname = 'model-best-i{:05d}-infos.pkl'.format(iteration)
                checkpoint_path = os.path.join(opt.checkpoint_path, model_fname)
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                with open(os.path.join(opt.checkpoint_path, infos_fname), 'wb') as f:
                    cPickle.dump(infos, f)
            '''
        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
예제 #16
0
def train(opt):
    opt.use_att = utils.if_use_att(opt.caption_model)

    from dataloader import DataLoader
    loader = DataLoader(opt)

    opt.vocab_size = loader.vocab_size
    opt.vocab_ccg_size = loader.vocab_ccg_size
    opt.seq_length = loader.seq_length

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from,
                               'infos_' + opt.id + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme
        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)

    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    cnn_model = utils.build_cnn(opt)
    cnn_model.cuda()

    model = models.setup(opt)
    model.cuda()
    # model = DataParallel(model)

    if vars(opt).get('start_from', None) is not None:
        # check if all necessary files exist
        assert os.path.isdir(
            opt.start_from), " %s must be a a path" % opt.start_from
        assert os.path.isfile(
            os.path.join(opt.start_from, "infos_" + opt.id + ".pkl")
        ), "infos.pkl file does not exist in path %s" % opt.start_from
        model.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'model.pth')))

    update_lr_flag = True
    model.train()

    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()
    multilabel_crit = nn.MultiLabelSoftMarginLoss().cuda()
    #    optimizer = optim.Adam(model.parameters(), lr=opt.learning_rate, weight_decay=opt.weight_decay)
    optimizer = optim.Adam(model.parameters(), lr=opt.learning_rate)
    if opt.finetune_cnn_after != -1 and epoch >= opt.finetune_cnn_after:
        print('finetune mode')
        cnn_optimizer = optim.Adam([\
            {'params': module.parameters()} for module in cnn_model._modules.values()[5:]\
            ], lr=opt.cnn_learning_rate, weight_decay=opt.cnn_weight_decay)

    if vars(opt).get('start_from', None) is not None and os.path.isfile(
            os.path.join(opt.start_from, "optimizer.pth")):
        if os.path.isfile(os.path.join(opt.start_from, 'optimizer.pth')):
            optimizer.load_state_dict(
                torch.load(os.path.join(opt.start_from, 'optimizer.pth')))
        if opt.finetune_cnn_after != -1 and epoch >= opt.finetune_cnn_after:
            if os.path.isfile(os.path.join(opt.start_from,
                                           'optimizer-cnn.pth')):
                cnn_optimizer.load_state_dict(
                    torch.load(
                        os.path.join(opt.start_from, 'optimizer-cnn.pth')))

    eval_kwargs = {'split': 'val', 'dataset': opt.input_json, 'verbose': True}
    eval_kwargs.update(vars(opt))
    val_loss, predictions, lang_stats = eval_utils.eval_split(
        cnn_model, model, crit, loader, eval_kwargs, True)
    epoch_start = time.time()
    while True:
        if update_lr_flag:
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
                utils.set_lr(optimizer, opt.current_lr)  # set the decayed rate
            else:
                opt.current_lr = opt.learning_rate
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob
                #model.module.ss_prob = opt.ss_prob
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
            else:
                sc_flag = False

            # Update the training stage of cnn
            for p in cnn_model.parameters():
                p.requires_grad = True
            # Fix the first few layers:
            for module in cnn_model._modules.values()[:5]:
                for p in module.parameters():
                    p.requires_grad = False
            cnn_model.train()
            update_lr_flag = False

        cnn_model.apply(utils.set_bn_fix)
        cnn_model.apply(utils.set_bn_eval)

        start = time.time()
        torch.cuda.synchronize()
        data = loader.get_batch('train')
        if opt.finetune_cnn_after != -1 and epoch >= opt.finetune_cnn_after:

            multilabels = [
                data['detection_infos'][i]['label']
                for i in range(len(data['detection_infos']))
            ]

            tmp = [
                data['labels'], data['masks'],
                np.array(multilabels, dtype=np.int16)
            ]
            tmp = [
                Variable(torch.from_numpy(_), requires_grad=False).cuda()
                for _ in tmp
            ]
            labels, masks, multilabels = tmp
            images = data[
                'images']  # it cannot be turned into tensor since different sizes.
            _fc_feats_2048 = []
            _fc_feats_81 = []
            _att_feats = []
            for i in range(loader.batch_size):
                x = Variable(torch.from_numpy(images[i]),
                             requires_grad=False).cuda()
                x = x.unsqueeze(0)
                att_feats, fc_feats_81 = cnn_model(x)
                fc_feats_2048 = att_feats.mean(3).mean(2).squeeze()
                att_feats = F.adaptive_avg_pool2d(att_feats,
                                                  [14, 14]).squeeze().permute(
                                                      1, 2, 0)  #(0, 2, 3, 1)
                _fc_feats_2048.append(fc_feats_2048)
                _fc_feats_81.append(fc_feats_81)
                _att_feats.append(att_feats)
            _fc_feats_2048 = torch.stack(_fc_feats_2048)
            _fc_feats_81 = torch.stack(_fc_feats_81)
            _att_feats = torch.stack(_att_feats)
            att_feats = _att_feats.unsqueeze(1).expand(*((_att_feats.size(0), loader.seq_per_img,) + \
                                                           _att_feats.size()[1:])).contiguous().view(*((_att_feats.size(0) * loader.seq_per_img,) + \
                                                           _att_feats.size()[1:]))
            fc_feats_2048 = _fc_feats_2048.unsqueeze(1).expand(*((_fc_feats_2048.size(0), loader.seq_per_img,) + \
                                                          _fc_feats_2048.size()[1:])).contiguous().view(*((_fc_feats_2048.size(0) * loader.seq_per_img,) + \
                                                          _fc_feats_2048.size()[1:]))
            fc_feats_81 = _fc_feats_81
            #
            cnn_optimizer.zero_grad()
        else:

            tmp = [
                data['fc_feats'], data['att_feats'], data['labels'],
                data['masks']
            ]
            tmp = [
                Variable(torch.from_numpy(_), requires_grad=False).cuda()
                for _ in tmp
            ]
            fc_feats, att_feats, labels, masks = tmp

        optimizer.zero_grad()

        if not sc_flag:
            loss1 = crit(model(fc_feats_2048, att_feats, labels),
                         labels[:, 1:], masks[:, 1:])
            loss2 = multilabel_crit(fc_feats_81.double(), multilabels.double())
            loss = 0.8 * loss1 + 0.2 * loss2.float()
        else:
            gen_result, sample_logprobs = model.sample(fc_feats_2048,
                                                       att_feats,
                                                       {'sample_max': 0})
            reward = get_self_critical_reward(model, fc_feats_2048, att_feats,
                                              data, gen_result)
            loss1 = rl_crit(
                sample_logprobs, gen_result,
                Variable(torch.from_numpy(reward).float().cuda(),
                         requires_grad=False))
            loss2 = multilabel_crit(fc_feats_81.double(), multilabels.double())
            loss3 = crit(model(fc_feats_2048, att_feats, labels),
                         labels[:, 1:], masks[:, 1:])
            loss = 0.995 * loss1 + 0.005 * (loss2.float() + loss3)
        loss.backward()

        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()

        train_loss = loss.data[0]
        mle_loss = loss1.data[0]
        multilabel_loss = loss2.data[0]
        torch.cuda.synchronize()
        end = time.time()
        if not sc_flag and iteration % 2500 == 0:
            print("iter {} (epoch {}), mle_loss = {:.3f}, multilabel_loss = {:.3f}, train_loss = {:.3f}, time/batch = {:.3f}" \
                .format(iteration, epoch, mle_loss, multilabel_loss, train_loss, end - start))

        if sc_flag and iteration % 2500 == 0:
            print("iter {} (epoch {}), avg_reward = {:.3f}, mle_loss = {:.3f}, multilabel_loss = {:.3f}, train_loss = {:.3f}, time/batch = {:.3f}" \
                .format(iteration, epoch, np.mean(reward[:,0]), mle_loss, multilabel_loss, train_loss, end - start))
        iteration += 1
        if (iteration % opt.losses_log_every == 0):
            loss_history[iteration] = train_loss if not sc_flag else np.mean(
                reward[:, 0])
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        if (iteration % opt.save_checkpoint_every == 0):
            eval_kwargs = {
                'split': 'val',
                'dataset': opt.input_json,
                'verbose': True
            }
            eval_kwargs.update(vars(opt))

            if opt.finetune_cnn_after != -1 and epoch >= opt.finetune_cnn_after:
                val_loss, predictions, lang_stats = eval_utils.eval_split(
                    cnn_model, model, crit, loader, eval_kwargs, True)
            else:
                val_loss, predictions, lang_stats = eval_utils.eval_split(
                    cnn_model, model, crit, loader, eval_kwargs, False)

            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            best_flag = False
            if True:
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))

                cnn_checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-cnn.pth')
                torch.save(cnn_model.state_dict(), cnn_checkpoint_path)
                print("cnn model saved to {}".format(cnn_checkpoint_path))

                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                if opt.finetune_cnn_after != -1 and epoch >= opt.finetune_cnn_after:
                    cnn_optimizer_path = os.path.join(opt.checkpoint_path,
                                                      'optimizer-cnn.pth')
                    torch.save(cnn_optimizer.state_dict(), cnn_optimizer_path)

                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'infos_' + opt.id + '.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'histories_' + opt.id + '.pkl'),
                        'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))

                    cnn_checkpoint_path = os.path.join(opt.checkpoint_path,
                                                       'model-cnn-best.pth')
                    torch.save(cnn_model.state_dict(), cnn_checkpoint_path)
                    print("cnn model saved to {}".format(cnn_checkpoint_path))

                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos_' + opt.id + '-best.pkl'),
                            'wb') as f:
                        cPickle.dump(infos, f)

        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True
            print("epoch: " + str(epoch) + " during: " +
                  str(time.time() - epoch_start))
            epoch_start = time.time()

        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
예제 #17
0
def train(opt):
    # Deal with feature things before anything
    opt.use_att = utils.if_use_att(opt.caption_model)
    ac = 0

    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(
                os.path.join(
                    opt.checkpoint_path, 'infos_' + opt.id +
                    format(int(opt.start_from), '04') + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(
                os.path.join(
                    opt.checkpoint_path, 'histories_' + opt.id +
                    format(int(opt.start_from), '04') + '.pkl')):
            with open(
                    os.path.join(
                        opt.checkpoint_path, 'histories_' + opt.id +
                        format(int(opt.start_from), '04') + '.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt).cuda()

    #dp_model = torch.nn.DataParallel(model)
    #dp_model = torch.nn.DataParallel(model, [0,2,3])
    dp_model = model

    update_lr_flag = True
    # Assure in training mode
    dp_model.train()

    for name, param in model.named_parameters():
        print(name)

    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()
    CE_ac = utils.CE_ac()

    optim_para = model.parameters()
    optimizer = utils.build_optimizer(optim_para, opt)

    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(
            os.path.join(
                opt.checkpoint_path, 'optimizer' + opt.id +
                format(int(opt.start_from), '04') + '.pth')):
        optimizer.load_state_dict(
            torch.load(
                os.path.join(
                    opt.checkpoint_path, 'optimizer' + opt.id +
                    format(int(opt.start_from), '04') + '.pth')))

    optimizer.zero_grad()
    accumulate_iter = 0
    train_loss = 0
    reward = np.zeros([1, 1])
    sim_lambda = opt.sim_lambda
    reset_optimzer_index = 1
    while True:
        if opt.self_critical_after != -1 and epoch >= opt.self_critical_after and reset_optimzer_index:
            opt.learning_rate_decay_start = opt.self_critical_after
            opt.learning_rate_decay_rate = opt.learning_rate_decay_rate_rl
            opt.learning_rate = opt.learning_rate_rl
            reset_optimzer_index = 0

        if update_lr_flag:
            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate
            utils.set_lr(optimizer, opt.current_lr)
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            update_lr_flag = False

        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch(opt.train_split)
        print('Read data:', time.time() - start)

        torch.cuda.synchronize()
        start = time.time()

        tmp = [data['labels'], data['masks'], data['mods']]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        labels, masks, mods = tmp

        tmp = [
            data['att_feats'], data['att_masks'], data['attr_feats'],
            data['attr_masks'], data['rela_feats'], data['rela_masks']
        ]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        att_feats, att_masks, attr_feats, attr_masks, rela_feats, rela_masks = tmp

        rs_data = {}
        rs_data['att_feats'] = att_feats
        rs_data['att_masks'] = att_masks
        rs_data['attr_feats'] = attr_feats
        rs_data['attr_masks'] = attr_masks
        rs_data['rela_feats'] = rela_feats
        rs_data['rela_masks'] = rela_masks

        if not sc_flag:
            logits, cw_logits = dp_model(rs_data, labels)
            ac = CE_ac(logits, labels[:, 1:], masks[:, 1:])
            print('ac :{0}'.format(ac))
            loss_lan = crit(logits, labels[:, 1:], masks[:, 1:])
        else:
            gen_result, sample_logprobs, cw_logits = dp_model(
                rs_data, opt={'sample_max': 0}, mode='sample')
            reward = get_self_critical_reward(dp_model, rs_data, data,
                                              gen_result, opt)
            loss_lan = rl_crit(sample_logprobs, gen_result.data,
                               torch.from_numpy(reward).float().cuda())

        loss_cw = crit(cw_logits, mods[:, 1:], masks[:, 1:])
        ac2 = CE_ac(cw_logits, mods[:, 1:], masks[:, 1:])
        print('ac :{0}'.format(ac2))
        if epoch < opt.step2_train_after:
            loss = loss_lan + sim_lambda * loss_cw
        else:
            loss = loss_lan

        accumulate_iter = accumulate_iter + 1
        loss = loss / opt.accumulate_number
        loss.backward()
        if accumulate_iter % opt.accumulate_number == 0:
            utils.clip_gradient(optimizer, opt.grad_clip)
            optimizer.step()
            optimizer.zero_grad()
            iteration += 1
            accumulate_iter = 0
            train_loss = loss.item() * opt.accumulate_number
            train_loss_lan = loss_lan.item()
            train_loss_cw = loss_cw.item()
            end = time.time()

            if not sc_flag:
                print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                      .format(iteration, epoch, train_loss, end - start))
                print("train_loss_lan = {:.3f}, train_loss_cw = {:.3f}" \
                      .format(train_loss_lan, train_loss_cw))
            else:
                print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \
                      .format(iteration, epoch, np.mean(reward[:, 0]), end - start))
                print("train_loss_lan = {:.3f}, train_loss_cw = {:.3f}" \
                      .format(train_loss_lan, train_loss_cw))
            print('lr:{0}'.format(opt.current_lr))

        torch.cuda.synchronize()

        # Update the iteration and epoch
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every
                == 0) and (accumulate_iter % opt.accumulate_number == 0):
            add_summary_value(tb_summary_writer, 'train_loss', train_loss,
                              iteration)
            add_summary_value(tb_summary_writer, 'train_loss_lan',
                              train_loss_lan, iteration)
            add_summary_value(tb_summary_writer, 'train_loss_cw',
                              train_loss_cw, iteration)
            add_summary_value(tb_summary_writer, 'learning_rate',
                              opt.current_lr, iteration)
            add_summary_value(tb_summary_writer, 'scheduled_sampling_prob',
                              model.ss_prob, iteration)
            add_summary_value(tb_summary_writer, 'ac', ac, iteration)
            if sc_flag:
                add_summary_value(tb_summary_writer, 'avg_reward',
                                  np.mean(reward[:, 0]), iteration)

            loss_history[iteration] = train_loss if not sc_flag else np.mean(
                reward[:, 0])
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every
                == 0) and (accumulate_iter % opt.accumulate_number == 0):
            # eval model
            eval_kwargs = {'split': 'test', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            #val_loss, predictions, lang_stats = eval_utils_rs3.eval_split(dp_model, crit, loader, eval_kwargs)

            # Write validation result into summary
            # add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration)
            # if lang_stats is not None:
            #     for k,v in lang_stats.items():
            #         add_summary_value(tb_summary_writer, k, v, iteration)
            # val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions}

            # Save model if is improving on validation result
            # if opt.language_eval == 1:
            #     current_score = lang_stats['CIDEr']
            # else:
            #     current_score = - val_loss
            current_score = 0

            best_flag = False
            if True:  # if true
                save_id = iteration / opt.save_checkpoint_every
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                checkpoint_path = os.path.join(
                    opt.checkpoint_path,
                    'model' + opt.id + format(int(save_id), '04') + '.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(
                    opt.checkpoint_path,
                    'optimizer' + opt.id + format(int(save_id), '04') + '.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(
                        os.path.join(
                            opt.checkpoint_path, 'infos_' + opt.id +
                            format(int(save_id), '04') + '.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(
                        os.path.join(
                            opt.checkpoint_path, 'histories_' + opt.id +
                            format(int(save_id), '04') + '.pkl'), 'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos_' + opt.id + '-best.pkl'),
                            'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
예제 #18
0
 def __init__(self, model, opt):
     super(LossWrapper, self).__init__()
     self.opt = opt
     self.model = model
     self.crit = utils.LanguageModelCriterion()
def train(opt):
    opt.use_att = utils.if_use_att(opt.caption_model)
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tf_summary_writer = tf and tf.summary.FileWriter(opt.checkpoint_path)

    # log information
    folder_id = 'log_result'
    file_id = 'twin_show_attend_tell'
    log_file_name = os.path.join(folder_id, file_id + '.txt')
    log_file = open(log_file_name, 'w')

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from,
                               'infos_' + opt.id + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt)
    model.cuda()
    back_model = models.setup(opt, reverse=True)  # True for twin-net
    back_model.cuda()

    update_lr_flag = True
    # Assure in training mode
    model.train()
    back_model.train()

    crit = utils.LanguageModelCriterion()  # define the loss criterion
    all_param = chain(model.parameters(), back_model.parameters())
    optimizer = optim.Adam(all_param,
                           lr=opt.learning_rate,
                           weight_decay=opt.weight_decay)

    # Load the optimizer
    if vars(opt).get('start_from', None) is not None:
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    while True:
        if update_lr_flag:
            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
                utils.set_lr(optimizer, opt.current_lr)  # set the decayed rate
            else:
                opt.current_lr = opt.learning_rate
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob
            update_lr_flag = False

        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch('train')
        print('Read data:', time.time() - start)

        torch.cuda.synchronize()
        start = time.time()

        # flip the masks and labels for twin-net
        reverse_labels = np.flip(data['labels'], 1).copy()
        reverse_masks = np.flip(data['masks'], 1).copy()

        tmp = [
            data['fc_feats'], data['att_feats'], data['labels'],
            reverse_labels, data['masks'], reverse_masks
        ]
        tmp = [
            Variable(torch.from_numpy(_), requires_grad=False).cuda()
            for _ in tmp
        ]
        fc_feats, att_feats, labels, reverse_labels, masks, reverse_masks = tmp

        optimizer.zero_grad()
        out, states = model(fc_feats, att_feats, labels)
        back_out, back_states = back_model(fc_feats, att_feats, reverse_labels)
        idx = [i for i in range(back_states.size()[1] - 1, -1, -1)]
        # print (back_states.size(), back_states.size()[1])
        # print (type(idx))
        # print (idx)

        idx = torch.LongTensor(idx)
        idx = Variable(idx).cuda()
        invert_backstates = back_states.index_select(1, idx)

        # print (states.size(), back_states.size())

        # check if the back states are inverted
        # back = back_states.index_select(1, Variable(torch.LongTensor([2])).cuda())
        # forw = invert_backstates.index_select(1, Variable(torch.LongTensor([14])).cuda())
        # print (forw, back)
        # print (back_states.index_select(1, Variable(torch.LongTensor([3])).cuda()))
        # print (invert_backstates.size())

        loss = crit(out, labels[:, 1:],
                    masks[:, 1:])  # compute using the defined criterion

        back_loss = crit(back_out, reverse_labels[:, :-1],
                         reverse_masks[:, :-1])

        invert_backstates = invert_backstates.detach()
        l2_loss = ((states - invert_backstates)**2).mean()

        all_loss = loss + 1.5 * l2_loss + back_loss

        all_loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()

        # store the relevant values
        train_l2_loss = l2_loss.data[0]
        train_loss = loss.data[0]
        train_all_loss = all_loss.data[0]
        train_back_loss = back_loss.data[0]
        torch.cuda.synchronize()
        end = time.time()
        print("iter {} (epoch {}), train_loss = {:.3f}, l2_loss = {:.3f}, back_loss = {:.3f}, all_loss = {:.3f}, time/batch = {:.3f}" \
            .format(iteration, epoch, train_loss, train_l2_loss, train_back_loss, train_all_loss, end - start))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            if tf is not None:
                add_summary_value(tf_summary_writer, 'train_loss', train_loss,
                                  iteration)
                add_summary_value(tf_summary_writer, 'l2_loss', train_l2_loss,
                                  iteration)
                add_summary_value(tf_summary_writer, 'all_loss',
                                  train_all_loss, iteration)
                add_summary_value(tf_summary_writer, 'back_loss',
                                  train_back_loss, iteration)
                add_summary_value(tf_summary_writer, 'learning_rate',
                                  opt.current_lr, iteration)
                add_summary_value(tf_summary_writer, 'scheduled_sampling_prob',
                                  model.ss_prob, iteration)
                tf_summary_writer.flush()

            log_line = 'Epoch [%d], Step [%d], all loss: %f,back_loss %f,train_l2_loss %f, train_loss %f, time %f ' % (
                epoch, iteration, train_all_loss, train_back_loss,
                train_l2_loss, train_loss, time.clock())
            log_file.write(log_line + '\n')

            loss_history[iteration] = train_loss
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            # eval model
            eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = eval_utils.eval_split(
                model, crit, loader, eval_kwargs)

            # Write validation result into summary
            if tf is not None:
                add_summary_value(tf_summary_writer, 'validation loss',
                                  val_loss, iteration)
                for k, v in lang_stats.items():
                    add_summary_value(tf_summary_writer, k, v, iteration)
                tf_summary_writer.flush()
            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            best_flag = False
            if True:  # if true
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'infos_' + opt.id + '.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'histories_' + opt.id + '.pkl'),
                        'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos_' + opt.id + '-best.pkl'),
                            'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
예제 #20
0
def train(opt):
    opt.use_att = utils.if_use_att(opt.caption_model)
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tf_summary_writer = tf and tf.summary.FileWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from,
                               'infos_' + opt.id + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    modelT = Att2inModel(opt)
    if vars(opt).get('start_from', None) is not None:
        assert os.path.isdir(
            opt.start_from), " %s must be a a path" % opt.start_from
        assert os.path.isfile(
            os.path.join(opt.start_from, "infos_" + opt.id + ".pkl")
        ), "infos.pkl file does not exist in path %s" % opt.start_from
        modelT.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'model.pth')))
        modelT.cuda()

    modelS = Att2inModel(opt)
    if vars(opt).get('start_from', None) is not None:
        assert os.path.isdir(
            opt.start_from), " %s must be a a path" % opt.start_from
        assert os.path.isfile(
            os.path.join(opt.start_from, "infos_" + opt.id + ".pkl")
        ), "infos.pkl file does not exist in path %s" % opt.start_from
        modelS.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'model.pth')))
        modelS.cuda()

    logger = Logger(opt)

    update_lr_flag = True
    # Assure in training mode
    modelT.train()
    modelS.train()

    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()

    optimizer_S = optim.Adam(modelS.parameters(),
                             lr=opt.learning_rate,
                             weight_decay=opt.weight_decay)
    optimizer_T = optim.Adam(modelT.parameters(),
                             lr=opt.learning_rate,
                             weight_decay=opt.weight_decay)

    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(
            os.path.join(opt.start_from, "optimizer.pth")):
        optimizer_S.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    while True:
        if update_lr_flag:
            opt, sc_flag, update_lr_flag, modelS, optimizer_S = update_lr(
                opt, epoch, modelS, optimizer_S)
            opt, sc_flag, update_lr_flag, modelT, optimizer_T = update_lr(
                opt, epoch, modelT, optimizer_T)

        # Load data from train split (0)
        data = loader.get_batch('train', seq_per_img=opt.seq_per_img)

        torch.cuda.synchronize()
        start = time.time()

        tmp = [
            data['fc_feats'], data['att_feats'], data['labels'], data['masks']
        ]
        tmp = [
            Variable(torch.from_numpy(_), requires_grad=False).cuda()
            for _ in tmp
        ]
        fc_feats, att_feats, labels, masks = tmp

        optimizer_S.zero_grad()
        optimizer_T.zero_grad()
        if not sc_flag:
            loss = crit(modelS(fc_feats, labels), labels[:, 1:], masks[:, 1:])
            loss.backward()
        else:
            gen_result_S, sample_logprobs_S = modelS.sample(
                fc_feats, att_feats, {'sample_max': 0})
            reward_S = get_self_critical_reward_forTS(modelT, modelS, fc_feats,
                                                      att_feats, data,
                                                      gen_result_S, logger)

            gen_result_T, sample_logprobs_T = modelT.sample(
                fc_feats, att_feats, {'sample_max': 0})
            reward_T = get_self_critical_reward_forTS(modelS, modelT, fc_feats,
                                                      att_feats, data,
                                                      gen_result_T, logger)

            loss_S = rl_crit(
                sample_logprobs_S, gen_result_S,
                Variable(torch.from_numpy(reward_S).float().cuda(),
                         requires_grad=False))
            loss_T = rl_crit(
                sample_logprobs_T, gen_result_T,
                Variable(torch.from_numpy(reward_T).float().cuda(),
                         requires_grad=False))

            loss_S.backward()
            loss_T.backward()

            loss = loss_S + loss_T
            #reward = reward_S + reward_T

        utils.clip_gradient(optimizer_S, opt.grad_clip)
        utils.clip_gradient(optimizer_T, opt.grad_clip)
        optimizer_S.step()
        optimizer_T.step()
        train_loss = loss.data[0]

        torch.cuda.synchronize()
        end = time.time()

        if not sc_flag:
            log = "iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                .format(iteration, epoch, train_loss, end - start)
            logger.write(log)
        else:
            log = "iter {} (epoch {}), S_avg_reward = {:.3f}, T_avg_reward = {:.3f}, time/batch = {:.3f}" \
                .format(iteration,  epoch, np.mean(reward_S[:,0]), np.mean(reward_T[:,0]), end - start)
            logger.write(log)

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            if tf is not None:
                add_summary_value(tf_summary_writer, 'train_loss', train_loss,
                                  iteration)
                add_summary_value(tf_summary_writer, 'learning_rate',
                                  opt.current_lr, iteration)
                add_summary_value(tf_summary_writer, 'scheduled_sampling_prob',
                                  modelS.ss_prob, iteration)
                if sc_flag:
                    add_summary_value(tf_summary_writer, 'avg_reward_S',
                                      np.mean(reward_S[:, 0]), iteration)
                    add_summary_value(tf_summary_writer, 'avg_reward_T',
                                      np.mean(reward_T[:, 0]), iteration)
                tf_summary_writer.flush()

            loss_history[iteration] = train_loss if not sc_flag else np.mean(
                reward_S[:, 0] + reward_T[:, 0])
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = modelS.ss_prob

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            # eval model
            eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))

            val_loss, predictions, lang_stats = eval_utils.eval_split(
                modelS, crit, loader, logger, eval_kwargs)
            logger.write_dict(lang_stats)

            # Write validation result into summary
            if tf is not None:
                add_summary_value(tf_summary_writer, 'validation loss',
                                  val_loss, iteration)
                for k, v in lang_stats.items():
                    add_summary_value(tf_summary_writer, k, v, iteration)
                tf_summary_writer.flush()
            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            best_flag = False
            if True:  # if true
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'modelS.pth')
                torch.save(modelS.state_dict(), checkpoint_path)
                print("modelS saved to {}".format(checkpoint_path))
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'modelT.pth')
                torch.save(modelS.state_dict(), checkpoint_path)
                print("modelT saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'S_optimizer.pth')
                torch.save(optimizer_S.state_dict(), optimizer_path)
                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'T_optimizer.pth')
                torch.save(optimizer_T.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'infos_' + opt.id + '.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'histories_' + opt.id + '.pkl'),
                        'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'modelS-best.pth')
                    torch.save(modelS.state_dict(), checkpoint_path)
                    print("modelS saved to {}".format(checkpoint_path))
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'modelT-best.pth')
                    torch.save(modelT.state_dict(), checkpoint_path)
                    print("modelT saved to {}".format(checkpoint_path))
                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos_' + opt.id + '-best.pkl'),
                            'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
def train(opt):
    opt.use_att = True#utils.if_use_att(opt.caption_model)
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tf_summary_writer = tf and tf.summary.FileWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from_S is not None:
        if os.path.isfile(os.path.join(opt.start_from_S, 'histories_'+opt.id+'.pkl')):
            with open(os.path.join(opt.start_from_S, 'histories_'+opt.id+'.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    # Set CommNetModel
    model1 = ShowTellModel(opt)
    model2 = Att2inModel(opt)
    model1.load_state_dict(torch.load(os.path.join(opt.start_from_T, 'model.pth')))
    model2.load_state_dict(torch.load(os.path.join(opt.start_from_S, 'model.pth')))
    model1.cuda()
    model2.cuda()

    model = CascadeNetModel(opt, model1, model2)
    #model.load_state_dict(torch.load('/home/vdo-gt/_code/caption.selfcritic.gan/experiment/20171110_231524/model-best.pth'))
    model.cuda()
    logger = Logger(opt)

    update_lr_flag = True
    model.train()

    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()

    optimizer = optim.Adam(model.parameters(), lr=opt.learning_rate, weight_decay=opt.weight_decay)

    while True:
        if update_lr_flag:
            opt, sc_flag, update_lr_flag, model, optimizer = update_lr(opt, epoch, model, optimizer)

        # Load data from train split (0)
        data = loader.get_batch('train', seq_per_img=opt.seq_per_img)

        torch.cuda.synchronize()
        start = time.time()

        tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks']]
        tmp = [Variable(torch.from_numpy(_), requires_grad=False).cuda() for _ in tmp]
        fc_feats, att_feats, labels, masks = tmp

        optimizer.zero_grad()

        if not sc_flag:
            out1, out2 = model(fc_feats, att_feats, labels)
            loss_1, loss_2 = crit(out1, labels[:,1:], masks[:,1:]), crit(out2, labels[:,1:], masks[:,1:])
            loss = loss_1 + loss_2
            loss.backward()
        else:
            #out1, out2 = model(fc_feats, att_feats, labels)
            #loss_1 = crit(out1, labels[:,1:], masks[:,1:])

            gen_result_1, sample_logprobs_1, gen_result_2, sample_logprobs_2 = model.sample(fc_feats, att_feats, {'sample_max': 0}, mode='sc')
            reward_2 = get_self_critical_reward_forCascade(model, fc_feats, att_feats, data, gen_result_1, gen_result_2, logger)

            loss_2 = rl_crit(sample_logprobs_2, gen_result_2, Variable(torch.from_numpy(reward_2).float().cuda(), requires_grad=False))

            loss = loss_2#loss_1 + loss_2
            loss.backward()

        utils.clip_gradient_2(optimizer, opt.grad_clip)
        optimizer.step()

        train_loss = loss.data[0]
        torch.cuda.synchronize()
        end = time.time()

        if not sc_flag:
            log = "iter {} (epoch {}), loss_1 = {:.3f}, loss_2 = {:.3f}, time/batch = {:.3f}" \
                .format(iteration, epoch, loss_1.data[0], loss_2.data[0], end - start)
            logger.write(log)
        else:
            log = "iter {} (epoch {}), loss_2(sc) = {:.3f}, avg_reward = {:.3f}, time/batch = {:.3f}" \
                .format(iteration,  epoch, loss_2.data[0], np.mean(reward_2[:,0]), end - start)
            logger.write(log)

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            if tf is not None:
                add_summary_value(tf_summary_writer, 'train_loss', train_loss, iteration)
                add_summary_value(tf_summary_writer, 'learning_rate', opt.current_lr, iteration)
                add_summary_value(tf_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration)
                if sc_flag:
                    add_summary_value(tf_summary_writer, 'avg_reward', np.mean(reward_2[:,0]), iteration)
                tf_summary_writer.flush()

            loss_history[iteration] = train_loss if not sc_flag else np.mean(reward_2[:,0])
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            # eval model
            eval_kwargs = {'split': 'val','dataset': opt.input_json}
            eval_kwargs.update(vars(opt))

            val_loss, predictions, lang_stats = eval_utils_for_FNet.eval_split(model, crit, loader, logger, eval_kwargs)
            logger.write_dict(lang_stats)

            # Write validation result into summary
            if tf is not None:
                add_summary_value(tf_summary_writer, 'validation loss', val_loss, iteration)
                for k,v in lang_stats.items():
                    add_summary_value(tf_summary_writer, k, v, iteration)
                tf_summary_writer.flush()
            val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions}

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = - val_loss

            best_flag = False
            if True: # if true
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(os.path.join(opt.checkpoint_path, 'histories_'+opt.id+'.pkl'), 'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path, 'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'-best.pkl'), 'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
예제 #22
0
def train(opt):
    opt.use_att = utils.if_use_att(opt.caption_model)
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tf_summary_writer = tf and tf.summary.FileWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from,
                               'infos_' + opt.id + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt)
    model.cuda()

    update_lr_flag = True
    model.train()

    crit = utils.LanguageModelCriterion()

    optimizer = optim.Adam(model.parameters(),
                           lr=opt.learning_rate,
                           weight_decay=opt.weight_decay)

    # Load the optimizer
    if vars(opt).get('start_from', None) is not None:
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    while True:
        if update_lr_flag:
            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
                utils.set_lr(optimizer, opt.current_lr)  # set the decayed rate
            else:
                opt.current_lr = opt.learning_rate
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob
            update_lr_flag = False

        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch('train')
        print('Read data:', time.time() - start)

        torch.cuda.synchronize()
        start = time.time()

        tmp = [
            data['fc_feats'], data['att_feats'], data['labels'], data['masks']
        ]
        tmp = [
            Variable(torch.from_numpy(_), requires_grad=False).cuda()
            for _ in tmp
        ]
        fc_feats, att_feats, labels, masks = tmp

        optimizer.zero_grad()
        loss = crit(model(fc_feats, att_feats, labels), labels[:, 1:],
                    masks[:, 1:])
        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        train_loss = loss.data[0]
        torch.cuda.synchronize()
        end = time.time()
        print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
            .format(iteration, epoch, train_loss, end - start))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            if tf is not None:
                add_summary_value(tf_summary_writer, 'train_loss', train_loss,
                                  iteration)
                add_summary_value(tf_summary_writer, 'learning_rate',
                                  opt.current_lr, iteration)
                add_summary_value(tf_summary_writer, 'scheduled_sampling_prob',
                                  model.ss_prob, iteration)
                tf_summary_writer.flush()

            loss_history[iteration] = train_loss
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # Stop if reaching max epochs
        if epoch >= 8:
            break
예제 #23
0
def train(opt):
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    infos = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from,
                               'infos_' + opt.id + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)
    val_result_history = infos.get('val_result_history', {})
    loss_history = infos.get('loss_history', {})
    lr_history = infos.get('lr_history', {})
    ss_prob_history = infos.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    cnn_model = utils.build_cnn(opt)
    cnn_model.cuda()
    model = models.setup(opt)
    model.cuda()

    update_lr_flag = True
    # Assure in training mode
    model.train()

    crit = utils.LanguageModelCriterion()

    optimizer = optim.Adam(model.parameters(), lr=opt.learning_rate)
    cnn_optimizer = optim.Adam(cnn_model.parameters(),
                               lr=opt.cnn_learning_rate,
                               weight_decay=opt.cnn_weight_decay)

    # Load the optimizer
    if vars(opt).get('start_from', None) is not None:
        if os.path.isfile(os.path.join(opt.start_from, 'optimizer.pth')):
            optimizer.load_state_dict(
                torch.load(os.path.join(opt.start_from, 'optimizer.pth')))
        if os.path.isfile(os.path.join(opt.start_from, 'optimizer-cnn.pth')):
            cnn_optimizer.load_state_dict(
                torch.load(os.path.join(opt.start_from, 'optimizer-cnn.pth')))

    while True:
        if update_lr_flag:
            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
                utils.set_lr(optimizer, opt.current_lr)  # set the decayed rate
            else:
                opt.current_lr = opt.learning_rate
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob
            # Update the training stage of cnn
            if opt.finetune_cnn_after == -1 or epoch < opt.finetune_cnn_after:
                for p in cnn_model.parameters():
                    p.requires_grad = False
                cnn_model.eval()
            else:
                for p in cnn_model.parameters():
                    p.requires_grad = True
                cnn_model.train()
            update_lr_flag = False

        torch.cuda.synchronize()
        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch('train')
        torch.cuda.synchronize()
        print('Read data:', time.time() - start)

        torch.cuda.synchronize()
        start = time.time()

        tmp = [data['images'], data['labels'], data['masks']]
        tmp = [
            Variable(torch.from_numpy(_), requires_grad=False).cuda()
            for _ in tmp
        ]
        images, labels, masks = tmp

        att_feats = cnn_model(images)
        fc_feats = att_feats.mean(2).mean(3).squeeze(2).squeeze(2)

        att_feats = att_feats.unsqueeze(1).expand(*((
            att_feats.size(0),
            opt.seq_per_img,
        ) + att_feats.size()[1:])).contiguous().view(
            *((att_feats.size(0) * opt.seq_per_img, ) + att_feats.size()[1:]))
        fc_feats = fc_feats.unsqueeze(1).expand(*((
            fc_feats.size(0),
            opt.seq_per_img,
        ) + fc_feats.size()[1:])).contiguous().view(
            *((fc_feats.size(0) * opt.seq_per_img, ) + fc_feats.size()[1:]))

        optimizer.zero_grad()
        if opt.finetune_cnn_after != -1 and epoch >= opt.finetune_cnn_after:
            cnn_optimizer.zero_grad()
        loss = crit(model(fc_feats, att_feats, labels), labels[:, 1:],
                    masks[:, 1:])
        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        if opt.finetune_cnn_after != -1 and epoch >= opt.finetune_cnn_after:
            utils.clip_gradient(cnn_optimizer, opt.grad_clip)
            cnn_optimizer.step()
        train_loss = loss.data[0]
        torch.cuda.synchronize()
        end = time.time()
        print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
            .format(iteration, epoch, train_loss, end - start))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            loss_history[iteration] = train_loss
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            # eval model
            eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = eval_utils.eval_split(
                cnn_model, model, crit, loader, eval_kwargs)

            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            best_flag = False
            if True:  # if true
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model.pth')
                cnn_checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-cnn.pth')
                torch.save(model.state_dict(), checkpoint_path)
                torch.save(cnn_model.state_dict(), cnn_checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                print("cnn model saved to {}".format(cnn_checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'optimizer.pth')
                cnn_optimizer_path = os.path.join(opt.checkpoint_path,
                                                  'optimizer-cnn.pth')
                torch.save(optimizer.state_dict(), optimizer_path)
                torch.save(cnn_optimizer.state_dict(), cnn_optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['val_result_history'] = val_result_history
                infos['loss_history'] = loss_history
                infos['lr_history'] = lr_history
                infos['ss_prob_history'] = ss_prob_history
                infos['vocab'] = loader.get_vocab()
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'infos_' + opt.id + '.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-best.pth')
                    cnn_checkpoint_path = os.path.join(opt.checkpoint_path,
                                                       'model-cnn-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    torch.save(cnn_model.state_dict(), cnn_checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    print("cnn model saved to {}".format(cnn_checkpoint_path))
                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos_' + opt.id + '-best.pkl'),
                            'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
예제 #24
0
def train(opt):
    opt.use_att = utils.if_use_att(opt.caption_model)
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.maxlen_sen
    opt.inc_seg = loader.inc_seg
    opt.seg_ix = loader.seg_ix
    tf_summary_writer = tf and tf.summary.FileWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    score_list = []
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from,
                               'infos_' + opt.id + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    best_val_score = None
    best_val_score = {}
    score_splits = ['val', 'test']
    score_type = ['Bleu_4', 'METEOR', 'CIDEr']
    for split_i in score_splits:
        for score_item in score_type:
            if split_i not in best_val_score.keys():
                best_val_score[split_i] = {}
            best_val_score[split_i][score_item] = 0.0
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', best_val_score)

    model = models.setup(opt)
    device_ids = [0, 1]

    torch.cuda.set_device(device_ids[0])
    model = nn.DataParallel(model, device_ids=device_ids)
    model = model.cuda()
    update_lr_flag = True
    # Assure in training mode
    model.module.train()
    crit = utils.LanguageModelCriterion()

    optimizer = optim.Adam(model.module.parameters(),
                           lr=opt.learning_rate,
                           weight_decay=opt.weight_decay)
    #optimizer = nn.DataParallel(optimizer, device_ids=device_ids)

    # Load the optimizer
    if vars(opt).get('start_from', None) is not None:
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    while True:
        if update_lr_flag:
            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
                utils.set_lr(optimizer, opt.current_lr)  # set the decayed rate
            else:
                opt.current_lr = opt.learning_rate
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.module.ss_prob = opt.ss_prob
            update_lr_flag = False

        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch('train')
        print('Read data:', time.time() - start)

        torch.cuda.synchronize()
        start = time.time()

        tmp = [data['fc_feats'], data['labels'], data['x_phrase_mask_0'], data['x_phrase_mask_1'], \
               data['label_masks'], data['salicy_seg'], data['seg_mask']]
        tmp = [
            Variable(torch.from_numpy(_), requires_grad=False).cuda()
            for _ in tmp
        ]
        fc_feats, seq, phrase_mask_0, phrase_mask_1, masks, salicy_seg, seg_mask = tmp

        optimizer.zero_grad()
        remove_len = 2
        outputs, alphas = model.module(fc_feats, seq, phrase_mask_0,
                                       phrase_mask_1, masks, seg_mask,
                                       remove_len)
        loss = crit(outputs, seq[remove_len:, :].permute(1, 0),
                    masks[remove_len:, :].permute(1, 0))
        alphas = alphas.permute(1, 0, 2)
        salicy_seg = salicy_seg[:, :, :]
        seg_mask = seg_mask[:, :]
        if opt.salicy_hard == False:
            if opt.salicy_loss_type == 'l2':
                salicy_loss = (((((salicy_seg * seg_mask[:, :, None] -
                                   alphas * seg_mask[:, :, None])**2).sum(0)
                                 ).sum(-1))**(0.5)).mean()
            if opt.salicy_loss_type == 'kl':
                #alphas: len_sen, batch_size, num_frame
                salicy_loss = kullback_leibler2(
                    alphas * seg_mask[:, :, None],
                    salicy_seg * seg_mask[:, :, None])
                salicy_loss = (((salicy_loss *
                                 seg_mask[:, :, None]).sum(-1)).sum(0)).mean()
        elif opt.salicy_hard == True:
            #salicy len_sen, batch_size, num_frame
            salicy_loss = -torch.log((alphas * salicy_seg).sum(-1) + 1e-8)
            #salicy_loss len_sen, batch_size
            salicy_loss = ((salicy_loss * seg_mask).sum(0)).mean()
        loss = loss + opt.salicy_alpha * salicy_loss
        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        train_loss = loss.data[0]
        torch.cuda.synchronize()
        end = time.time()
        print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
            .format(iteration, epoch, train_loss, end - start))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            if tf is not None:
                add_summary_value(tf_summary_writer, 'train_loss', train_loss,
                                  iteration)
                add_summary_value(tf_summary_writer, 'learning_rate',
                                  opt.current_lr, iteration)
                add_summary_value(tf_summary_writer, 'scheduled_sampling_prob',
                                  model.module.ss_prob, iteration)
                tf_summary_writer.flush()

            loss_history[iteration] = train_loss
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.module.ss_prob

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            # eval model
            eval_kwargs = {
                'split': 'val',
                'dataset': opt.dataset,
                'remove_len': remove_len
            }
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats, score_list_i = eval_utils.eval_split(
                model.module, crit, loader, eval_kwargs)
            score_list.append(score_list_i)
            np.savetxt('./save/train_valid_test.txt', score_list, fmt='%.3f')
            # Write validation result into summary
            if tf is not None:
                add_summary_value(tf_summary_writer, 'validation loss',
                                  val_loss, iteration)
                for k in lang_stats.keys():
                    for v in lang_stats[k].keys():
                        add_summary_value(tf_summary_writer, k + v,
                                          lang_stats[k][v], iteration)
                tf_summary_writer.flush()
            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['val']['CIDEr']
            else:
                current_score = -val_loss
            best_flag = {}
            for split_i in score_splits:
                for score_item in score_type:
                    if split_i not in best_flag.keys():
                        best_flag[split_i] = {}
                    best_flag[split_i][score_item] = False
            if True:  # if true
                for split_i in score_splits:
                    for score_item in score_type:
                        if best_val_score is None or lang_stats[split_i][
                                score_item] > best_val_score[split_i][
                                    score_item]:
                            best_val_score[split_i][score_item] = lang_stats[
                                split_i][score_item]
                            best_flag[split_i][score_item] = True

                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model.pth')
                torch.save(model.module.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'infos_' + opt.id + '.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'histories_' + opt.id + '.pkl'),
                        'wb') as f:
                    cPickle.dump(histories, f)

                for split_i in score_splits:
                    for score_item in score_type:
                        if best_flag[split_i][score_item]:
                            checkpoint_path = os.path.join(
                                opt.checkpoint_path, 'model-best_' + split_i +
                                '_' + score_item + '.pth')
                            torch.save(model.module.state_dict(),
                                       checkpoint_path)
                            print("model saved to {}".format(checkpoint_path))
                            with open(
                                    os.path.join(
                                        opt.checkpoint_path,
                                        'infos_' + split_i + '_' + score_item +
                                        '_' + opt.id + '-best.pkl'),
                                    'wb') as f:
                                cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
예제 #25
0
def train(opt):
    opt.use_att = utils.if_use_att(opt.caption_model)
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tf_summary_writer = tf and tf.summary.FileWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from,
                               'infos_' + opt.id + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt)
    model.cuda()

    #model_D = Discriminator(opt)
    #model_D.load_state_dict(torch.load('save/model_D.pth'))
    #model_D.cuda()
    #criterion_D = nn.CrossEntropyLoss(size_average=True)

    model_E = Distance(opt)
    model_E.load_state_dict(
        torch.load('save/model_E_NCE/model_E_10epoch.pthsfdasdfadf'))
    model_E.cuda()
    criterion_E = nn.CosineEmbeddingLoss(margin=0, size_average=True)
    #criterion_E = nn.CosineSimilarity()

    logger = Logger(opt)

    update_lr_flag = True
    # Assure in training mode
    model.train()
    #model_D.train()

    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()

    optimizer_G = optim.Adam(model.parameters(),
                             lr=opt.learning_rate,
                             weight_decay=opt.weight_decay)
    #optimizer_D = optim.Adam(model_D.parameters(), lr=opt.learning_rate, weight_decay=opt.weight_decay)

    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(
            os.path.join(opt.start_from, "optimizer.pth")):
        optimizer_G.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    while True:
        if update_lr_flag:
            opt, sc_flag, update_lr_flag, model, optimizer_G = update_lr(
                opt, epoch, model, optimizer_G)

        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch('train')
        #print('Read data:', time.time() - start)

        torch.cuda.synchronize()
        start = time.time()

        #tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks']]
        tmp = [data['fc_feats'], data['labels'], data['masks']]
        tmp = [
            Variable(torch.from_numpy(_), requires_grad=False).cuda()
            for _ in tmp
        ]
        #fc_feats, att_feats, labels, masks = tmp
        fc_feats, labels, masks = tmp

        ############################################################################################################
        ############################################ REINFORCE TRAINING ############################################
        ############################################################################################################
        if 1:  #iteration % opt.D_scheduling != 0:
            optimizer_G.zero_grad()
            if not sc_flag:
                loss = crit(model(fc_feats, labels), labels[:, 1:], masks[:,
                                                                          1:])
            else:
                gen_result, sample_logprobs = model.sample(
                    fc_feats, {'sample_max': 0})
                #reward = get_self_critical_reward(model, fc_feats, att_feats, data, gen_result)
                sc_reward = get_self_critical_reward(model, fc_feats, data,
                                                     gen_result, logger)
                #gan_reward = get_gan_reward(model, model_D, criterion_D, fc_feats, data, logger)                 # Criterion_D = nn.XEloss()
                distance_loss_reward1 = get_distance_reward(
                    model,
                    model_E,
                    criterion_E,
                    fc_feats,
                    data,
                    logger,
                    is_mismatched=False)  # criterion_E = nn.CosEmbedLoss()
                distance_loss_reward2 = get_distance_reward(
                    model,
                    model_E,
                    criterion_E,
                    fc_feats,
                    data,
                    logger,
                    is_mismatched=True)  # criterion_E = nn.CosEmbedLoss()
                #cosine_reward = get_distance_reward(model, model_E, criterion_E, fc_feats, data, logger)         # criterion_E = nn.CosSim()
                reward = distance_loss_reward1 + distance_loss_reward2
                loss = rl_crit(
                    sample_logprobs, gen_result,
                    Variable(torch.from_numpy(reward).float().cuda(),
                             requires_grad=False))
                loss.backward()

            utils.clip_gradient(optimizer_G, opt.grad_clip)
            optimizer_G.step()
            train_loss = loss.data[0]
            torch.cuda.synchronize()
            end = time.time()

            if not sc_flag:
                log = "iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, train_loss, end - start)
                logger.write(log)
            else:
                log = "iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration,  epoch, np.mean(reward[:,0]), end - start)
                logger.write(log)

        ######################################################################################################
        ############################################ GAN TRAINING ############################################
        ######################################################################################################
        else:  #elif iteration % opt.D_scheduling == 0: # gan training
            model_D.zero_grad()
            optimizer_D.zero_grad()

            fc_feats_temp = Variable(fc_feats.data.cpu(), volatile=True).cuda()
            labels = Variable(labels.data.cpu()).cuda()

            sample_res, sample_logprobs = model.sample(
                fc_feats_temp, {'sample_max': 0})  #640, 16
            greedy_res, greedy_logprobs = model.sample(
                fc_feats_temp, {'sample_max': 1})  #640, 16
            gt_res = labels  # 640, 18

            sample_res_embed = model.embed(Variable(sample_res))
            greedy_res_embed = model.embed(Variable(greedy_res))
            gt_res_embed = model.embed(gt_res)

            f_label = Variable(
                torch.FloatTensor(data['fc_feats'].shape[0]).cuda())
            r_label = Variable(
                torch.FloatTensor(data['fc_feats'].shape[0]).cuda())
            f_label.data.fill_(0)
            r_label.data.fill_(1)

            f_D_output = model_D(sample_res_embed.detach(), fc_feats.detach())
            f_loss = criterion_D(f_D_output, f_label.long())
            f_loss.backward()

            r_D_output = model_D(gt_res_embed.detach(), fc_feats.detach())
            r_loss = criterion_D(r_D_output, r_label.long())
            r_loss.backward()

            D_loss = f_loss + r_loss
            optimizer_D.step()
            torch.cuda.synchronize()

            log = 'iter {} (epoch {}),  Discriminator loss : {}'.format(
                iteration, epoch,
                D_loss.data.cpu().numpy()[0])
            logger.write(log)

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            if tf is not None:
                add_summary_value(tf_summary_writer, 'train_loss', train_loss,
                                  iteration)
                add_summary_value(tf_summary_writer, 'learning_rate',
                                  opt.current_lr, iteration)
                add_summary_value(tf_summary_writer, 'scheduled_sampling_prob',
                                  model.ss_prob, iteration)
                if sc_flag:
                    add_summary_value(tf_summary_writer, 'avg_reward',
                                      np.mean(reward[:, 0]), iteration)
                tf_summary_writer.flush()

            loss_history[iteration] = train_loss if not sc_flag else np.mean(
                reward[:, 0])
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            # eval model
            eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))

            val_loss, predictions, lang_stats = eval_utils.eval_split(
                model, crit, loader, logger, eval_kwargs)
            logger.write_dict(lang_stats)

            # Write validation result into summary
            if tf is not None:
                add_summary_value(tf_summary_writer, 'validation loss',
                                  val_loss, iteration)
                for k, v in lang_stats.items():
                    add_summary_value(tf_summary_writer, k, v, iteration)
                tf_summary_writer.flush()
            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            best_flag = False
            if True:  # if true
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'optimizer.pth')
                torch.save(optimizer_G.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'infos_' + opt.id + '.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'histories_' + opt.id + '.pkl'),
                        'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos_' + opt.id + '-best.pkl'),
                            'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
예제 #26
0
파일: train.py 프로젝트: cxqj/ECHR
def train(opt):
    exclude_opt = [
        'training_mode', 'tap_epochs', 'cg_epochs', 'tapcg_epochs', 'lr',
        'learning_rate_decay_start', 'learning_rate_decay_every',
        'learning_rate_decay_rate', 'self_critical_after',
        'save_checkpoint_every', 'id', "pretrain", "pretrain_path", "debug",
        "save_all_checkpoint", "min_epoch_when_save"
    ]

    save_folder, logger, tf_writer = build_floder_and_create_logger(opt)
    saved_info = {'best': {}, 'last': {}, 'history': {}}
    is_continue = opt.start_from != None

    if is_continue:
        infos_path = os.path.join(save_folder, 'info.pkl')
        with open(infos_path) as f:
            logger.info('load info from {}'.format(infos_path))
            saved_info = cPickle.load(f)
            pre_opt = saved_info[opt.start_from_mode]['opt']
            if vars(opt).get("no_exclude_opt", False):
                exclude_opt = []
            for opt_name in vars(pre_opt).keys():
                if (not opt_name in exclude_opt):
                    vars(opt).update({opt_name: vars(pre_opt).get(opt_name)})
                if vars(pre_opt).get(opt_name) != vars(opt).get(opt_name):
                    print('change opt: {} from {} to {}'.format(
                        opt_name,
                        vars(pre_opt).get(opt_name),
                        vars(opt).get(opt_name)))

    opt.use_att = utils.if_use_att(opt.caption_model)
    loader = DataLoader(opt)
    opt.CG_vocab_size = loader.vocab_size
    opt.CG_seq_length = loader.seq_length

    # init training option
    epoch = saved_info[opt.start_from_mode].get('epoch', 0)
    iteration = saved_info[opt.start_from_mode].get('iter', 0)
    best_val_score = saved_info[opt.start_from_mode].get('best_val_score', 0)
    val_result_history = saved_info['history'].get('val_result_history', {})
    loss_history = saved_info['history'].get('loss_history', {})
    lr_history = saved_info['history'].get('lr_history', {})
    loader.iterators = saved_info[opt.start_from_mode].get(
        'iterators', loader.iterators)
    loader.split_ix = saved_info[opt.start_from_mode].get(
        'split_ix', loader.split_ix)
    opt.current_lr = vars(opt).get('current_lr', opt.lr)
    opt.m_batch = vars(opt).get('m_batch', 1)

    # create a tap_model,fusion_model,cg_model

    tap_model = models.setup_tap(opt)
    lm_model = CaptionGenerator(opt)
    cg_model = lm_model

    if is_continue:
        if opt.start_from_mode == 'best':
            model_pth = torch.load(os.path.join(save_folder, 'model-best.pth'))
        elif opt.start_from_mode == 'last':
            model_pth = torch.load(
                os.path.join(save_folder,
                             'model_iter_{}.pth'.format(iteration)))
        assert model_pth['iteration'] == iteration
        logger.info('Loading pth from {}, iteration:{}'.format(
            save_folder, iteration))
        tap_model.load_state_dict(model_pth['tap_model'])
        cg_model.load_state_dict(model_pth['cg_model'])

    elif opt.pretrain:
        print('pretrain {} from {}'.format(opt.pretrain, opt.pretrain_path))
        model_pth = torch.load(opt.pretrain_path)
        if opt.pretrain == 'tap':
            tap_model.load_state_dict(model_pth['tap_model'])
        elif opt.pretrain == 'cg':
            cg_model.load_state_dict(model_pth['cg_model'])
        elif opt.pretrain == 'tap_cg':
            tap_model.load_state_dict(model_pth['tap_model'])
            cg_model.load_state_dict(model_pth['cg_model'])
        else:
            assert 1 == 0, 'opt.pretrain error'

    tap_model.cuda()
    tap_model.train()  # Assure in training mode

    tap_crit = utils.TAPModelCriterion()

    tap_optimizer = optim.Adam(tap_model.parameters(),
                               lr=opt.lr,
                               weight_decay=opt.weight_decay)

    cg_model.cuda()
    cg_model.train()
    cg_optimizer = optim.Adam(cg_model.parameters(),
                              lr=opt.lr,
                              weight_decay=opt.weight_decay)
    cg_crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()
    cg_optimizer = optim.Adam(cg_model.parameters(),
                              lr=opt.lr,
                              weight_decay=opt.weight_decay)

    allmodels = [tap_model, cg_model]
    optimizers = [tap_optimizer, cg_optimizer]

    if is_continue:
        tap_optimizer.load_state_dict(model_pth['tap_optimizer'])
        cg_optimizer.load_state_dict(model_pth['cg_optimizer'])

    update_lr_flag = True
    loss_sum = np.zeros(5)
    bad_video_num = 0
    best_epoch = epoch
    start = time.time()

    print_opt(opt, allmodels, logger)
    logger.info('\nStart training')

    # set a var to indicate what to train in current iteration: "tap", "cg" or "tap_cg"
    flag_training_whats = get_training_list(opt, logger)

    # Iteration begin
    while True:
        if update_lr_flag:
            if (epoch > opt.learning_rate_decay_start
                    and opt.learning_rate_decay_start >= 0):
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.lr * decay_factor
            else:
                opt.current_lr = opt.lr
            for optimizer in optimizers:
                utils.set_lr(optimizer, opt.current_lr)
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(None)
            else:
                sc_flag = False
            update_lr_flag = False

        flag_training_what = flag_training_whats[epoch]
        if opt.training_mode == "alter2":
            flag_training_what = flag_training_whats[iteration]

        # get data
        data = loader.get_batch('train')

        if opt.debug:
            print('vid:', data['vid'])
            print('info:', data['infos'])

        torch.cuda.synchronize()

        if (data["proposal_num"] <= 0) or (data['fc_feats'].shape[0] <= 1):
            bad_video_num += 1  # print('vid:{} has no good proposal.'.format(data['vid']))
            continue

        ind_select_list, soi_select_list, cg_select_list, sampled_ids, = data[
            'ind_select_list'], data['soi_select_list'], data[
                'cg_select_list'], data['sampled_ids']

        if flag_training_what == 'cg' or flag_training_what == 'gt_tap_cg':
            ind_select_list = data['gts_ind_select_list']
            soi_select_list = data['gts_soi_select_list']
            cg_select_list = data['gts_cg_select_list']

        tmp = [
            data['fc_feats'], data['att_feats'], data['lda_feats'],
            data['tap_labels'], data['tap_masks_for_loss'],
            data['cg_labels'][cg_select_list],
            data['cg_masks'][cg_select_list], data['w1']
        ]

        tmp = [
            Variable(torch.from_numpy(_), requires_grad=False).cuda()
            for _ in tmp
        ]

        c3d_feats, att_feats, lda_feats, tap_labels, tap_masks_for_loss, cg_labels, cg_masks, w1 = tmp

        if (iteration - 1) % opt.m_batch == 0:
            tap_optimizer.zero_grad()
            cg_optimizer.zero_grad()

        tap_feats, pred_proposals = tap_model(c3d_feats)
        tap_loss = tap_crit(pred_proposals, tap_masks_for_loss, tap_labels, w1)

        loss_sum[0] = loss_sum[0] + tap_loss.item()

        # Backward Propagation
        if flag_training_what == 'tap':
            tap_loss.backward()
            utils.clip_gradient(tap_optimizer, opt.grad_clip)
            if iteration % opt.m_batch == 0:
                tap_optimizer.step()
        else:
            if not sc_flag:
                pred_captions = cg_model(tap_feats,
                                         c3d_feats,
                                         lda_feats,
                                         cg_labels,
                                         ind_select_list,
                                         soi_select_list,
                                         mode='train')
                cg_loss = cg_crit(pred_captions, cg_labels[:, 1:],
                                  cg_masks[:, 1:])

            else:
                gen_result, sample_logprobs, greedy_res = cg_model(
                    tap_feats,
                    c3d_feats,
                    lda_feats,
                    cg_labels,
                    ind_select_list,
                    soi_select_list,
                    mode='train_rl')
                sentence_info = data['sentences_batch'] if (
                    flag_training_what != 'cg'
                    and flag_training_what != 'gt_tap_cg'
                ) else data['gts_sentences_batch']

                reward = get_self_critical_reward2(
                    greedy_res, (data['vid'], sentence_info),
                    gen_result,
                    vocab=loader.get_vocab(),
                    opt=opt)
                cg_loss = rl_crit(sample_logprobs, gen_result,
                                  torch.from_numpy(reward).float().cuda())

            loss_sum[1] = loss_sum[1] + cg_loss.item()

            if flag_training_what == 'cg' or flag_training_what == 'gt_tap_cg' or flag_training_what == 'LP_cg':
                cg_loss.backward()

                utils.clip_gradient(cg_optimizer, opt.grad_clip)
                if iteration % opt.m_batch == 0:
                    cg_optimizer.step()
                if flag_training_what == 'gt_tap_cg':
                    utils.clip_gradient(tap_optimizer, opt.grad_clip)
                    if iteration % opt.m_batch == 0:
                        tap_optimizer.step()
            elif flag_training_what == 'tap_cg':
                total_loss = opt.lambda1 * tap_loss + opt.lambda2 * cg_loss
                total_loss.backward()
                utils.clip_gradient(tap_optimizer, opt.grad_clip)
                utils.clip_gradient(cg_optimizer, opt.grad_clip)
                if iteration % opt.m_batch == 0:
                    tap_optimizer.step()
                    cg_optimizer.step()

                loss_sum[2] = loss_sum[2] + total_loss.item()

        torch.cuda.synchronize()

        # Updating epoch num
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Print losses, Add to summary
        if iteration % opt.losses_log_every == 0:
            end = time.time()
            losses = np.round(loss_sum / opt.losses_log_every, 3)
            logger.info(
                "iter {} (epoch {}, lr {}), avg_iter_loss({}) = {}, time/batch = {:.3f}, bad_vid = {:.3f}" \
                    .format(iteration, epoch, opt.current_lr, flag_training_what, losses,
                            (end - start) / opt.losses_log_every,
                            bad_video_num))

            tf_writer.add_scalar('lr', opt.current_lr, iteration)
            tf_writer.add_scalar('train_tap_loss', losses[0], iteration)
            tf_writer.add_scalar('train_tap_prop_loss', losses[3], iteration)
            tf_writer.add_scalar('train_tap_bound_loss', losses[4], iteration)
            tf_writer.add_scalar('train_cg_loss', losses[1], iteration)
            tf_writer.add_scalar('train_total_loss', losses[2], iteration)
            if sc_flag and (not flag_training_what == 'tap'):
                tf_writer.add_scalar('avg_reward', np.mean(reward[:, 0]),
                                     iteration)
            loss_history[iteration] = losses
            lr_history[iteration] = opt.current_lr
            loss_sum = np.zeros(5)
            start = time.time()
            bad_video_num = 0

        # Evaluation, and save model
        if (iteration % opt.save_checkpoint_every
                == 0) and (epoch >= opt.min_epoch_when_save):
            eval_kwargs = {
                'split': 'val',
                'val_all_metrics': 0,
                'topN': 100,
            }

            eval_kwargs.update(vars(opt))

            # eval_kwargs['num_vids_eval'] = int(491)
            eval_kwargs['topN'] = 100

            eval_kwargs2 = {
                'split': 'val',
                'val_all_metrics': 1,
                'num_vids_eval': 4917,
            }
            eval_kwargs2.update(vars(opt))

            if not opt.num_vids_eval:
                eval_kwargs['num_vids_eval'] = int(4917.)
                eval_kwargs2['num_vids_eval'] = 4917

            crits = [tap_crit, cg_crit]
            pred_json_path_T = os.path.join(save_folder, 'pred_sent',
                                            'pred_num{}_iter{}.json')

            # if 'alter' in opt.training_mode:
            if flag_training_what == 'tap':
                eval_kwargs['topN'] = 1000
                predictions, eval_score, val_loss = eval_utils.eval_split(
                    allmodels,
                    crits,
                    loader,
                    pred_json_path_T.format(eval_kwargs['num_vids_eval'],
                                            iteration),
                    eval_kwargs,
                    flag_eval_what='tap')
            else:
                if vars(opt).get('fast_eval_cg', False) == False:
                    predictions, eval_score, val_loss = eval_utils.eval_split(
                        allmodels,
                        crits,
                        loader,
                        pred_json_path_T.format(eval_kwargs['num_vids_eval'],
                                                iteration),
                        eval_kwargs,
                        flag_eval_what='tap_cg')

                predictions2, eval_score2, val_loss2 = eval_utils.eval_split(
                    allmodels,
                    crits,
                    loader,
                    pred_json_path_T.format(eval_kwargs2['num_vids_eval'],
                                            iteration),
                    eval_kwargs2,
                    flag_eval_what='cg')

                if (not vars(opt).get('fast_eval_cg', False)
                        == False) or (not vars(opt).get(
                            'fast_eval_cg_top10', False) == False):
                    eval_score = eval_score2
                    val_loss = val_loss2
                    predictions = predictions2

            # else:
            #    predictions, eval_score, val_loss = eval_utils.eval_split(allmodels, crits, loader, pred_json_path,
            #                                                              eval_kwargs,
            #                                                              flag_eval_what=flag_training_what)

            f_f1 = lambda x, y: 2 * x * y / (x + y)
            f1 = f_f1(eval_score['Recall'], eval_score['Precision']).mean()
            if flag_training_what != 'tap':  # if only train tap, use the mean of precision and recall as final score
                current_score = np.array(eval_score['METEOR']).mean() * 100
            else:  # if train tap_cg, use avg_meteor as final score
                current_score = f1

            for model in allmodels:
                for name, param in model.named_parameters():
                    tf_writer.add_histogram(name,
                                            param.clone().cpu().data.numpy(),
                                            iteration,
                                            bins=10)
                    if param.grad is not None:
                        tf_writer.add_histogram(
                            name + '_grad',
                            param.grad.clone().cpu().data.numpy(),
                            iteration,
                            bins=10)

            tf_writer.add_scalar('val_tap_loss', val_loss[0], iteration)
            tf_writer.add_scalar('val_cg_loss', val_loss[1], iteration)
            tf_writer.add_scalar('val_tap_prop_loss', val_loss[3], iteration)
            tf_writer.add_scalar('val_tap_bound_loss', val_loss[4], iteration)
            tf_writer.add_scalar('val_total_loss', val_loss[2], iteration)
            tf_writer.add_scalar('val_score', current_score, iteration)
            if flag_training_what != 'tap':
                tf_writer.add_scalar('val_score_gt_METEOR',
                                     np.array(eval_score2['METEOR']).mean(),
                                     iteration)
                tf_writer.add_scalar('val_score_gt_Bleu_4',
                                     np.array(eval_score2['Bleu_4']).mean(),
                                     iteration)
                tf_writer.add_scalar('val_score_gt_CIDEr',
                                     np.array(eval_score2['CIDEr']).mean(),
                                     iteration)
            tf_writer.add_scalar('val_recall', eval_score['Recall'].mean(),
                                 iteration)
            tf_writer.add_scalar('val_precision',
                                 eval_score['Precision'].mean(), iteration)
            tf_writer.add_scalar('f1', f1, iteration)

            val_result_history[iteration] = {
                'val_loss': val_loss,
                'eval_score': eval_score
            }

            if flag_training_what == 'tap':
                logger.info(
                    'Validation the result of iter {}, score(f1/meteor):{},\n all:{}'
                    .format(iteration, current_score, eval_score))
            else:
                mean_score = {
                    k: np.array(v).mean()
                    for k, v in eval_score.items()
                }
                gt_mean_score = {
                    k: np.array(v).mean()
                    for k, v in eval_score2.items()
                }

                metrics = ['Bleu_4', 'CIDEr', 'METEOR', 'ROUGE_L']
                gt_avg_score = np.array([
                    v for metric, v in gt_mean_score.items()
                    if metric in metrics
                ]).sum()
                logger.info(
                    'Validation the result of iter {}, score(f1/meteor):{},\n all:{}\n mean:{} \n\n gt:{} \n mean:{}\n avg_score: {}'
                    .format(iteration, current_score, eval_score, mean_score,
                            eval_score2, gt_mean_score, gt_avg_score))

            # Save model .pth
            saved_pth = {
                'iteration': iteration,
                'cg_model': cg_model.state_dict(),
                'tap_model': tap_model.state_dict(),
                'cg_optimizer': cg_optimizer.state_dict(),
                'tap_optimizer': tap_optimizer.state_dict(),
            }

            if opt.save_all_checkpoint:
                checkpoint_path = os.path.join(
                    save_folder, 'model_iter_{}.pth'.format(iteration))
            else:
                checkpoint_path = os.path.join(save_folder, 'model.pth')
            torch.save(saved_pth, checkpoint_path)
            logger.info('Save model at iter {} to checkpoint file {}.'.format(
                iteration, checkpoint_path))

            # save info.pkl
            if current_score > best_val_score:
                best_val_score = current_score
                best_epoch = epoch
                saved_info['best'] = {
                    'opt': opt,
                    'iter': iteration,
                    'epoch': epoch,
                    'iterators': loader.iterators,
                    'flag_training_what': flag_training_what,
                    'split_ix': loader.split_ix,
                    'best_val_score': best_val_score,
                    'vocab': loader.get_vocab(),
                }

                best_checkpoint_path = os.path.join(save_folder,
                                                    'model-best.pth')
                torch.save(saved_pth, best_checkpoint_path)
                logger.info(
                    'Save Best-model at iter {} to checkpoint file.'.format(
                        iteration))

            saved_info['last'] = {
                'opt': opt,
                'iter': iteration,
                'epoch': epoch,
                'iterators': loader.iterators,
                'flag_training_what': flag_training_what,
                'split_ix': loader.split_ix,
                'best_val_score': best_val_score,
                'vocab': loader.get_vocab(),
            }
            saved_info['history'] = {
                'val_result_history': val_result_history,
                'loss_history': loss_history,
                'lr_history': lr_history,
            }
            with open(os.path.join(save_folder, 'info.pkl'), 'w') as f:
                cPickle.dump(saved_info, f)
                logger.info('Save info to info.pkl')

            # Stop criterion
            if epoch >= len(flag_training_whats):
                tf_writer.close()
                break
def train(opt):
    # Deal with feature things before anything
    opt.use_att = utils.if_use_att(opt.caption_model)
    if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5

    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from,
                               'infos_' + opt.id + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = ["rnn_type", "rnn_size",
                            "num_layers"]  #removed caption_model temporarily
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt).cuda()

    dp_model = torch.nn.DataParallel(model)

    update_lr_flag = True
    # Assure in training mode
    dp_model.train()

    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()
    if opt.reduce_on_plateau:
        optimizer = utils.build_optimizer(model.parameters(), opt)
        optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3)
    else:
        optimizer = utils.build_optimizer(model.parameters(), opt)
    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(
            os.path.join(opt.start_from, "optimizer.pth")):
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))


# model.load_state_dict(torch.load(opt.model, map_location=lambda storage, loc: storage))
    while True:
        if update_lr_flag:
            if not opt.reduce_on_plateau:
                # Assign the learning rate
                if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                    frac = (epoch - opt.learning_rate_decay_start
                            ) // opt.learning_rate_decay_every
                    decay_factor = opt.learning_rate_decay_rate**frac
                    opt.current_lr = opt.learning_rate * decay_factor
                else:
                    opt.current_lr = opt.learning_rate
                utils.set_lr(optimizer, opt.current_lr)
                # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            update_lr_flag = False

        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch('train')
        #         print('Read data:', time.time() - start)
        '''checking data object'''

        torch.cuda.synchronize()
        start = time.time()

        tmp = [
            data['fc_feats'], data['att_feats'], data['labels'], data['masks'],
            data['att_masks']
        ]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        fc_feats, att_feats, labels, masks, att_masks = tmp

        optimizer.zero_grad()
        if not sc_flag:
            loss = crit(dp_model(fc_feats, att_feats, labels, att_masks),
                        labels[:, 1:], masks[:, 1:])

        else:
            gen_result, sample_logprobs = dp_model(fc_feats,
                                                   att_feats,
                                                   att_masks,
                                                   opt={'sample_max': 0},
                                                   mode='sample')
            reward = get_self_critical_reward(dp_model, fc_feats, att_feats,
                                              att_masks, data, gen_result, opt)
            loss = rl_crit(sample_logprobs, gen_result.data,
                           torch.from_numpy(reward).float().cuda())

        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        train_loss = loss.item()
        torch.cuda.synchronize()
        end = time.time()
        if not sc_flag:
            #             pass
            print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                .format(iteration, epoch, train_loss, end - start))
        else:
            #             pass
            print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \
                .format(iteration, epoch, np.mean(reward[:,0]), end - start))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True
            print("Iter: {}, Loss: {}, Epoc: {}, LR: {}".format(
                iteration, train_loss, epoch, opt.current_lr))

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            add_summary_value(tb_summary_writer, 'train_loss', train_loss,
                              iteration)
            if opt.reduce_on_plateau:
                opt.current_lr = optimizer.current_lr
            add_summary_value(tb_summary_writer, 'learning_rate',
                              opt.current_lr, iteration)
            add_summary_value(tb_summary_writer, 'scheduled_sampling_prob',
                              model.ss_prob, iteration)
            if sc_flag:
                add_summary_value(tb_summary_writer, 'avg_reward',
                                  np.mean(reward[:, 0]), iteration)

            loss_history[iteration] = train_loss if not sc_flag else np.mean(
                reward[:, 0])
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            # eval model
            eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = eval_utils.eval_split(
                dp_model, crit, loader, eval_kwargs)
            if opt.reduce_on_plateau:
                if 'CIDEr' in lang_stats:
                    optimizer.scheduler_step(-lang_stats['CIDEr'])
                else:
                    optimizer.scheduler_step(val_loss)

            # Write validation result into summary
            add_summary_value(tb_summary_writer, 'validation loss', val_loss,
                              iteration)
            if lang_stats is not None:
                for k, v in lang_stats.items():
                    add_summary_value(tb_summary_writer, k, v, iteration)
            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            best_flag = False
            if True:  # if true
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'infos_' + opt.id + '.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'histories_' + opt.id + '.pkl'),
                        'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos_' + opt.id + '-best.pkl'),
                            'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
예제 #28
0
def main(opt):
    dataset = VideoDataset(opt, 'train')
    dataloader = DataLoader(dataset,
                            batch_size=opt["batch_size"],
                            shuffle=True)
    opt["vocab_size"] = dataset.get_vocab_size()
    if opt["model"] == 'S2VTModel':
        model = S2VTModel(opt["vocab_size"],
                          opt["max_len"],
                          opt["dim_hidden"],
                          opt["dim_word"],
                          opt['dim_vid'],
                          rnn_cell=opt['rnn_type'],
                          n_layers=opt['num_layers'],
                          rnn_dropout_p=opt["rnn_dropout_p"])
    elif opt["model"] == "S2VTAttModel":
        encoder = EncoderRNN(opt["dim_vid"],
                             opt["dim_hidden"],
                             bidirectional=opt["bidirectional"],
                             input_dropout_p=opt["input_dropout_p"],
                             rnn_cell=opt['rnn_type'],
                             rnn_dropout_p=opt["rnn_dropout_p"])
        # # 声音encoder
        # encoder_voice = EncoderRNN(
        #     opt["dim_voice"],
        #     opt["dim_hidden"],
        #     bidirectional=opt["bidirectional"],
        #     input_dropout_p=opt["input_dropout_p"],
        #     rnn_cell=opt['rnn_type'],
        #     rnn_dropout_p=opt["rnn_dropout_p"])
        # 手语encoder
        if opt['with_hand'] == 1:
            encoder_hand = EncoderRNN(opt["dim_hand"],
                                      opt["dim_hand_hidden"],
                                      bidirectional=opt["bidirectional"],
                                      input_dropout_p=opt["input_dropout_p"],
                                      rnn_cell=opt['rnn_type'],
                                      rnn_dropout_p=opt["rnn_dropout_p"])
            decoder = DecoderRNN(opt["vocab_size"],
                                 opt["max_len"],
                                 opt["dim_hidden"] + opt["dim_hand_hidden"],
                                 opt["dim_word"],
                                 input_dropout_p=opt["input_dropout_p"],
                                 rnn_cell=opt['rnn_type'],
                                 rnn_dropout_p=opt["rnn_dropout_p"],
                                 bidirectional=opt["bidirectional"])
            model = S2VTAttModel(encoder, encoder_hand, decoder)
        else:
            decoder = DecoderRNN(opt["vocab_size"],
                                 opt["max_len"],
                                 opt["dim_hidden"],
                                 opt["dim_word"],
                                 input_dropout_p=opt["input_dropout_p"],
                                 rnn_cell=opt['rnn_type'],
                                 rnn_dropout_p=opt["rnn_dropout_p"],
                                 bidirectional=opt["bidirectional"])
            model = S2VTAttModel(encoder, None, decoder)
        # model = S2VTAttModel(encoder, encoder_voice, encoder_hand, decoder)

    model = model.cuda()
    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()
    optimizer = optim.Adam(model.parameters(),
                           lr=opt["learning_rate"],
                           weight_decay=opt["weight_decay"])
    exp_lr_scheduler = optim.lr_scheduler.StepLR(
        optimizer,
        step_size=opt["learning_rate_decay_every"],
        gamma=opt["learning_rate_decay_rate"])
    # print(dataloader)
    # print(crit)
    # print(optimizer)

    train(dataloader, model, crit, optimizer, exp_lr_scheduler, opt, rl_crit)
예제 #29
0
    if k not in ignore:
        if k in vars(opt):
            assert vars(opt)[k] == vars(
                infos['opt'])[k], k + ' option not consistent'
        else:
            vars(opt).update({k: vars(infos['opt'])[k]
                              })  # copy over options from model

vocab = infos['vocab']  # ix -> word mapping

# Setup the model
model = models.setup(opt)
model.load_state_dict(torch.load(opt.model))
model.cuda()
model.eval()
crit = utils.LanguageModelCriterion()

# Create the Data Loader instance
if len(opt.image_folder) == 0:
    loader = DataLoader(opt)
else:
    loader = DataLoaderRaw({
        'folder_path': opt.image_folder,
        'coco_json': opt.coco_json,
        'batch_size': opt.batch_size,
        'cnn_model': opt.cnn_model
    })
# When eval using provided pretrained model, the vocab may be different from what you have in your cocotalk.json
# So make sure to use the vocab in infos file.
loader.ix_to_word = infos['vocab']
def train(opt):
    import random
    random.seed(0)
    torch.manual_seed(0)
    torch.cuda.manual_seed_all(0)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(0)
    # Deal with feature things before anything
    opt.use_att = utils.if_use_att(opt.caption_model)
    if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5

    from dataloader_pair import DataLoader

    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    if opt.log_to_file:
        if os.path.exists(os.path.join(opt.checkpoint_path, 'log')):
            suffix = time.strftime("%Y-%m-%d %X", time.localtime())
            print('Warning !!! %s already exists ! use suffix ! ' %
                  os.path.join(opt.checkpoint_path, 'log'))
            sys.stdout = open(
                os.path.join(opt.checkpoint_path, 'log' + suffix), "w")
        else:
            print('logging to file %s' %
                  os.path.join(opt.checkpoint_path, 'log'))
            sys.stdout = open(os.path.join(opt.checkpoint_path, 'log'), "w")

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        if os.path.isfile(opt.start_from):
            with open(os.path.join(opt.infos)) as f:
                infos = cPickle.load(f)
                saved_model_opt = infos['opt']
                need_be_same = [
                    "caption_model", "rnn_type", "rnn_size", "num_layers"
                ]
                for checkme in need_be_same:
                    assert vars(saved_model_opt)[checkme] == vars(
                        opt
                    )[checkme], "Command line argument and saved model disagree on '%s' " % checkme
        else:
            if opt.load_best != 0:
                print('loading best info')
                with open(
                        os.path.join(opt.start_from,
                                     'infos_' + opt.id + '-best.pkl')) as f:
                    infos = cPickle.load(f)
                    saved_model_opt = infos['opt']
                    need_be_same = [
                        "caption_model", "rnn_type", "rnn_size", "num_layers"
                    ]
                    for checkme in need_be_same:
                        assert vars(saved_model_opt)[checkme] == vars(
                            opt
                        )[checkme], "Command line argument and saved model disagree on '%s' " % checkme
            else:
                with open(
                        os.path.join(opt.start_from,
                                     'infos_' + opt.id + '.pkl')) as f:
                    infos = cPickle.load(f)
                    saved_model_opt = infos['opt']
                    need_be_same = [
                        "caption_model", "rnn_type", "rnn_size", "num_layers"
                    ]
                    for checkme in need_be_same:
                        assert vars(saved_model_opt)[checkme] == vars(
                            opt
                        )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl'), 'rb') as f:
                try:
                    histories = cPickle.load(f)
                except:
                    print('load history error!')
                    histories = {}

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)
    start_epoch = epoch

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt).cuda()
    dp_model = torch.nn.DataParallel(model)

    update_lr_flag = True
    # Assure in training mode
    dp_model.train()

    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()

    optimizer = utils.build_optimizer(model.parameters(), opt)
    #Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(
            os.path.join(opt.start_from, "optimizer.pth")):
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    if opt.caption_model == 'att2in2p':
        optimized = [
            'logit2', 'ctx2att2', 'core2', 'prev_sent_emb', 'prev_sent_wrap'
        ]
        optimized_param = []
        optimized_param1 = []

        for name, param in model.named_parameters():
            second = False
            for n in optimized:
                if n in name:
                    print('second', name)
                    optimized_param.append(param)
                    second = True
            if 'embed' in name:
                print('all', name)
                optimized_param1.append(param)
                optimized_param.append(param)
            elif not second:
                print('first', name)
                optimized_param1.append(param)

    while True:
        if opt.val_only:
            eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            print('start evaluating')
            val_loss, predictions, lang_stats = eval_utils_pair.eval_split(
                dp_model, crit, loader, eval_kwargs)
            exit(0)
        if update_lr_flag:
            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate
            utils.set_lr(optimizer, opt.current_lr)
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            update_lr_flag = False

        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch('train')
        print('Read data:', time.time() - start)

        torch.cuda.synchronize()
        start = time.time()

        tmp = [
            data['pair_fc_feats'], data['pair_att_feats'], data['pair_labels'],
            data['pair_masks'], data['pair_att_masks']
        ]

        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        fc_feats, att_feats, labels, masks, att_masks = tmp
        masks = masks.float()

        optimizer.zero_grad()

        if not sc_flag:
            if opt.onlysecond:
                # only using the second sentence from a visual paraphrase pair. opt.caption_model should be a one-stage decoding model
                loss = crit(
                    dp_model(fc_feats, att_feats, labels[:, 1, :], att_masks),
                    labels[:, 1, 1:], masks[:, 1, 1:])
                loss1 = loss2 = loss / 2
            elif opt.first:
                # using the first sentence
                tmp = [
                    data['first_fc_feats'], data['first_att_feats'],
                    data['first_labels'], data['first_masks'],
                    data['first_att_masks']
                ]
                tmp = [
                    _ if _ is None else torch.from_numpy(_).cuda() for _ in tmp
                ]
                fc_feats, att_feats, labels, masks, att_masks = tmp
                masks = masks.float()
                loss = crit(
                    dp_model(fc_feats, att_feats, labels[:, :], att_masks),
                    labels[:, 1:], masks[:, 1:])
                loss1 = loss2 = loss / 2
            elif opt.onlyfirst:
                # only using the second sentence from a visual paraphrase pair
                loss = crit(
                    dp_model(fc_feats, att_feats, labels[:, 0, :], att_masks),
                    labels[:, 0, 1:], masks[:, 0, 1:])
                loss1 = loss2 = loss / 2
            else:
                # proposed DCVP model, opt.caption_model should be att2inp
                output1, output2 = dp_model(fc_feats, att_feats, labels,
                                            att_masks, masks[:, 0, 1:])
                loss1 = crit(output1, labels[:, 0, 1:], masks[:, 0, 1:])
                loss2 = crit(output2, labels[:, 1, 1:], masks[:, 1, 1:])
                loss = loss1 + loss2

        else:
            raise NotImplementedError
            # Our DCVP model does not support self-critical sequence training
            # We found that RL(SCST) with CIDEr reward will improve conventional metrics (BLEU, CIDEr, etc.)
            # but harm diversity and descriptiveness
            # Please refer to the paper for the details

        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()

        train_loss = loss.item()
        torch.cuda.synchronize()
        end = time.time()
        if not sc_flag:
            print("iter {} (epoch {}), train_loss = {:.3f}, loss1 = {:.3f}, loss2 = {:.3f}, time/batch = {:.3f}" \
                .format(iteration, epoch, loss.item(), loss1.item(), loss2.item(), end - start))
        else:
            print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \
                .format(iteration, epoch, np.mean(reward[:,0]), end - start))

        sys.stdout.flush()
        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            add_summary_value(tb_summary_writer, 'train_loss', train_loss,
                              iteration)
            add_summary_value(tb_summary_writer, 'learning_rate',
                              opt.current_lr, iteration)
            add_summary_value(tb_summary_writer, 'scheduled_sampling_prob',
                              model.ss_prob, iteration)
            if sc_flag:
                add_summary_value(tb_summary_writer, 'avg_reward',
                                  np.mean(reward[:, 0]), iteration)

            loss_history[iteration] = train_loss if not sc_flag else np.mean(
                reward[:, 0])
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            # eval model
            eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = eval_utils_pair.eval_split(
                dp_model, crit, loader, eval_kwargs)

            # Write validation result into summary
            add_summary_value(tb_summary_writer, 'validation loss', val_loss,
                              iteration)
            if lang_stats is not None:
                for k, v in lang_stats.items():
                    add_summary_value(tb_summary_writer, k, v, iteration)
            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            best_flag = False
            if True:  # if true
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'infos_' + opt.id + '.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'histories_' + opt.id + '.pkl'),
                        'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos_' + opt.id + '-best.pkl'),
                            'wb') as f:
                        cPickle.dump(infos, f)
                checkpoint_path = os.path.join(
                    opt.checkpoint_path, 'model' + str(iteration) + '.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                with open(
                        os.path.join(
                            opt.checkpoint_path,
                            'infos_' + opt.id + '_' + str(iteration) + '.pkl'),
                        'wb') as f:
                    cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break