def test(opt):
    logger = Logger(opt)
    dataset = VISTDataset(opt)
    opt.vocab_size = dataset.get_vocab_size()
    opt.seq_length = dataset.get_story_length()

    dataset.test()
    test_loader = DataLoader(dataset, batch_size=opt.batch_size, shuffle=False, num_workers=opt.workers)
    evaluator = Evaluator(opt, 'test')
    model = models.setup(opt)
    model.cuda()
    predictions, metrics = evaluator.test_story(model, dataset, test_loader, opt)
コード例 #2
0
def train(opt):
    logger = Logger(opt)
    flag = Flag(D_iters=opt.D_iter, G_iters=opt.G_iter, always=opt.always)
    ################### set up dataset and dataloader ########################
    dataset = VISTDataset(opt)
    opt.vocab_size = dataset.get_vocab_size()
    opt.seq_length = dataset.get_story_length()

    dataset.set_option(data_type={
        'whole_story': False,
        'split_story': True,
        'caption': False
    })

    dataset.train()
    train_loader = DataLoader(dataset,
                              batch_size=opt.batch_size,
                              shuffle=opt.shuffle,
                              num_workers=opt.workers)
    dataset.val()
    val_loader = DataLoader(dataset,
                            batch_size=opt.batch_size,
                            shuffle=False,
                            num_workers=opt.workers)

    ##################### set up model, criterion and optimizer ######
    bad_valid = 0

    # set up evaluator
    evaluator = Evaluator(opt, 'val')

    # set up criterion
    crit = criterion.LanguageModelCriterion()
    rl_crit = criterion.ReinforceCriterion(opt, dataset)

    # set up model
    model = models.setup(opt)
    model.cuda()
    disc_opt = copy.copy(opt)
    disc_opt.model = 'RewardModel'
    disc = models.setup(disc_opt)
    if os.path.exists(os.path.join(logger.log_dir, 'disc-model.pth')):
        logging.info("loading pretrained RewardModel")
        disc.load_state_dict(
            torch.load(os.path.join(logger.log_dir, 'disc-model.pth')))
    disc.cuda()

    # set up optimizer
    optimizer = setup_optimizer(opt, model)
    disc_optimizer = setup_optimizer(opt, disc)

    dataset.train()
    model.train()
    disc.train()
    ############################## training ##################################
    for epoch in range(logger.epoch_start, opt.max_epochs):
        # Assign the scheduled sampling prob

        start = time.time()
        for iter, batch in enumerate(train_loader):
            logger.iteration += 1
            torch.cuda.synchronize()

            feature_fc = Variable(batch['feature_fc']).cuda()
            target = Variable(batch['split_story']).cuda()
            index = batch['index']

            optimizer.zero_grad()
            disc_optimizer.zero_grad()

            if flag.flag == "Disc":
                model.eval()
                disc.train()
                if opt.decoding_method_DISC == 'sample':
                    seq, seq_log_probs, baseline = model.sample(
                        feature_fc,
                        sample_max=False,
                        rl_training=True,
                        pad=True)
                elif opt.decoding_method_DISC == 'greedy':
                    seq, seq_log_probs, baseline = model.sample(
                        feature_fc,
                        sample_max=True,
                        rl_training=True,
                        pad=True)
            else:
                model.train()
                disc.eval()
                seq, seq_log_probs, baseline = model.sample(feature_fc,
                                                            sample_max=False,
                                                            rl_training=True,
                                                            pad=True)

            seq = Variable(seq).cuda()
            mask = (seq > 0).float()
            mask = to_contiguous(
                torch.cat([
                    Variable(
                        mask.data.new(mask.size(0), mask.size(1), 1).fill_(1)),
                    mask[:, :, :-1]
                ], 2))
            normed_seq_log_probs = (seq_log_probs *
                                    mask).sum(-1) / mask.sum(-1)

            gen_score = disc(seq.view(-1, seq.size(2)),
                             feature_fc.view(-1, feature_fc.size(2)))

            if flag.flag == "Disc":
                gt_score = disc(target.view(-1, target.size(2)),
                                feature_fc.view(-1, feature_fc.size(2)))
                loss = -torch.sum(gt_score) + torch.sum(gen_score)

                avg_pos_score = torch.mean(gt_score)
                avg_neg_score = torch.mean(gen_score)

                if logger.iteration % 5 == 0:
                    logging.info("pos reward {} neg reward {}".format(
                        avg_pos_score.data[0], avg_neg_score.data[0]))
                    print(
                        "PREDICTION: ",
                        utils.decode_story(dataset.get_vocab(),
                                           seq[:1].data)[0])
                    print(
                        "GROUND TRUTH: ",
                        utils.decode_story(dataset.get_vocab(),
                                           target[:1].data)[0])
            else:
                rewards = Variable(gen_score.data -
                                   0.001 * normed_seq_log_probs.data)
                #with open("/tmp/reward.txt", "a") as f:
                #    print(" ".join(map(str, rewards.data.cpu().numpy())), file=f)
                loss, avg_score = rl_crit(seq.data, seq_log_probs, baseline,
                                          index, rewards)
                # if logger.iteration % opt.losses_log_every == 0:
                avg_pos_score = torch.mean(gen_score)
                logging.info("average reward: {} average IRL score: {}".format(
                    avg_score.data[0], avg_pos_score.data[0]))

            if flag.flag == "Disc":
                loss.backward()
                nn.utils.clip_grad_norm(disc.parameters(),
                                        opt.grad_clip,
                                        norm_type=2)
                disc_optimizer.step()
            else:
                tf_loss = crit(model(feature_fc, target), target)
                print("rl_loss / tf_loss = ", loss.data[0] / tf_loss.data[0])
                loss = opt.rl_weight * loss + (1 - opt.rl_weight) * tf_loss
                loss.backward()
                nn.utils.clip_grad_norm(model.parameters(),
                                        opt.grad_clip,
                                        norm_type=2)
                optimizer.step()

            train_loss = loss.data[0]
            torch.cuda.synchronize()

            # Write the training loss summary
            if logger.iteration % opt.losses_log_every == 0:
                logger.log_training(epoch, iter, train_loss, opt.learning_rate,
                                    model.ss_prob)
                logging.info(
                    "Epoch {} Train {} - Iter {} / {}, loss = {:.5f}, time used = {:.3f}s"
                    .format(epoch, flag.flag, iter, len(train_loader),
                            train_loss,
                            time.time() - start))
                start = time.time()

            if logger.iteration % opt.save_checkpoint_every == 0:
                if opt.always is None:
                    # Evaluate on validation dataset and save model for every epoch
                    val_loss, predictions, metrics = evaluator.eval_story(
                        model, crit, dataset, val_loader, opt)
                    if opt.metric == 'XE':
                        score = -val_loss
                    else:
                        score = metrics[opt.metric]
                    logger.log_checkpoint(epoch, val_loss, metrics,
                                          predictions, opt, model, dataset,
                                          optimizer)
                    # halve the learning rate if not improving for a long time
                    if logger.best_val_score > score:
                        bad_valid += 1
                        if bad_valid >= 10:
                            opt.learning_rate = opt.learning_rate / 2.0
                            logging.info("halve learning rate to {}".format(
                                opt.learning_rate))
                            checkpoint_path = os.path.join(
                                logger.log_dir, 'model-best.pth')
                            model.load_state_dict(torch.load(checkpoint_path))
                            utils.set_lr(
                                optimizer,
                                opt.learning_rate)  # set the decayed rate
                            bad_valid = 0
                            logging.info("bad valid : {}".format(bad_valid))
                    else:
                        logging.info("achieving best {} score: {}".format(
                            opt.metric, score))
                        bad_valid = 0
                else:
                    torch.save(disc.state_dict(),
                               os.path.join(logger.log_dir, 'disc-model.pth'))
            flag.inc()
