Exemple #1
0
if len(opt.id) == 0:
    opt.id = infos['opt'].id
ignore = ["id", "batch_size", "beam_size", "start_from"]  # , "language_eval"
for k in vars(infos['opt']).keys():
    if k not in ignore:
        if k in vars(opt):
            pass
            # assert vars(opt)[k] == vars(infos['opt'])[k], k + ' option not consistent'
        else:
            vars(opt).update({k: vars(infos['opt'])[k]
                              })  # copy over options from model

vocab = infos['vocab']  # ix -> word mapping

# Setup the model
cnn_model = utils.build_cnn(opt)
cnn_model.load_state_dict(torch.load(opt.cnn_model_path))
cnn_model.cuda()
cnn_model.eval()
model = models.setup(opt)
model.load_state_dict(torch.load(opt.model_path))
model.cuda()
model.eval()
crit = utils.LanguageModelCriterion()
opt.seq_per_img = 1
# opt.sentence_embed = './data/data_news_compact_lda_label.h5'
# opt.sentence_embed = './data/sen_embed/articles_compact_avg.h5'
# Create the Data Loader instance
if len(opt.image_folder) == 0:
    loader = DataLoader(opt)
else:
Exemple #2
0
def train(opt):
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tf_summary_writer = tf.summary.FileWriter(opt.checkpoint_path)

    infos = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from,
                               'infos_' + opt.id + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)
    val_result_history = infos.get('val_result_history', {})
    loss_history = infos.get('loss_history', {})
    lr_history = infos.get('lr_history', {})
    ss_prob_history = infos.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    cnn_model = utils.build_cnn(opt)
    cnn_model.cuda()
    model = models.setup(opt)
    model.cuda()

    update_lr_flag = True
    # Assure in training mode
    model.train()

    crit = utils.LanguageModelCriterion()

    optimizer = optim.Adam(model.parameters(), lr=opt.learning_rate)
    cnn_optimizer = optim.Adam(cnn_model.parameters(),
                               lr=opt.cnn_learning_rate,
                               weight_decay=opt.cnn_weight_decay)

    # Load the optimizer
    if vars(opt).get('start_from', None) is not None:
        if os.path.isfile(os.path.join(opt.start_from, 'optimizer.pth')):
            optimizer.load_state_dict(
                torch.load(os.path.join(opt.start_from, 'optimizer.pth')))
        if os.path.isfile(os.path.join(opt.start_from, 'optimizer-cnn.pth')):
            cnn_optimizer.load_state_dict(
                torch.load(os.path.join(opt.start_from, 'optimizer-cnn.pth')))

    while True:
        if update_lr_flag:
            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
                utils.set_lr(optimizer, opt.current_lr)  # set the decayed rate
            else:
                opt.current_lr = opt.learning_rate
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob
            # Update the training stage of cnn
            if opt.finetune_cnn_after == -1 or epoch < opt.finetune_cnn_after:
                for p in cnn_model.parameters():
                    p.requires_grad = False
                cnn_model.eval()
            else:
                for p in cnn_model.parameters():
                    p.requires_grad = True
                cnn_model.train()
            update_lr_flag = False

        torch.cuda.synchronize()
        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch('train')
        torch.cuda.synchronize()
        print('Read data:', time.time() - start)

        torch.cuda.synchronize()
        start = time.time()

        tmp = [data['images'], data['labels'], data['masks']]
        tmp = [
            Variable(torch.from_numpy(_), requires_grad=False).cuda()
            for _ in tmp
        ]
        images, labels, masks = tmp

        att_feats = cnn_model(images)
        fc_feats = att_feats.mean(2).mean(3).squeeze(2).squeeze(2)

        att_feats = att_feats.unsqueeze(1).expand(*((
            att_feats.size(0),
            opt.seq_per_img,
        ) + att_feats.size()[1:])).contiguous().view(
            *((att_feats.size(0) * opt.seq_per_img, ) + att_feats.size()[1:]))
        fc_feats = fc_feats.unsqueeze(1).expand(*((
            fc_feats.size(0),
            opt.seq_per_img,
        ) + fc_feats.size()[1:])).contiguous().view(
            *((fc_feats.size(0) * opt.seq_per_img, ) + fc_feats.size()[1:]))

        optimizer.zero_grad()
        if opt.finetune_cnn_after != -1 and epoch >= opt.finetune_cnn_after:
            cnn_optimizer.zero_grad()
        loss = crit(model(fc_feats, att_feats, labels), labels[:, 1:],
                    masks[:, 1:])
        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        if opt.finetune_cnn_after != -1 and epoch >= opt.finetune_cnn_after:
            utils.clip_gradient(cnn_optimizer, opt.grad_clip)
            cnn_optimizer.step()
        train_loss = loss.data[0]
        torch.cuda.synchronize()
        end = time.time()
        print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
            .format(iteration, epoch, train_loss, end - start))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            add_summary_value(tf_summary_writer, 'train_loss', train_loss,
                              iteration)
            add_summary_value(tf_summary_writer, 'learning_rate',
                              opt.current_lr, iteration)
            add_summary_value(tf_summary_writer, 'scheduled_sampling_prob',
                              model.ss_prob, iteration)
            tf_summary_writer.flush()

            loss_history[iteration] = train_loss
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            # eval model
            eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = eval_utils.eval_split(
                cnn_model, model, crit, loader, eval_kwargs)

            # Write validation result into summary
            add_summary_value(tf_summary_writer, 'validation loss', val_loss,
                              iteration)
            for k, v in lang_stats.iteritems():
                add_summary_value(tf_summary_writer, k, v, iteration)
            tf_summary_writer.flush()
            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            best_flag = False
            if True:  # if true
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model.pth')
                cnn_checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-cnn.pth')
                torch.save(model.state_dict(), checkpoint_path)
                torch.save(cnn_model.state_dict(), cnn_checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                print("cnn model saved to {}".format(cnn_checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'optimizer.pth')
                cnn_optimizer_path = os.path.join(opt.checkpoint_path,
                                                  'optimizer-cnn.pth')
                torch.save(optimizer.state_dict(), optimizer_path)
                torch.save(cnn_optimizer.state_dict(), cnn_optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['val_result_history'] = val_result_history
                infos['loss_history'] = loss_history
                infos['lr_history'] = lr_history
                infos['ss_prob_history'] = ss_prob_history
                infos['vocab'] = loader.get_vocab()
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'infos_' + opt.id + '.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-best.pth')
                    cnn_checkpoint_path = os.path.join(opt.checkpoint_path,
                                                       'model-cnn-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    torch.save(cnn_model.state_dict(), cnn_checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    print("cnn model saved to {}".format(cnn_checkpoint_path))
                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos_' + opt.id + '-best.pkl'),
                            'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
def main():
    import opts
    import misc.utils as utils
    opt = opts.parse_opt()
    opt.caption_model = 'topdown'
    opt.batch_size = 10  #512#32*4*4
    opt.id = 'topdown'
    opt.learning_rate = 5e-4
    opt.learning_rate_decay_start = 0
    opt.scheduled_sampling_start = 0
    opt.save_checkpoint_every = 5000  #450#5000#11500
    opt.val_images_use = 5000
    opt.max_epochs = 50  #30
    opt.start_from = 'save/rt'  #"save" #None
    opt.language_eval = 1
    opt.input_json = 'data/meta_coco_en.json'
    opt.input_label_h5 = 'data/label_coco_en.h5'
    #    opt.input_json='data/coco_ccg.json' #'data/meta_coco_en.json'
    #    opt.input_label_h5='data/coco_ccg_label.h5' #'data/label_coco_en.h5'
    #    opt.input_fc_dir='/nlp/andyweizhao/self-critical.pytorch-master/data/cocotalk_fc'
    #    opt.input_att_dir='/nlp/andyweizhao/self-critical.pytorch-master/data/cocotalk_att'
    opt.finetune_cnn_after = 0
    opt.ccg = False
    opt.input_image_h5 = 'data/coco_image_512.h5'

    opt.use_att = utils.if_use_att(opt.caption_model)

    from dataloader import DataLoader  # just-in-time generated features
    loader = DataLoader(opt)

    #    from dataloader_fixcnn import DataLoader # load pre-processed features
    #    loader = DataLoader(opt)

    opt.vocab_size = loader.vocab_size
    opt.vocab_ccg_size = loader.vocab_ccg_size
    opt.seq_length = loader.seq_length

    import models
    model = models.setup(opt)
    cnn_model = utils.build_cnn(opt)
    cnn_model.cuda()
    model.cuda()

    data = loader.get_batch('train')
    images = data['images']

    #    _fc_feats_2048 = []
    #    _fc_feats_81 = []
    #    _att_feats = []
    #    for i in range(loader.batch_size):
    #        x = Variable(torch.from_numpy(images[i]), volatile=True).cuda()
    #        x = x.unsqueeze(0)
    #        att_feats, fc_feats_81 = cnn_model(x)
    #        fc_feats_2048 = att_feats.mean(3).mean(2).squeeze()
    #        att_feats = F.adaptive_avg_pool2d(att_feats,[14,14]).squeeze().permute(1, 2, 0)#(0, 2, 3, 1)
    #        _fc_feats_2048.append(fc_feats_2048)
    #        _fc_feats_81.append(fc_feats_81)
    #        _att_feats.append(att_feats)
    #    _fc_feats_2048 = torch.stack(_fc_feats_2048)
    #    _fc_feats_81 = torch.stack(_fc_feats_81)
    #    _att_feats = torch.stack(_att_feats)
    #    att_feats = _att_feats.unsqueeze(1).expand(*((_att_feats.size(0), loader.seq_per_img,) + \
    #                                                   _att_feats.size()[1:])).contiguous().view(*((_att_feats.size(0) * loader.seq_per_img,) + \
    #                                                   _att_feats.size()[1:]))
    #    fc_feats_2048 = _fc_feats_2048.unsqueeze(1).expand(*((_fc_feats_2048.size(0), loader.seq_per_img,) + \
    #                                                  _fc_feats_2048.size()[1:])).contiguous().view(*((_fc_feats_2048.size(0) * loader.seq_per_img,) + \
    #                                                  _fc_feats_2048.size()[1:]))
    #    fc_feats_81 = _fc_feats_81
    #
    #    att_feats = Variable(att_feats, requires_grad=False).cuda()
    #    Variable(fc_feats_81)

    crit = utils.LanguageModelCriterion()
    eval_kwargs = {'split': 'val', 'dataset': opt.input_json, 'verbose': True}
    eval_kwargs.update(vars(opt))
    val_loss, predictions, lang_stats = eval_split(cnn_model, model, crit,
                                                   loader, eval_kwargs, True)
Exemple #4
0
def train(opt):
    opt.use_att = utils.if_use_att(opt.caption_model)

    from dataloader import DataLoader
    loader = DataLoader(opt)

    opt.vocab_size = loader.vocab_size
    opt.vocab_ccg_size = loader.vocab_ccg_size
    opt.seq_length = loader.seq_length

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from,
                               'infos_' + opt.id + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme
        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)

    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    cnn_model = utils.build_cnn(opt)
    cnn_model.cuda()

    model = models.setup(opt)
    model.cuda()
    # model = DataParallel(model)

    if vars(opt).get('start_from', None) is not None:
        # check if all necessary files exist
        assert os.path.isdir(
            opt.start_from), " %s must be a a path" % opt.start_from
        assert os.path.isfile(
            os.path.join(opt.start_from, "infos_" + opt.id + ".pkl")
        ), "infos.pkl file does not exist in path %s" % opt.start_from
        model.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'model.pth')))

    update_lr_flag = True
    model.train()

    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()
    multilabel_crit = nn.MultiLabelSoftMarginLoss().cuda()
    #    optimizer = optim.Adam(model.parameters(), lr=opt.learning_rate, weight_decay=opt.weight_decay)
    optimizer = optim.Adam(model.parameters(), lr=opt.learning_rate)
    if opt.finetune_cnn_after != -1 and epoch >= opt.finetune_cnn_after:
        print('finetune mode')
        cnn_optimizer = optim.Adam([\
            {'params': module.parameters()} for module in cnn_model._modules.values()[5:]\
            ], lr=opt.cnn_learning_rate, weight_decay=opt.cnn_weight_decay)

    if vars(opt).get('start_from', None) is not None and os.path.isfile(
            os.path.join(opt.start_from, "optimizer.pth")):
        if os.path.isfile(os.path.join(opt.start_from, 'optimizer.pth')):
            optimizer.load_state_dict(
                torch.load(os.path.join(opt.start_from, 'optimizer.pth')))
        if opt.finetune_cnn_after != -1 and epoch >= opt.finetune_cnn_after:
            if os.path.isfile(os.path.join(opt.start_from,
                                           'optimizer-cnn.pth')):
                cnn_optimizer.load_state_dict(
                    torch.load(
                        os.path.join(opt.start_from, 'optimizer-cnn.pth')))

    eval_kwargs = {'split': 'val', 'dataset': opt.input_json, 'verbose': True}
    eval_kwargs.update(vars(opt))
    val_loss, predictions, lang_stats = eval_utils.eval_split(
        cnn_model, model, crit, loader, eval_kwargs, True)
    epoch_start = time.time()
    while True:
        if update_lr_flag:
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
                utils.set_lr(optimizer, opt.current_lr)  # set the decayed rate
            else:
                opt.current_lr = opt.learning_rate
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob
                #model.module.ss_prob = opt.ss_prob
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
            else:
                sc_flag = False

            # Update the training stage of cnn
            for p in cnn_model.parameters():
                p.requires_grad = True
            # Fix the first few layers:
            for module in cnn_model._modules.values()[:5]:
                for p in module.parameters():
                    p.requires_grad = False
            cnn_model.train()
            update_lr_flag = False

        cnn_model.apply(utils.set_bn_fix)
        cnn_model.apply(utils.set_bn_eval)

        start = time.time()
        torch.cuda.synchronize()
        data = loader.get_batch('train')
        if opt.finetune_cnn_after != -1 and epoch >= opt.finetune_cnn_after:

            multilabels = [
                data['detection_infos'][i]['label']
                for i in range(len(data['detection_infos']))
            ]

            tmp = [
                data['labels'], data['masks'],
                np.array(multilabels, dtype=np.int16)
            ]
            tmp = [
                Variable(torch.from_numpy(_), requires_grad=False).cuda()
                for _ in tmp
            ]
            labels, masks, multilabels = tmp
            images = data[
                'images']  # it cannot be turned into tensor since different sizes.
            _fc_feats_2048 = []
            _fc_feats_81 = []
            _att_feats = []
            for i in range(loader.batch_size):
                x = Variable(torch.from_numpy(images[i]),
                             requires_grad=False).cuda()
                x = x.unsqueeze(0)
                att_feats, fc_feats_81 = cnn_model(x)
                fc_feats_2048 = att_feats.mean(3).mean(2).squeeze()
                att_feats = F.adaptive_avg_pool2d(att_feats,
                                                  [14, 14]).squeeze().permute(
                                                      1, 2, 0)  #(0, 2, 3, 1)
                _fc_feats_2048.append(fc_feats_2048)
                _fc_feats_81.append(fc_feats_81)
                _att_feats.append(att_feats)
            _fc_feats_2048 = torch.stack(_fc_feats_2048)
            _fc_feats_81 = torch.stack(_fc_feats_81)
            _att_feats = torch.stack(_att_feats)
            att_feats = _att_feats.unsqueeze(1).expand(*((_att_feats.size(0), loader.seq_per_img,) + \
                                                           _att_feats.size()[1:])).contiguous().view(*((_att_feats.size(0) * loader.seq_per_img,) + \
                                                           _att_feats.size()[1:]))
            fc_feats_2048 = _fc_feats_2048.unsqueeze(1).expand(*((_fc_feats_2048.size(0), loader.seq_per_img,) + \
                                                          _fc_feats_2048.size()[1:])).contiguous().view(*((_fc_feats_2048.size(0) * loader.seq_per_img,) + \
                                                          _fc_feats_2048.size()[1:]))
            fc_feats_81 = _fc_feats_81
            #
            cnn_optimizer.zero_grad()
        else:

            tmp = [
                data['fc_feats'], data['att_feats'], data['labels'],
                data['masks']
            ]
            tmp = [
                Variable(torch.from_numpy(_), requires_grad=False).cuda()
                for _ in tmp
            ]
            fc_feats, att_feats, labels, masks = tmp

        optimizer.zero_grad()

        if not sc_flag:
            loss1 = crit(model(fc_feats_2048, att_feats, labels),
                         labels[:, 1:], masks[:, 1:])
            loss2 = multilabel_crit(fc_feats_81.double(), multilabels.double())
            loss = 0.8 * loss1 + 0.2 * loss2.float()
        else:
            gen_result, sample_logprobs = model.sample(fc_feats_2048,
                                                       att_feats,
                                                       {'sample_max': 0})
            reward = get_self_critical_reward(model, fc_feats_2048, att_feats,
                                              data, gen_result)
            loss1 = rl_crit(
                sample_logprobs, gen_result,
                Variable(torch.from_numpy(reward).float().cuda(),
                         requires_grad=False))
            loss2 = multilabel_crit(fc_feats_81.double(), multilabels.double())
            loss3 = crit(model(fc_feats_2048, att_feats, labels),
                         labels[:, 1:], masks[:, 1:])
            loss = 0.995 * loss1 + 0.005 * (loss2.float() + loss3)
        loss.backward()

        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()

        train_loss = loss.data[0]
        mle_loss = loss1.data[0]
        multilabel_loss = loss2.data[0]
        torch.cuda.synchronize()
        end = time.time()
        if not sc_flag and iteration % 2500 == 0:
            print("iter {} (epoch {}), mle_loss = {:.3f}, multilabel_loss = {:.3f}, train_loss = {:.3f}, time/batch = {:.3f}" \
                .format(iteration, epoch, mle_loss, multilabel_loss, train_loss, end - start))

        if sc_flag and iteration % 2500 == 0:
            print("iter {} (epoch {}), avg_reward = {:.3f}, mle_loss = {:.3f}, multilabel_loss = {:.3f}, train_loss = {:.3f}, time/batch = {:.3f}" \
                .format(iteration, epoch, np.mean(reward[:,0]), mle_loss, multilabel_loss, train_loss, end - start))
        iteration += 1
        if (iteration % opt.losses_log_every == 0):
            loss_history[iteration] = train_loss if not sc_flag else np.mean(
                reward[:, 0])
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        if (iteration % opt.save_checkpoint_every == 0):
            eval_kwargs = {
                'split': 'val',
                'dataset': opt.input_json,
                'verbose': True
            }
            eval_kwargs.update(vars(opt))

            if opt.finetune_cnn_after != -1 and epoch >= opt.finetune_cnn_after:
                val_loss, predictions, lang_stats = eval_utils.eval_split(
                    cnn_model, model, crit, loader, eval_kwargs, True)
            else:
                val_loss, predictions, lang_stats = eval_utils.eval_split(
                    cnn_model, model, crit, loader, eval_kwargs, False)

            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            best_flag = False
            if True:
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))

                cnn_checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-cnn.pth')
                torch.save(cnn_model.state_dict(), cnn_checkpoint_path)
                print("cnn model saved to {}".format(cnn_checkpoint_path))

                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                if opt.finetune_cnn_after != -1 and epoch >= opt.finetune_cnn_after:
                    cnn_optimizer_path = os.path.join(opt.checkpoint_path,
                                                      'optimizer-cnn.pth')
                    torch.save(cnn_optimizer.state_dict(), cnn_optimizer_path)

                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'infos_' + opt.id + '.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'histories_' + opt.id + '.pkl'),
                        'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))

                    cnn_checkpoint_path = os.path.join(opt.checkpoint_path,
                                                       'model-cnn-best.pth')
                    torch.save(cnn_model.state_dict(), cnn_checkpoint_path)
                    print("cnn model saved to {}".format(cnn_checkpoint_path))

                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos_' + opt.id + '-best.pkl'),
                            'wb') as f:
                        cPickle.dump(infos, f)

        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True
            print("epoch: " + str(epoch) + " during: " +
                  str(time.time() - epoch_start))
            epoch_start = time.time()

        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