コード例 #3
0
ファイル: train_AREL.py プロジェクト: vivid-k/AREL
def train(opt):
    logger = Logger(opt)  # 定义 logger
    flag = Flag(D_iters=opt.D_iter, G_iters=opt.G_iter,
                always=opt.always)  # 初始化训练标签

    dataset = VISTDataset(opt)  # 加载数据
    opt.vocab_size = dataset.get_vocab_size()
    opt.seq_length = dataset.get_story_length()
    dataset.set_option(data_type={
        'whole_story': False,
        'split_story': True,
        'caption': False
    })
    dataset.train()
    train_loader = DataLoader(dataset,
                              batch_size=opt.batch_size,
                              shuffle=opt.shuffle)
    dataset.val()
    val_loader = DataLoader(dataset, batch_size=opt.batch_size, shuffle=False)
    bad_valid = 0

    evaluator = Evaluator(opt, 'val')
    crit = criterion.LanguageModelCriterion()
    rl_crit = criterion.ReinforceCriterion(opt, dataset)  # 强化学习的损失函数

    # set up model
    model = models.setup(opt)
    model.cuda()
    disc_opt = copy.copy(opt)
    disc_opt.model = 'RewardModel'  # 加入model属性
    disc = models.setup(disc_opt)  # 判别器模型,实例化哪个模型的类
    if os.path.exists(os.path.join('./data/save/',
                                   'disc-model.pth')):  # 若存在,则加载模型参数
        logging.info("loading pretrained RewardModel")
        disc.load_state_dict(
            torch.load(os.path.join(logger.log_dir, 'disc-model.pth')))
    disc.cuda()
    # 两个优化器,完全独立的两个模型
    optimizer = setup_optimizer(opt, model)
    disc_optimizer = setup_optimizer(disc_opt, disc)  # fix

    dataset.train()
    model.train()
    disc.train()
    ############################## training ##################################
    for epoch in range(logger.epoch_start, opt.max_epochs):  # 最大轮数为 50
        start = time.time()
        for iter, batch in enumerate(train_loader):  # 开始迭代
            logger.iteration += 1  # 记录迭代次数
            torch.cuda.synchronize()
            # 获取数据
            feature_fc = Variable(batch['feature_fc']).cuda()
            target = Variable(batch['split_story']).cuda()
            index = batch['index']

            optimizer.zero_grad()
            disc_optimizer.zero_grad()

            if flag.flag == "Disc":
                model.eval()  # policy model参数不更新
                disc.train()  # 更新判别器参数
                if opt.decoding_method_DISC == 'sample':  # True,返回 sample 的序列,根据概率分布 sample
                    seq, seq_log_probs, baseline = model.sample(
                        feature_fc,
                        sample_max=False,
                        rl_training=True,
                        pad=True)
                elif opt.decoding_method_DISC == 'greedy':
                    seq, seq_log_probs, baseline = model.sample(
                        feature_fc,
                        sample_max=True,
                        rl_training=True,
                        pad=True)
            else:
                model.train()  # 更新模型
                disc.eval()  # 判别器不更新
                seq, seq_log_probs, baseline = model.sample(feature_fc,
                                                            sample_max=False,
                                                            rl_training=True,
                                                            pad=True)

            seq = Variable(seq).cuda()
            mask = (seq > 0).float()  # 64,5,30
            mask = to_contiguous(
                torch.cat([
                    Variable(
                        mask.data.new(mask.size(0), mask.size(1), 1).fill_(1)),
                    mask[:, :, :-1]
                ], 2))
            normed_seq_log_probs = (seq_log_probs * mask).sum(-1) / mask.sum(
                -1)  # 64,5,得到整个序列的概率
            gen_score = disc(
                seq.view(-1, seq.size(2)),
                feature_fc.view(-1, feature_fc.size(2)))  # 计算sample序列的reward分数

            if flag.flag == "Disc":  # 先训练判别器,生成器已经预训练好。训练该判别器参数,使其能对标签和生成数据进行打分。
                gt_score = disc(target.view(-1, target.size(2)),
                                feature_fc.view(
                                    -1, feature_fc.size(2)))  # 计算真实序列的reward
                loss = -torch.sum(gt_score) + torch.sum(
                    gen_score)  # 计算损失,loss为负很正常
                # 计算平均 reward,训练判别器希望能尽可能pos高
                avg_pos_score = torch.mean(gt_score)
                avg_neg_score = torch.mean(gen_score)

                if logger.iteration % 5 == 0:
                    logging.info("pos reward {} neg reward {}".format(
                        avg_pos_score.item(), avg_neg_score.item()))
                    # print("PREDICTION: ", utils.decode_story(dataset.get_vocab(), seq[:1].data)[0])
                    # print("GROUND TRUTH: ", utils.decode_story(dataset.get_vocab(), target[:1].data)[0])
            else:
                rewards = Variable(gen_score.data -
                                   0 * normed_seq_log_probs.view(-1).data)
                #with open("/tmp/reward.txt", "a") as f:
                #    print(" ".join(map(str, rewards.data.cpu().numpy())), file=f)
                loss, avg_score = rl_crit(seq.data, seq_log_probs, baseline,
                                          index, rewards.view(-1, seq.size(1)))
                # if logger.iteration % opt.losses_log_every == 0:
                avg_pos_score = torch.mean(gen_score)
                # logging.info("average reward: {} average IRL score: {}".format(avg_score.item(), avg_pos_score.item()))

            if flag.flag == "Disc":
                loss.backward()
                nn.utils.clip_grad_norm(disc.parameters(),
                                        opt.grad_clip,
                                        norm_type=2)
                disc_optimizer.step()
            else:
                tf_loss = crit(model(feature_fc, target), target)
                # print("rl_loss / tf_loss = ", loss.item() / tf_loss.item())
                loss = opt.rl_weight * loss + (1 - opt.rl_weight) * tf_loss
                loss.backward()
                nn.utils.clip_grad_norm(model.parameters(),
                                        opt.grad_clip,
                                        norm_type=2)
                optimizer.step()

            train_loss = loss.item()
            torch.cuda.synchronize()

            # Write the training loss summary
            if logger.iteration % opt.losses_log_every == 0:
                logger.log_training(epoch, iter, train_loss, opt.learning_rate,
                                    model.ss_prob)
                logging.info(
                    "Epoch {} Train {} - Iter {} / {}, loss = {:.5f}, time used = {:.3f}s"
                    .format(epoch, flag.flag, iter, len(train_loader),
                            train_loss,
                            time.time() - start))
                start = time.time()

            if logger.iteration % opt.save_checkpoint_every == 0:
                if opt.always is None:
                    # Evaluate on validation dataset and save model for every epoch
                    val_loss, predictions, metrics = evaluator.eval_story(
                        model, crit, dataset, val_loader, opt)
                    if opt.metric == 'XE':
                        score = -val_loss
                    else:
                        score = metrics[opt.metric]
                    logger.log_checkpoint(epoch, val_loss, metrics,
                                          predictions, opt, model, dataset,
                                          optimizer)
                    # halve the learning rate if not improving for a long time
                    if logger.best_val_score > score:
                        bad_valid += 1
                        if bad_valid >= 10:
                            opt.learning_rate = opt.learning_rate / 2.0
                            logging.info("halve learning rate to {}".format(
                                opt.learning_rate))
                            checkpoint_path = os.path.join(
                                logger.log_dir, 'model-best.pth')
                            model.load_state_dict(torch.load(checkpoint_path))
                            utils.set_lr(
                                optimizer,
                                opt.learning_rate)  # set the decayed rate
                            bad_valid = 0
                            logging.info("bad valid : {}".format(bad_valid))
                    else:
                        logging.info("achieving best {} score: {}".format(
                            opt.metric, score))
                        bad_valid = 0
                else:
                    torch.save(disc.state_dict(),
                               os.path.join(logger.log_dir, 'disc-model.pth'))
            flag.inc()
コード例 #4
0
def train(opt):
    """
    模型训练函数
    """
    # 自定义的类,日志记录
    logger = Logger(opt)

    # 获取数据
    dataset = VISTDataset(opt)
    opt.vocab_size = dataset.get_vocab_size()
    opt.seq_length = dataset.get_story_length()
    # print(dataset.get_word2id()['the'])
    dataset.set_option(data_type={
        'whole_story': False,
        'split_story': True,
        'caption': True
    })  # 若不使用caption数据,则将其设为False
    dataset.train()
    train_loader = DataLoader(dataset,
                              batch_size=opt.batch_size,
                              shuffle=opt.shuffle)
    dataset.test()  # 改为valid
    val_loader = DataLoader(dataset, batch_size=opt.batch_size, shuffle=False)
    # m = dataset.word2id

    # 记录上升的 valid_loss 次数
    bad_valid = 0

    # 创建Evaluator
    evaluator = Evaluator(opt, 'val')
    # 损失
    crit = criterion.LanguageModelCriterion()
    # 是否使用强化学习,默认为-1
    if opt.start_rl >= 0:
        rl_crit = criterion.ReinforceCriterion(opt, dataset)

    # set up model,函数在init文件中,若有原来模型,则加载模型参数
    model = models.setup(opt)
    model.cuda()
    optimizer = setup_optimizer(opt, model)
    dataset.train()
    model.train()
    for epoch in range(logger.epoch_start, opt.max_epochs):  # 默认为 0-20
        # scheduled_sampling_start表示在第几个epoch,衰减gt使用概率,最大到0.25,5个epoch之内还是0
        if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
            frac = (
                epoch - opt.scheduled_sampling_start
            ) // opt.scheduled_sampling_increase_every  # 后者默认值为5,//为向下取整除
            opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                              opt.scheduled_sampling_max_prob)  # 0.05、0.25
            model.ss_prob = opt.ss_prob
        # 对数据进行一个batch一个batch的迭代
        for iter, batch in enumerate(train_loader):
            start = time.time()
            logger.iteration += 1
            torch.cuda.synchronize()

            # 获取batch中的数据,图像特征、caption、以及target
            features = Variable(batch['feature_fc']).cuda()  # 64*5*2048
            caption = None
            if opt.caption:
                caption = Variable(batch['caption']).cuda()  # 64*5*20
            target = Variable(batch['split_story']).cuda()  # 64*5*30
            index = batch['index']

            optimizer.zero_grad()

            # 模型运行,返回一个概率分布,然后计算交叉熵损失
            output = model(features, target, caption)
            loss = crit(output, target)

            if opt.start_rl >= 0 and epoch >= opt.start_rl:  # reinforcement learning
                # 获取 sample 数据和 baseline 数据
                seq, seq_log_probs, baseline = model.sample(features,
                                                            caption=caption,
                                                            sample_max=False,
                                                            rl_training=True)
                rl_loss, avg_score = rl_crit(seq, seq_log_probs, baseline,
                                             index)
                print(rl_loss.data[0] / loss.data[0])
                loss = opt.rl_weight * rl_loss + (1 - opt.rl_weight) * loss
                logging.info("average {} score: {}".format(
                    opt.reward_type, avg_score))
            # 反向传播
            loss.backward()
            train_loss = loss.item()
            # 梯度裁剪,第二个参数为梯度最大范数,大于该值则进行裁剪
            nn.utils.clip_grad_norm(model.parameters(),
                                    opt.grad_clip,
                                    norm_type=2)
            optimizer.step()
            torch.cuda.synchronize()
            # 日志记录时间以及损失
            logging.info(
                "Epoch {} - Iter {} / {}, loss = {:.5f}, time used = {:.3f}s".
                format(epoch, iter, len(train_loader), train_loss,
                       time.time() - start))
            # Write the training loss summary,tensorboard记录
            if logger.iteration % opt.losses_log_every == 0:
                logger.log_training(epoch, iter, train_loss, opt.learning_rate,
                                    model.ss_prob)
            # validation验证,每迭代save_checkpoint_every轮评测一次
            if logger.iteration % opt.save_checkpoint_every == 0:
                val_loss, predictions, metrics = evaluator.eval_story(
                    model, crit, dataset, val_loader, opt)
                if opt.metric == 'XE':
                    score = -val_loss
                else:
                    score = metrics[opt.metric]
                logger.log_checkpoint(epoch, val_loss, metrics, predictions,
                                      opt, model, dataset, optimizer)
                # halve the learning rate if not improving for a long time
                if logger.best_val_score > score:
                    bad_valid += 1
                    if bad_valid >= 4:
                        opt.learning_rate = opt.learning_rate / 2.0
                        logging.info("halve learning rate to {}".format(
                            opt.learning_rate))
                        checkpoint_path = os.path.join(logger.log_dir,
                                                       'model-best.pth')
                        model.load_state_dict(torch.load(checkpoint_path))
                        utils.set_lr(optimizer,
                                     opt.learning_rate)  # set the decayed rate
                        bad_valid = 0
                        logging.info("bad valid : {}".format(bad_valid))
                else:
                    logging.info("achieving best {} score: {}".format(
                        opt.metric, score))
                    bad_valid = 0