Exemple #5
0
def train(opt):
    np.random.seed(42)
    warnings.filterwarnings('ignore')

    opt.use_att = utils.if_use_att(opt.caption_model)
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length
    # for debug purposes
    # a=get_batch_one(opt, [loader.split_ix, loader.shuffle, loader.iterators, loader.label_start_ix, loader.label_end_ix])
    # loader.get_batch('train')
    if not os.path.exists(opt.checkpoint_path+'tensorboard/'):
        os.makedirs(opt.checkpoint_path+'tensorboard/')

    else:
        for path in os.listdir(opt.checkpoint_path+'tensorboard/'):
            os.remove(opt.checkpoint_path+'tensorboard/'+path)
    tf_summary_writer = tf and tf.summary.FileWriter(opt.checkpoint_path+'tensorboard/')
    np.random.seed(42)
    infos = {}
    histories = {}

    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from, 'infos_'+opt.id+'.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same=["caption_model", "rnn_type", "rnn_size", "num_layers"]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')):
            with open(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    # iteration = 26540
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    cnn_model = utils.build_cnn(opt)
    cnn_model.cuda()
    model = models.setup(opt)
    model.cuda()

    update_lr_flag = True
    # Assure in training mode
    model.train()

    crit = utils.LanguageModelCriterion()

    optimizer = optim.Adam(model.parameters(), lr=opt.learning_rate)
    if opt.finetune_cnn_after != -1:
        # only finetune the layer2 to layer4
        cnn_optimizer = optim.Adam([\
            {'params': module.parameters()} for module in cnn_model._modules.values()[5:]\
            ], lr=opt.cnn_learning_rate, weight_decay=opt.cnn_weight_decay)

    # Load the optimizer
    if vars(opt).get('start_from', None) is not None:
        if os.path.isfile(os.path.join(opt.start_from, 'optimizer.pth')):
            optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'optimizer.pth')))
        if opt.finetune_cnn_after != -1 and epoch >= opt.finetune_cnn_after:
            if os.path.isfile(os.path.join(opt.start_from, 'optimizer-cnn.pth')):
                cnn_optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'optimizer-cnn.pth')))
    while True:
        if update_lr_flag:
                # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate ** frac
                opt.current_lr = opt.learning_rate * decay_factor
                utils.set_lr(optimizer, opt.current_lr) # set the decayed rate
            else:
                opt.current_lr = opt.learning_rate
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob  * frac, opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob
            # Update the training stage of cnn
            if opt.finetune_cnn_after == -1 or epoch < opt.finetune_cnn_after:
                for p in cnn_model.parameters():
                    p.requires_grad = False
                cnn_model.eval()
            else:
                for p in cnn_model.parameters():
                    p.requires_grad = True
                # Fix the first few layers:
                for module in cnn_model._modules.values()[:5]:
                    for p in module.parameters():
                        p.requires_grad = False
                cnn_model.train()
            update_lr_flag = False
        # torch.cuda.synchronize()
        start = time.time()
        # Load data from train split (0)
        # for validation training change the split to 'val'
        # data = loader.get_batch('val')
        data = loader.get_batch('train')

        data['images'] = utils.prepro_images(data['images'], True)
        # torch.cuda.synchronize()
        print('Read data:', time.time() - start)

        # torch.cuda.synchronize()
        start = time.time()
        tmp = [data['images'], data['labels'], data['masks']]
        tmp = [Variable(torch.from_numpy(_), requires_grad=False).cuda() for _ in tmp]
        images, labels, masks = tmp

        att_feats = cnn_model(images).permute(0, 2, 3, 1)
        fc_feats = att_feats.mean(2).mean(1)

        if not opt.use_att:
            att_feats = Variable(torch.FloatTensor(1, 1,1,1).cuda())

        att_feats = att_feats.unsqueeze(1).expand(*((att_feats.size(0), opt.seq_per_img,) +
                                                    att_feats.size()[1:])).contiguous().view(*((att_feats.size(0) * opt.seq_per_img,)
                                                                                               + att_feats.size()[1:]))
        fc_feats = fc_feats.unsqueeze(1).expand(*((fc_feats.size(0), opt.seq_per_img,) +
                                                  fc_feats.size()[1:])).contiguous().view(*((fc_feats.size(0) * opt.seq_per_img,) +
                                                                                            fc_feats.size()[1:]))
        model.zero_grad()
        optimizer.zero_grad()
        if opt.finetune_cnn_after != -1 and epoch >= opt.finetune_cnn_after:
            cnn_optimizer.zero_grad()

        if opt.sentence_embed:
            sen_embed = Variable(torch.from_numpy(np.array(data['sen_embed'])).cuda())
            out = model(fc_feats, att_feats, labels, sen_embed)
            loss = crit(out, labels[:, 1:], masks[:, 1:])
            # loss += cov
        else:
            loss = crit(model(fc_feats, att_feats, labels), labels[:,1:], masks[:,1:])
               # - 0.001 * crit(model(torch.zeros(fc_feats.size()).cuda(), torch.zeros(att_feats.size()).cuda(), labels), labels[:,1:], masks[:,1:])
        loss.backward()
        # utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        if opt.finetune_cnn_after != -1 and epoch >= opt.finetune_cnn_after:
            utils.clip_gradient(cnn_optimizer, opt.grad_clip)
            cnn_optimizer.step()
        # train_loss = loss.data[0]
        train_loss = loss.item()
        # torch.cuda.synchronize()

        end = time.time()
        print("Step [{}/{}], Epoch [{}/{}],  train_loss = {:.3f}, time/batch = {:.3f}" \
            .format((iteration+1)%int(len(loader)/vars(opt)['batch_size']), int(len(loader)/vars(opt)['batch_size']),
                     epoch, vars(opt)['max_epochs'], train_loss, end - start))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            if tf is not None:
                add_summary_value(tf_summary_writer, 'train_loss', train_loss, iteration)
                add_summary_value(tf_summary_writer, 'learning_rate', opt.current_lr, iteration)
                add_summary_value(tf_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration)
                tf_summary_writer.flush()

            loss_history[iteration] = train_loss
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            # eval model
            eval_kwargs = {'split': 'val',
                            'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = eval_utils.eval_split(cnn_model, model, crit, loader, eval_kwargs)

            # Write validation result into summary
            if tf is not None:
                add_summary_value(tf_summary_writer, 'validation loss', val_loss, iteration)
                for k,v in lang_stats.items():
                    add_summary_value(tf_summary_writer, k, v, iteration)
                tf_summary_writer.flush()
            val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions}

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = - val_loss

            best_flag = False
            if True: # if true
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                checkpoint_path = os.path.join(opt.checkpoint_path + opt.caption_model, 'model.pth')
                cnn_checkpoint_path = os.path.join(opt.checkpoint_path + opt.caption_model, 'model-cnn.pth')
                torch.save(model.state_dict(), checkpoint_path)
                torch.save(cnn_model.state_dict(), cnn_checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                print("cnn model saved to {}".format(cnn_checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path + opt.caption_model, 'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)
                if opt.finetune_cnn_after != -1 and epoch >= opt.finetune_cnn_after:
                    cnn_optimizer_path = os.path.join(opt.checkpoint_path + opt.caption_model, 'optimizer-cnn.pth')
                    torch.save(cnn_optimizer.state_dict(), cnn_optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(os.path.join(opt.checkpoint_path + opt.caption_model, 'infos_'+opt.id+'.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(os.path.join(opt.checkpoint_path + opt.caption_model, 'histories_'+opt.id+'.pkl'), 'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path+ opt.caption_model, 'model-best.pth')
                    cnn_checkpoint_path = os.path.join(opt.checkpoint_path+ opt.caption_model, 'model-cnn-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    torch.save(cnn_model.state_dict(), cnn_checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    print("cnn model saved to {}".format(cnn_checkpoint_path))
                    with open(os.path.join(opt.checkpoint_path+ opt.caption_model, 'infos_'+opt.id+'-best.pkl'), 'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break