コード例 #5
0
def train(opt):
    # utils.setup_seed()
    logger = Logger(opt, save_code=opt.save_code)

    ################### set up dataset and dataloader ########################
    dataset = VISTDataset(opt)
    opt.vocab_size = dataset.get_vocab_size()
    opt.seq_length = dataset.get_story_length()

    dataset.set_option(data_type={'whole_story': False, 'split_story': True, 'caption': False, 'prefix_story': True})

    dataset.train()
    train_loader = DataLoader(dataset, batch_size=opt.batch_size, shuffle=opt.shuffle, num_workers=opt.workers)
    dataset.val()
    val_loader = DataLoader(dataset, batch_size=opt.batch_size, shuffle=False, num_workers=opt.workers)

    ##################### set up model, criterion and optimizer ######
    bad_valid = 0

    # set up evaluator
    evaluator = Evaluator(opt, 'val')

    # set up criterion
    crit = criterion.LanguageModelCriterion()

    # set up model
    model = models.setup(opt)
    model.cuda()

    # set up optimizer
    optimizer = setup_optimizer(opt, model)

    dataset.train()
    model.train()
    initial_lr = opt.learning_rate
    logging.info(model)
    ############################## training ##################################
    for epoch in range(logger.epoch_start, opt.max_epochs):
        # Assign the scheduled sampling prob
        if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
            frac = (epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every
            opt.ss_prob = min(opt.scheduled_sampling_increase_prob *
                              frac, opt.scheduled_sampling_max_prob)
            model.ss_prob = opt.ss_prob

        for iter, batch in enumerate(train_loader):
            start = time.time()
            logger.iteration += 1
            torch.cuda.synchronize()

            feature_fc = batch['feature_fc'].cuda()
            if opt.use_obj:
                feature_obj = batch['feature_obj'].cuda()
                if opt.use_spatial:
                    feature_obj_spatial = batch['feature_obj_spatial'].cuda()
                else:
                    feature_obj_spatial = None
                if opt.use_classes:
                    feature_obj_classes = batch['feature_obj_classes'].cuda()
                else:
                    feature_obj_classes = None
                if opt.use_attrs:
                    feature_obj_attrs = batch['feature_obj_attrs'].cuda()
                else:
                    feature_obj_attrs = None
            target = batch['split_story'].cuda()
            prefix = batch['prefix_story'].cuda()
            history_count = batch['history_counter'].cuda()
            index = batch['index']

            optimizer.zero_grad()

            # cross entropy loss
            output = model(feature_fc, feature_obj, target, history_count, spatial=feature_obj_spatial,
                               clss=feature_obj_classes, attrs=feature_obj_attrs)
            loss = crit(output, target)

            loss.backward()
            train_loss = loss.item()

            nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip, norm_type=2)
            optimizer.step()
            torch.cuda.synchronize()

            if iter % opt.log_step == 0:
                logging.info("Epoch {} - Iter {} / {}, loss = {:.5f}, time used = {:.3f}s".format(epoch, iter,
                                                                                                  len(train_loader),
                                                                                                  train_loss,
                                                                                                  time.time() - start))
            # Write the training loss summary
            if logger.iteration % opt.losses_log_every == 0:
                logger.log_training(epoch, iter, train_loss, opt.learning_rate, model.ss_prob)

            if logger.iteration % opt.save_checkpoint_every == 0:
                # Evaluate on validation dataset and save model for every epoch
                val_loss, predictions, metrics = evaluator.eval_story(model, crit, dataset, val_loader, opt)
                if opt.metric == 'XE':
                    score = -val_loss
                else:
                    score = metrics[opt.metric]
                logger.log_checkpoint(epoch, val_loss, metrics, predictions, opt, model, dataset, optimizer)
                # halve the learning rate if not improving for a long time
                if logger.best_val_score > score:
                    bad_valid += 1
                    if bad_valid >= opt.bad_valid_threshold:
                        opt.learning_rate = opt.learning_rate * opt.learning_rate_decay_rate
                        logging.info("halve learning rate to {}".format(opt.learning_rate))
                        checkpoint_path = os.path.join(logger.log_dir, 'model-best.pth')
                        model.load_state_dict(torch.load(checkpoint_path))
                        utils.set_lr(optimizer, opt.learning_rate)  # set the decayed rate
                        bad_valid = 0
                        logging.info("bad valid : {}".format(bad_valid))
                else:
                    opt.learning_rate = initial_lr
                    logging.info("achieving best {} score: {}".format(opt.metric, score))
                    bad_valid = 0
def train(opt):
    setup_seed()
    logger = Logger(opt)

    ################### set up dataset and dataloader ########################
    dataset = VISTDataset(opt)
    opt.vocab_size = dataset.get_vocab_size()
    opt.seq_length = dataset.get_story_length()

    dataset.set_option(data_type={'whole_story': False, 'split_story': True, 'caption': False})

    dataset.train()
    train_loader = DataLoader(dataset, batch_size=opt.batch_size, shuffle=opt.shuffle, num_workers=opt.workers)
    dataset.val()
    val_loader = DataLoader(dataset, batch_size=opt.batch_size, shuffle=False, num_workers=opt.workers)

    ##################### set up model, criterion and optimizer ######
    bad_valid = 0

    # set up evaluator
    evaluator = Evaluator(opt, 'val')

    # set up criterion
    crit = criterion.LanguageModelCriterion()
    if opt.start_rl >= 0:
        rl_crit = criterion.ReinforceCriterion(opt, dataset)

    # set up model
    model = models.setup(opt)
    model.cuda()

    # set up optimizer
    optimizer = setup_optimizer(opt, model)

    dataset.train()
    model.train()
    ############################## training ##################################
    for epoch in range(logger.epoch_start, opt.max_epochs):
        # Assign the scheduled sampling prob
        if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
            frac = (epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every
            opt.ss_prob = min(opt.scheduled_sampling_increase_prob *
                              frac, opt.scheduled_sampling_max_prob)
            model.ss_prob = opt.ss_prob

        for iter, batch in enumerate(train_loader):
            start = time.time()
            logger.iteration += 1
            torch.cuda.synchronize()

            feature_fc = Variable(batch['feature_fc']).cuda()
            target = Variable(batch['split_story']).cuda()
            index = batch['index']
            semantic = batch['semantic']

            optimizer.zero_grad()

            # cross entropy loss
            output = model(feature_fc, target, semantic)
            loss = crit(output, target)

            if opt.start_rl >= 0 and epoch >= opt.start_rl:  # reinforcement learning
                seq, seq_log_probs, baseline = model.sample(feature_fc, sample_max=False, rl_training=True)
                rl_loss, avg_score = rl_crit(seq, seq_log_probs, baseline, index)
                print(rl_loss.data[0] / loss.data[0])
                loss = opt.rl_weight * rl_loss + (1 - opt.rl_weight) * loss
                logging.info("average {} score: {}".format(opt.reward_type, avg_score))

            loss.backward()
            train_loss = loss.data[0]

            nn.utils.clip_grad_norm(model.parameters(), opt.grad_clip, norm_type=2)
            optimizer.step()
            torch.cuda.synchronize()

            logging.info("Epoch {} - Iter {} / {}, loss = {:.5f}, time used = {:.3f}s".format(epoch, iter,
                                                                                              len(train_loader),
                                                                                              train_loss,
                                                                                              time.time() - start))
            # Write the training loss summary
            if logger.iteration % opt.losses_log_every == 0:
                logger.log_training(epoch, iter, train_loss, opt.learning_rate, model.ss_prob)

            if logger.iteration % opt.save_checkpoint_every == 0:
                # Evaluate on validation dataset and save model for every epoch
                val_loss, predictions, metrics = evaluator.eval_story(model, crit, dataset, val_loader, opt)
                if opt.metric == 'XE':
                    score = -val_loss
                else:
                    score = metrics[opt.metric]
                logger.log_checkpoint(epoch, val_loss, metrics, predictions, opt, model, dataset, optimizer)
                # halve the learning rate if not improving for a long time
                if logger.best_val_score > score:
                    bad_valid += 1
                    if bad_valid >= 4:
                        opt.learning_rate = opt.learning_rate / 2.0
                        logging.info("halve learning rate to {}".format(opt.learning_rate))
                        checkpoint_path = os.path.join(logger.log_dir, 'model-best.pth')
                        model.load_state_dict(torch.load(checkpoint_path))
                        utils.set_lr(optimizer, opt.learning_rate)  # set the decayed rate
                        bad_valid = 0
                        logging.info("bad valid : {}".format(bad_valid))
                else:
                    logging.info("achieving best {} score: {}".format(opt.metric, score))
                    bad_valid = 0