Beispiel #1
0
    def train(self, train_db, val_db, test_db):
        ##################################################################
        ## LOG
        ##################################################################
        logz.configure_output_dir(self.cfg.model_dir)
        logz.save_config(self.cfg)

        ##################################################################
        ## Main loop
        ##################################################################
        start = time()
        min_val_loss = 100000000
        for epoch in range(self.epoch, self.cfg.n_epochs):
            ##################################################################
            ## Training
            ##################################################################
            torch.cuda.empty_cache()
            train_loss, train_accu = self.train_epoch(train_db, epoch)

            ##################################################################
            ## Validation
            ##################################################################
            torch.cuda.empty_cache()
            val_loss, val_accu = self.validate_epoch(val_db, epoch)

            ##################################################################
            ## Logging
            ##################################################################

            # update optim scheduler
            current_val_loss = np.mean(val_loss[:,0])
            # self.optimizer.update(current_val_loss, epoch)
            logz.log_tabular("Time", time() - start)
            logz.log_tabular("Iteration", epoch)
            logz.log_tabular("AverageLoss",         np.mean(train_loss[:, 0]))
            logz.log_tabular("AveragePredLoss",     np.mean(train_loss[:, 1]))
            logz.log_tabular("AverageEmbedLoss",    np.mean(train_loss[:, 2]))
            logz.log_tabular("AverageAttnLoss",     np.mean(train_loss[:, 3]))
            logz.log_tabular("AverageObjAccu",      np.mean(train_accu[:, 0]))
            logz.log_tabular("AverageCoordAccu",    np.mean(train_accu[:, 1]))
            logz.log_tabular("AverageScaleAccu",    np.mean(train_accu[:, 2]))
            logz.log_tabular("AverageRatioAccu",    np.mean(train_accu[:, 3]))

            logz.log_tabular("ValAverageLoss",      np.mean(val_loss[:, 0]))
            logz.log_tabular("ValAveragePredLoss",  np.mean(val_loss[:, 1]))
            logz.log_tabular("ValAverageEmbedLoss", np.mean(val_loss[:, 2]))
            logz.log_tabular("ValAverageAttnLoss",  np.mean(val_loss[:, 3]))
            logz.log_tabular("ValAverageObjAccu",   np.mean(val_accu[:, 0]))
            logz.log_tabular("ValAverageCoordAccu", np.mean(val_accu[:, 1]))
            logz.log_tabular("ValAverageScaleAccu", np.mean(val_accu[:, 2]))
            logz.log_tabular("ValAverageRatioAccu", np.mean(val_accu[:, 3]))
            logz.dump_tabular()

            ##################################################################
            ## Checkpoint
            ##################################################################
            self.save_checkpoint(epoch)
Beispiel #2
0
def train_PG(opt):
    start = time.time()

    # Configure output directory for logging
    logz.configure_output_dir(opt.logdir)

    # Log experimental parameters
    logz.save_config(opt)

    # Set random seeds
    np.random.seed(opt.seed)
    torch.manual_seed(opt.seed)
    if opt.cuda:
        torch.cuda.manual_seed_all(opt.seed)

    # Make the gym environment
    env = gym.make(opt.env_name)

    # Is this env continuous, or discrete?
    discrete = isinstance(env.action_space, gym.spaces.Discrete)

    # Maximum length for episodes
    max_path_length = opt.max_path_length or env.spec.max_episode_steps

    # Observation and action sizes
    ob_dim = env.observation_space.shape[0]
    ac_dim = env.action_space.n if discrete else env.action_space.shape[0]

    # Policy net, the underlying model is a mlp
    policy_net = PolicyNet(ob_dim, ac_dim, opt.n_layers, opt.size, discrete)
    optimizer = optim.Adam(policy_net.parameters(), lr=opt.learning_rate)

    # Neural network baseline (the mean reward to reduce the variance of the gradient)
    if opt.nn_baseline:
        baseline_net = PolicyNet(ob_dim, 1, opt.n_layers, opt.size, False)
        baseline_optimizer = optim.Adam(baseline_net.parameters(),
                                        lr=opt.learning_rate)

    total_timesteps = 0
    for itr in range(opt.n_iter):
        print("********** Iteration %i ************" % itr)
        # Collect paths until we have enough timesteps

        # policy net turns to evaluation mode
        policy_net.eval()

        # the batch size of this iteration
        # yes, the batch size for each iteration varies
        timesteps_this_batch = 0
        paths = []
        while True:
            ob = env.reset()
            obs, acs, rewards = [], [], []
            animate_this_episode = (len(paths) == 0 and (itr % 10 == 0)
                                    and opt.render)
            steps = 0
            while True:
                if animate_this_episode:
                    env.render()
                    time.sleep(0.05)

                # collect observations
                obs.append(ob)

                # run the policy net to collect the actions
                obs_th = torch.from_numpy(ob[None, :]).float()
                if opt.cuda:
                    obs_th = obs_th.cuda()

                if discrete:
                    probs, _ = policy_net(Variable(obs_th))
                    # multinomial sampling
                    m = torch.distributions.Categorical(probs)
                    ac = m.sample().data.numpy()
                else:
                    mean, sigma = policy_net(Variable(obs_th))
                    ac = sigma * (mean + Variable(torch.randn(mean.size())))
                    ac = ac.data.numpy()
                ac = ac[0]
                acs.append(ac)

                # collect rewards
                ob, rew, done, _ = env.step(ac)
                rewards.append(rew)

                steps += 1
                if done or steps > max_path_length:
                    break

            path = {
                "observation": np.array(obs),
                "reward": np.array(rewards),
                "action": np.array(acs)
            }
            paths.append(path)

            timesteps_this_batch += pathlength(path)

            if timesteps_this_batch > opt.batch_size:
                break
        total_timesteps += timesteps_this_batch

        # Build arrays for observation, action for the policy gradient update by concatenating
        # across paths
        ob_no = np.concatenate([path["observation"] for path in paths])
        ac_na = np.concatenate([path["action"] for path in paths])

        # discount for reward computation
        gamma = opt.discount
        # trajectory rewards
        q_n = []
        if opt.reward_to_go:
            for path in paths:
                n_samples = len(path["reward"])
                tr = 0.0
                pr = []
                for k in range(n_samples - 1, -1, -1):
                    cr = path["reward"][k]
                    tr = gamma * tr + cr
                    pr.append(tr)
                q_n.extend(pr[::-1])
        else:
            for path in paths:
                n_samples = len(path["reward"])
                tr = 0.0
                for k in range(n_samples - 1, -1, -1):
                    cr = path["reward"][k]
                    tr = gamma * tr + cr
                pr = np.ones((n_samples, )) * tr
                q_n.extend(pr.tolist())
        q_n = np.array(q_n)

        # If the neural network baseline is used
        # The predicted mean rewards should be subtracted from
        # the Q values
        if opt.nn_baseline:
            baseline_net.eval()
            obs_th = torch.from_numpy(ob_no).float()
            if opt.cuda:
                obs_th = obs_th.cuda()
            b_n, _ = baseline_net(Variable(obs_th))
            b_n = torch.squeeze(b_n, dim=-1)
            b_n = b_n.data.numpy()
            adv_n = q_n - b_n
        else:
            adv_n = q_n.copy()

        # Normalize the advantages
        # an empirical way to reduce the variance of the gradient
        # may or may not help, just an option to try

        if not (opt.dont_normalize_advantages):
            mu = np.mean(adv_n)
            sigma = np.std(adv_n)
            adv_n = (adv_n - mu) / sigma

        # Pytorch tensors
        batch_size = ob_no.shape[0]
        ob_no_th = torch.from_numpy(ob_no).float()
        ac_na_th = torch.from_numpy(ac_na).float()
        adv_n_th = torch.from_numpy(adv_n).float()

        if opt.cuda:
            ob_no_th = ob_no_th.cuda()
            ac_na_th = ac_na_th.cuda()
            adv_n_th = adv_n_th.cuda()

        if opt.nn_baseline:
            # train the baseline network
            q_mu = np.mean(q_n)
            q_sigma = np.std(q_n)
            q_n_th = torch.from_numpy((q_n - q_mu) / q_sigma).float()
            if opt.cuda:
                q_n_th = q_n_th.cuda()

            baseline_net.train()
            baseline_criterion = nn.MSELoss()

            baseline_net.zero_grad()
            pred, _ = baseline_net(Variable(ob_no_th))
            pred = torch.squeeze(pred, dim=-1)
            baseline_loss = baseline_criterion(pred, Variable(q_n_th))
            baseline_loss.backward()
            baseline_optimizer.step()
            baseline_error = baseline_loss.data[0]

        policy_net.train()
        policy_net.zero_grad()
        if discrete:
            probs, _ = policy_net(Variable(ob_no_th))
            # multinomial sampling, a biased exploration
            m = torch.distributions.Categorical(probs)
            loss = -m.log_prob(Variable(ac_na_th)) * Variable(adv_n_th)
            loss = loss.mean()
        else:
            means, sigma = policy_net(Variable(ob_no_th))
            diff = (means - Variable(ac_na_th, requires_grad=False) /
                    (sigma + 1e-8))**2
            loss = torch.mean(torch.sum(diff, -1) * Variable(adv_n_th))

        loss.backward()
        optimizer.step()
        error = loss.data[0]

        # Log diagnostics
        returns = [path["reward"].sum() for path in paths]
        ep_lengths = [pathlength(path) for path in paths]
        logz.log_tabular("Time", time.time() - start)
        logz.log_tabular("Iteration", itr)
        logz.log_tabular("AverageReturn", np.mean(returns))
        logz.log_tabular("StdReturn", np.std(returns))
        logz.log_tabular("MaxReturn", np.max(returns))
        logz.log_tabular("MinReturn", np.min(returns))
        logz.log_tabular("EpLenMean", np.mean(ep_lengths))
        logz.log_tabular("EpLenStd", np.std(ep_lengths))
        logz.log_tabular("TimestepsThisBatch", timesteps_this_batch)
        logz.log_tabular("TimestepsSoFar", total_timesteps)
        logz.log_tabular("Loss", error)
        if opt.nn_baseline:
            logz.log_tabular("BaselineLoss", baseline_error)
        logz.dump_tabular()
Beispiel #3
0
    def train(self, train_db, val_db, test_db):
        ##################################################################
        ## Optimizer
        ##################################################################
        image_encoder_trainable_paras = \
            filter(lambda p: p.requires_grad, self.net.image_encoder.parameters())
        raw_optimizer = optim.Adam([
                {'params': self.net.text_encoder.embedding.parameters(), 'lr': self.cfg.finetune_lr},
                {'params': image_encoder_trainable_paras, 'lr': self.cfg.finetune_lr},
                {'params': self.net.text_encoder.rnn.parameters()},
                {'params': self.net.what_decoder.parameters()}, 
                {'params': self.net.where_decoder.parameters()}
            ], lr=self.cfg.lr)
        optimizer = Optimizer(raw_optimizer, max_grad_norm=self.cfg.grad_norm_clipping)
        # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer.optimizer, factor=0.8, patience=3)
        scheduler = optim.lr_scheduler.StepLR(optimizer.optimizer, step_size=3, gamma=0.8)
        optimizer.set_scheduler(scheduler)

        ##################################################################
        ## LOG
        ##################################################################
        logz.configure_output_dir(self.cfg.model_dir)
        logz.save_config(self.cfg)

        ##################################################################
        ## Main loop
        ##################################################################
        start = time()
        for epoch in range(self.cfg.n_epochs):
            ##################################################################
            ## Training
            ##################################################################
            torch.cuda.empty_cache()
            train_pred_loss, train_attn_loss, train_eos_loss, train_accu = \
                self.train_epoch(train_db, optimizer, epoch)
        
            ##################################################################
            ## Validation
            ##################################################################
            torch.cuda.empty_cache()
            val_loss, val_accu, val_infos = self.validate_epoch(val_db)
            
            ##################################################################
            ## Sample
            ##################################################################
            torch.cuda.empty_cache()
            self.sample(epoch, test_db, self.cfg.n_samples)
            torch.cuda.empty_cache()
            ##################################################################
            ## Logging
            ##################################################################

            # update optim scheduler
            optimizer.update(np.mean(val_loss), epoch)
                
            logz.log_tabular("Time", time() - start)
            logz.log_tabular("Iteration", epoch)

            logz.log_tabular("TrainAverageError", np.mean(train_pred_loss))
            logz.log_tabular("TrainStdError", np.std(train_pred_loss))
            logz.log_tabular("TrainMaxError", np.max(train_pred_loss))
            logz.log_tabular("TrainMinError", np.min(train_pred_loss))
            logz.log_tabular("TrainAverageAccu", np.mean(train_accu))
            logz.log_tabular("TrainStdAccu", np.std(train_accu))
            logz.log_tabular("TrainMaxAccu", np.max(train_accu))
            logz.log_tabular("TrainMinAccu", np.min(train_accu))
            
            logz.log_tabular("ValAverageError", np.mean(val_loss))
            logz.log_tabular("ValStdError", np.std(val_loss))
            logz.log_tabular("ValMaxError", np.max(val_loss))
            logz.log_tabular("ValMinError", np.min(val_loss))
            logz.log_tabular("ValAverageAccu", np.mean(val_accu))
            logz.log_tabular("ValStdAccu", np.std(val_accu))
            logz.log_tabular("ValMaxAccu", np.max(val_accu))
            logz.log_tabular("ValMinAccu", np.min(val_accu))

            logz.log_tabular("ValAverageObjAccu", np.mean(val_accu[:, 0]))
            logz.log_tabular("ValStdObjAccu", np.std(val_accu[:, 0]))
            logz.log_tabular("ValMaxObjAccu", np.max(val_accu[:, 0]))
            logz.log_tabular("ValMinObjAccu", np.min(val_accu[:, 0]))

            logz.log_tabular("ValAveragePoseAccu", np.mean(val_accu[:, 1]))
            logz.log_tabular("ValStdPoseAccu", np.std(val_accu[:, 1]))
            logz.log_tabular("ValMaxPoseAccu", np.max(val_accu[:, 1]))
            logz.log_tabular("ValMinPoseAccu", np.min(val_accu[:, 1]))

            logz.log_tabular("ValAverageExprAccu", np.mean(val_accu[:, 2]))
            logz.log_tabular("ValStdExprAccu", np.std(val_accu[:, 2]))
            logz.log_tabular("ValMaxExprAccu", np.max(val_accu[:, 2]))
            logz.log_tabular("ValMinExprAccu", np.min(val_accu[:, 2]))

            logz.log_tabular("ValAverageCoordAccu", np.mean(val_accu[:, 3]))
            logz.log_tabular("ValStdCoordAccu", np.std(val_accu[:, 3]))
            logz.log_tabular("ValMaxCoordAccu", np.max(val_accu[:, 3]))
            logz.log_tabular("ValMinCoordAccu", np.min(val_accu[:, 3]))

            logz.log_tabular("ValAverageScaleAccu", np.mean(val_accu[:, 4]))
            logz.log_tabular("ValStdScaleAccu", np.std(val_accu[:, 4]))
            logz.log_tabular("ValMaxScaleAccu", np.max(val_accu[:, 4]))
            logz.log_tabular("ValMinScaleAccu", np.min(val_accu[:, 4]))

            logz.log_tabular("ValAverageFlipAccu", np.mean(val_accu[:, 5]))
            logz.log_tabular("ValStdFlipAccu", np.std(val_accu[:, 5]))
            logz.log_tabular("ValMaxFlipAccu", np.max(val_accu[:, 5]))
            logz.log_tabular("ValMinFlipAccu", np.min(val_accu[:, 5]))


            logz.log_tabular("ValUnigramF3", np.mean(val_infos.unigram_F3()))
            logz.log_tabular("ValBigramF3",  np.mean(val_infos.bigram_F3()))
            logz.log_tabular("ValUnigramP",  np.mean(val_infos.unigram_P()))
            logz.log_tabular("ValUnigramR",  np.mean(val_infos.unigram_R()))
            logz.log_tabular("ValBigramP",   val_infos.mean_bigram_P())
            logz.log_tabular("ValBigramR",   val_infos.mean_bigram_R())

            logz.log_tabular("ValUnigramPose",  np.mean(val_infos.pose()))
            logz.log_tabular("ValUnigramExpr",  np.mean(val_infos.expr()))
            logz.log_tabular("ValUnigramScale", np.mean(val_infos.scale()))
            logz.log_tabular("ValUnigramFlip",  np.mean(val_infos.flip()))
            logz.log_tabular("ValUnigramSim",   np.mean(val_infos.unigram_coord()))
            logz.log_tabular("ValBigramSim",    val_infos.mean_bigram_coord())

            logz.dump_tabular()

            ##################################################################
            ## Checkpoint
            ##################################################################
            log_info = [np.mean(val_loss), np.mean(val_accu)]
            self.save_checkpoint(epoch, log_info)
            torch.cuda.empty_cache()
Beispiel #4
0
    def test(self, test_db):
        ##################################################################
        ## LOG
        ##################################################################
        logz.configure_output_dir(self.cfg.model_dir)
        logz.save_config(self.cfg)
        start = time()
        test_loaddb = region_loader(test_db)
        test_loader = DataLoader(test_loaddb,
                                 batch_size=self.cfg.batch_size,
                                 shuffle=False,
                                 num_workers=self.cfg.num_workers,
                                 collate_fn=region_collate_fn)

        sample_mode = 0 if self.cfg.rl_finetune > 0 else self.cfg.explore_mode
        all_txt_feats, all_img_feats, all_img_masks, losses = [], [], [], []
        self.net.eval()
        for cnt, batched in enumerate(test_loader):
            ##################################################################
            ## Batched data
            ##################################################################
            scene_inds, sent_inds, sent_msks, region_feats, region_masks, region_clses = self.batch_data(
                batched)

            ##################################################################
            ## Inference one step
            ##################################################################
            with torch.no_grad():
                img_feats, masked_feats, txt_feats, subspace_masks, sample_logits, sample_indices = \
                    self.net(scene_inds, sent_inds, sent_msks, None, None, None, region_feats, region_clses, region_masks, sample_mode=sample_mode)
                txt_masks = txt_feats.new_ones(txt_feats.size(0),
                                               txt_feats.size(1))
                batch_losses = self.net.final_loss(img_feats, masked_feats,
                                                   region_masks, txt_feats,
                                                   txt_masks, sample_logits,
                                                   sample_indices)
                loss = torch.sum(torch.mean(batch_losses, -1))
            losses.append(loss.cpu().data.item())
            all_txt_feats.append(txt_feats)
            all_img_masks.append(region_masks)
            if self.cfg.subspace_alignment_mode > 0:
                all_img_feats.append(masked_feats)
            else:
                all_img_feats.append(img_feats)
            ##################################################################
            ## Print info
            ##################################################################
            if cnt % self.cfg.log_per_steps == 0:
                print('Iter %07d:' % (cnt))
                tmp_losses = np.stack(losses, 0)
                print('mean loss: ', np.mean(tmp_losses))
                print('-------------------------')

        torch.cuda.empty_cache()
        losses = np.array(losses)
        all_img_feats = torch.cat(all_img_feats, 0)
        all_img_masks = torch.cat(all_img_masks, 0)
        all_txt_feats = torch.cat(all_txt_feats, 0)
        all_txt_masks = all_txt_feats.new_ones(all_txt_feats.size(0),
                                               all_txt_feats.size(1))

        # print('all_img_feats', all_img_feats.size())
        all_img_feats_np = all_img_feats.cpu().data.numpy()
        all_img_masks_np = all_img_masks.cpu().data.numpy()
        with open(
                osp.join(self.cfg.model_dir,
                         'img_features_%d.pkl' % self.cfg.n_feature_dim),
                'wb') as fid:
            pickle.dump({
                'feats': all_img_feats_np,
                'masks': all_img_masks_np
            }, fid, pickle.HIGHEST_PROTOCOL)

        ##################################################################
        ## Evaluation
        ##################################################################
        print('Evaluating the per-turn performance, may take a while.')
        metrics, caches_results = self.net.evaluate(all_img_feats,
                                                    all_img_masks,
                                                    all_txt_feats)

        with open(osp.join(self.cfg.model_dir, 'test_metrics.json'),
                  'w') as fp:
            json.dump(metrics, fp, indent=4, sort_keys=True)
        with open(osp.join(self.cfg.model_dir, 'test_caches.pkl'),
                  'wb') as fid:
            pickle.dump(caches_results, fid, pickle.HIGHEST_PROTOCOL)

        visualize(self.cfg.exp_name, metrics,
                  osp.join(self.cfg.model_dir, 'evaluation.png'))

        return losses, metrics, caches_results
Beispiel #5
0
    def train(self, train_db, val_db, test_db):
        ##################################################################
        ## LOG
        ##################################################################
        logz.configure_output_dir(self.cfg.model_dir)
        logz.save_config(self.cfg)

        ##################################################################
        ## Main loop
        ##################################################################
        start = time()
        min_val_loss = 1000.0
        max_val_recall = -1.0
        train_loaddb = region_loader(train_db)
        val_loaddb = region_loader(val_db)
        #TODO
        train_loader = DataLoader(train_loaddb,
                                  batch_size=self.cfg.batch_size,
                                  shuffle=True,
                                  num_workers=self.cfg.num_workers,
                                  collate_fn=region_collate_fn)
        val_loader = DataLoader(val_loaddb,
                                batch_size=self.cfg.batch_size,
                                shuffle=False,
                                num_workers=self.cfg.num_workers,
                                collate_fn=region_collate_fn)

        for epoch in range(self.epoch, self.cfg.n_epochs):
            ##################################################################
            ## Training
            ##################################################################
            if self.cfg.coco_mode >= 0:
                self.cfg.coco_mode = np.random.randint(0, self.cfg.max_turns)
            torch.cuda.empty_cache()
            train_losses = self.train_epoch(train_loaddb, train_loader, epoch)

            ##################################################################
            ## Validation
            ##################################################################
            if self.cfg.coco_mode >= 0:
                self.cfg.coco_mode = 0
            torch.cuda.empty_cache()
            val_losses, val_metrics, caches_results = self.validate_epoch(
                val_loaddb, val_loader, epoch)

            #################################################################
            # Logging
            #################################################################

            # update optim scheduler
            current_val_loss = np.mean(val_losses)
            self.optimizer.update(current_val_loss, epoch)
            logz.log_tabular("Time", time() - start)
            logz.log_tabular("Iteration", epoch)
            logz.log_tabular("TrainAverageLoss", np.mean(train_losses))
            logz.log_tabular("ValAverageLoss", current_val_loss)

            mmm = np.zeros((5, ), dtype=np.float64)
            for k, v in val_metrics.items():
                mmm = mmm + np.array(v)
            mmm /= len(val_metrics)
            logz.log_tabular("t2i_R1", mmm[0])
            logz.log_tabular("t2i_R5", mmm[1])
            logz.log_tabular("t2i_R10", mmm[2])
            logz.log_tabular("t2i_medr", mmm[3])
            logz.log_tabular("t2i_meanr", mmm[4])
            logz.dump_tabular()
            current_val_recall = np.mean(mmm[:3])

            ##################################################################
            ## Checkpoint
            ##################################################################
            if self.cfg.rl_finetune == 0 and self.cfg.coco_mode < 0:
                if min_val_loss > current_val_loss:
                    min_val_loss = current_val_loss
                    self.save_checkpoint(epoch)
                    with open(
                            osp.join(self.cfg.model_dir,
                                     'val_metrics_%d.json' % epoch),
                            'w') as fp:
                        json.dump(val_metrics, fp, indent=4, sort_keys=True)
                    with open(
                            osp.join(self.cfg.model_dir,
                                     'val_top5_inds_%d.pkl' % epoch),
                            'wb') as fid:
                        pickle.dump(caches_results, fid,
                                    pickle.HIGHEST_PROTOCOL)
            else:
                if max_val_recall < current_val_recall:
                    max_val_recall = current_val_recall
                    self.save_checkpoint(epoch)
                    with open(
                            osp.join(self.cfg.model_dir,
                                     'val_metrics_%d.json' % epoch),
                            'w') as fp:
                        json.dump(val_metrics, fp, indent=4, sort_keys=True)
                    with open(
                            osp.join(self.cfg.model_dir,
                                     'val_top5_inds_%d.pkl' % epoch),
                            'wb') as fid:
                        pickle.dump(caches_results, fid,
                                    pickle.HIGHEST_PROTOCOL)
Beispiel #6
0
    def test(self, test_db):
        ##################################################################
        ## LOG
        ##################################################################
        logz.configure_output_dir(self.cfg.model_dir)
        logz.save_config(self.cfg)
        start = time()
        test_loaddb = caption_loader(test_db)
        test_loader = DataLoader(test_loaddb,
                                 batch_size=self.cfg.batch_size,
                                 shuffle=False,
                                 num_workers=self.cfg.num_workers,
                                 collate_fn=caption_collate_fn)

        all_txt_feats, all_img_feats, losses = [], [], []
        self.net.eval()
        for cnt, batched in enumerate(test_loader):
            ##################################################################
            ## Batched data
            ##################################################################
            images, sent_inds, sent_msks = self.batch_data(batched)

            ##################################################################
            ## Inference one step
            ##################################################################
            with torch.no_grad():
                img_feats, txt_feats = self.net(sent_inds, sent_msks, None,
                                                images)
                loss = self.net.forward_loss(img_feats, txt_feats)
            losses.append(loss.cpu().data.item())
            all_txt_feats.append(txt_feats)
            all_img_feats.append(img_feats)
            ##################################################################
            ## Print info
            ##################################################################
            if cnt % self.cfg.log_per_steps == 0:
                print('Iter %07d:' % (cnt))
                tmp_losses = np.stack(losses, 0)
                print('mean loss: ', np.mean(tmp_losses))
                print('-------------------------')

        torch.cuda.empty_cache()
        losses = np.array(losses)
        all_img_feats = torch.cat(all_img_feats, 0)
        all_txt_feats = torch.cat(all_txt_feats, 0)

        # print('all_img_feats', all_img_feats.size())
        all_img_feats_np = all_img_feats.cpu().data.numpy()
        with open(
                osp.join(self.cfg.model_dir,
                         'img_features_%d.pkl' % self.cfg.n_feature_dim),
                'wb') as fid:
            pickle.dump(all_img_feats_np, fid, pickle.HIGHEST_PROTOCOL)

        ##################################################################
        ## Evaluation
        ##################################################################
        print('Evaluating the per-turn performance, may take a while.')
        metrics, caches_results = self.net.evaluate(all_img_feats,
                                                    all_txt_feats)

        with open(osp.join(self.cfg.model_dir, 'test_metrics.json'),
                  'w') as fp:
            json.dump(metrics, fp, indent=4, sort_keys=True)
        with open(osp.join(self.cfg.model_dir, 'test_caches.pkl'),
                  'wb') as fid:
            pickle.dump(caches_results, fid, pickle.HIGHEST_PROTOCOL)

        visualize(self.cfg.exp_name, metrics,
                  osp.join(self.cfg.model_dir, 'evaluation.png'))

        return losses, metrics, caches_results
Beispiel #7
0
    def train(self, train_db, val_db, test_db):
        ##################################################################
        ## Optimizer
        ##################################################################
        if self.cfg.cuda and self.cfg.parallel:
            net = self.net.module
        else:
            net = self.net
        image_encoder_trainable_paras = \
            filter(lambda p: p.requires_grad, net.image_encoder.parameters())
        raw_optimizer = optim.Adam([
            {
                'params': image_encoder_trainable_paras
            },
            {
                'params': net.text_encoder.embedding.parameters(),
                'lr': self.cfg.finetune_lr
            },
            {
                'params': net.text_encoder.rnn.parameters()
            },
            {
                'params': net.what_decoder.parameters()
            },
            {
                'params': net.where_decoder.parameters()
            },
            {
                'params': net.shape_encoder.parameters()
            },
        ],
                                   lr=self.cfg.lr)
        optimizer = Optimizer(raw_optimizer,
                              max_grad_norm=self.cfg.grad_norm_clipping)
        # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer.optimizer, factor=0.8, patience=3)
        # scheduler = optim.lr_scheduler.StepLR(optimizer.optimizer, step_size=3, gamma=0.8)
        # optimizer.set_scheduler(scheduler)

        ##################################################################
        ## LOG
        ##################################################################
        logz.configure_output_dir(self.cfg.model_dir)
        logz.save_config(self.cfg)

        ##################################################################
        ## Main loop
        ##################################################################
        start = time()
        min_val_loss = 100000000
        for epoch in range(self.cfg.n_epochs):
            ##################################################################
            ## Training
            ##################################################################
            torch.cuda.empty_cache()
            train_loss = self.train_epoch(train_db, optimizer, epoch)

            ##################################################################
            ## Validation
            ##################################################################
            torch.cuda.empty_cache()
            val_loss = self.validate_epoch(val_db, epoch)

            ##################################################################
            ## Logging
            ##################################################################

            # update optim scheduler
            current_val_loss = np.mean(val_loss[:, 0])
            # optimizer.update(current_val_loss, epoch)
            logz.log_tabular("Time", time() - start)
            logz.log_tabular("Iteration", epoch)
            logz.log_tabular("AverageLoss", np.mean(train_loss[:, 0]))
            logz.log_tabular("AverageEmbedLoss", np.mean(train_loss[:, 1]))
            logz.log_tabular("AverageAttnLoss", np.mean(train_loss[:, 2]))
            logz.log_tabular("ValAverageLoss", np.mean(val_loss[:, 0]))
            logz.log_tabular("ValAverageEmbedLoss", np.mean(val_loss[:, 1]))
            logz.log_tabular("ValAverageAttnLoss", np.mean(val_loss[:, 2]))
            logz.dump_tabular()

            ##################################################################
            ## Checkpoint
            ##################################################################
            if min_val_loss > current_val_loss:
                min_val_loss = current_val_loss
                # log_info = [np.mean(val_loss), np.mean(val_accu)]
                # self.save_checkpoint(epoch, log_info)
                self.save_best_checkpoint()
                torch.cuda.empty_cache()
Beispiel #8
0
    def train(self, train_db, val_db, test_db):
        ##################################################################
        ## Optimizer
        ##################################################################
        if self.cfg.cuda and self.cfg.parallel:
            net = self.net.module
        else:
            net = self.net
        image_encoder_trainable_paras = \
            filter(lambda p: p.requires_grad, net.image_encoder.parameters())
        # raw_optimizer = optim.Adam([
        #         {'params': net.text_encoder.parameters(), 'lr': self.cfg.finetune_lr},
        #         {'params': image_encoder_trainable_paras},
        #         {'params': net.what_decoder.parameters()},
        #         {'params': net.where_decoder.parameters()}
        #     ], lr=self.cfg.lr)
        raw_optimizer = optim.Adam([{
            'params': image_encoder_trainable_paras,
            'initial_lr': self.cfg.lr
        }, {
            'params': net.what_decoder.parameters(),
            'initial_lr': self.cfg.lr
        }, {
            'params': net.where_decoder.parameters(),
            'initial_lr': self.cfg.lr
        }],
                                   lr=self.cfg.lr)
        self.optimizer = Optimizer(raw_optimizer,
                                   max_grad_norm=self.cfg.grad_norm_clipping)
        # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer.optimizer, factor=0.8, patience=3)
        scheduler = optim.lr_scheduler.StepLR(self.optimizer.optimizer,
                                              step_size=3,
                                              gamma=0.8,
                                              last_epoch=self.start_epoch - 1)
        self.optimizer.set_scheduler(scheduler)

        num_train_steps = int(
            len(train_db) / self.cfg.accumulation_steps * self.cfg.n_epochs)
        num_warmup_steps = int(num_train_steps * self.cfg.warmup)
        self.bert_optimizer = AdamW([{
            'params': net.text_encoder.parameters(),
            'initial_lr': self.cfg.finetune_lr
        }],
                                    lr=self.cfg.finetune_lr)
        self.bert_scheduler = get_linear_schedule_with_warmup(
            self.bert_optimizer,
            num_warmup_steps=num_warmup_steps,
            num_training_steps=num_train_steps,
            last_epoch=self.start_epoch - 1)

        bucket_boundaries = [4, 8, 12, 16, 22]  # [4,8,12,16,22]
        print('preparing training bucket sampler')
        self.train_bucket_sampler = BucketSampler(
            train_db, bucket_boundaries, batch_size=self.cfg.batch_size)
        print('preparing validation bucket sampler')
        self.val_bucket_sampler = BucketSampler(val_db,
                                                bucket_boundaries,
                                                batch_size=4)

        ##################################################################
        ## LOG
        ##################################################################
        logz.configure_output_dir(self.cfg.model_dir)
        logz.save_config(self.cfg)

        ##################################################################
        ## Main loop
        ##################################################################
        start = time()

        for epoch in range(self.start_epoch, self.cfg.n_epochs):
            ##################################################################
            ## Training
            ##################################################################
            print('Training...')
            torch.cuda.empty_cache()
            train_pred_loss, train_attn_loss, train_eos_loss, train_accu, train_mse = \
                self.train_epoch(train_db, self.optimizer, epoch)

            ##################################################################
            ## Validation
            ##################################################################
            print('Validation...')
            val_loss, val_accu, val_mse, val_infos = self.validate_epoch(
                val_db)

            ##################################################################
            ## Sample
            ##################################################################
            if self.cfg.if_sample:
                print('Sample...')
                torch.cuda.empty_cache()
                self.sample(epoch, test_db, self.cfg.n_samples)
                torch.cuda.empty_cache()
            ##################################################################
            ## Logging
            ##################################################################

            # update optim scheduler
            print('Loging...')
            self.optimizer.update(np.mean(val_loss), epoch)

            logz.log_tabular("Time", time() - start)
            logz.log_tabular("Iteration", epoch)

            logz.log_tabular("TrainAverageError", np.mean(train_pred_loss))
            logz.log_tabular("TrainAverageAccu", np.mean(train_accu))
            logz.log_tabular("TrainAverageMse", np.mean(train_mse))
            logz.log_tabular("ValAverageError", np.mean(val_loss))
            logz.log_tabular("ValAverageAccu", np.mean(val_accu))
            logz.log_tabular("ValAverageObjAccu", np.mean(val_accu[:, 0]))
            logz.log_tabular("ValAverageCoordAccu", np.mean(val_accu[:, 1]))
            logz.log_tabular("ValAverageScaleAccu", np.mean(val_accu[:, 2]))
            logz.log_tabular("ValAverageRatioAccu", np.mean(val_accu[:, 3]))
            logz.log_tabular("ValAverageMse", np.mean(val_mse))
            logz.log_tabular("ValAverageXMse", np.mean(val_mse[:, 0]))
            logz.log_tabular("ValAverageYMse", np.mean(val_mse[:, 1]))
            logz.log_tabular("ValAverageWMse", np.mean(val_mse[:, 2]))
            logz.log_tabular("ValAverageHMse", np.mean(val_mse[:, 3]))
            logz.log_tabular("ValUnigramF3", np.mean(val_infos.unigram_F3()))
            logz.log_tabular("ValBigramF3", np.mean(val_infos.bigram_F3()))
            logz.log_tabular("ValUnigramP", np.mean(val_infos.unigram_P()))
            logz.log_tabular("ValUnigramR", np.mean(val_infos.unigram_R()))
            logz.log_tabular("ValBigramP", val_infos.mean_bigram_P())
            logz.log_tabular("ValBigramR", val_infos.mean_bigram_R())
            logz.log_tabular("ValUnigramScale", np.mean(val_infos.scale()))
            logz.log_tabular("ValUnigramRatio", np.mean(val_infos.ratio()))
            logz.log_tabular("ValUnigramSim",
                             np.mean(val_infos.unigram_coord()))
            logz.log_tabular("ValBigramSim", val_infos.mean_bigram_coord())

            logz.dump_tabular()

            ##################################################################
            ## Checkpoint
            ##################################################################
            print('Saving checkpoint...')
            log_info = [np.mean(val_loss), np.mean(val_accu)]
            self.save_checkpoint(epoch, log_info)
            torch.cuda.empty_cache()
Beispiel #9
0
    def train(self, train_db, val_db):
        # if self.cfg.cuda and self.cfg.parallel:
        #     net = self.net.module
        # else:
        #     net = self.net
        # net = self.net
        # print(net)
        if self.cfg.tensorboard_logdir is not None:
            summary_writer = SummaryWriter(self.cfg.tensorboard_logdir)
        else:
            summary_writer = SummaryWriter(
                osp.join(self.cfg.log_dir, self.cfg.task, 'tensorboard',
                         self.cfg.model_name))

        # log_per_steps = self.cfg.accumulation_steps * self.cfg.log_per_steps

        log_dir = osp.join(self.cfg.log_dir, self.cfg.task,
                           self.cfg.model_name)
        if not osp.exists(log_dir):
            os.makedirs(log_dir)

        code_dir = osp.join(log_dir, 'code')
        if not osp.exists(code_dir):
            os.makedirs(code_dir)

        shutil.copy('./train.py', osp.join(code_dir, 'train.py'))
        shutil.copy('./commonsense_dataset.py',
                    osp.join(code_dir, 'commonsense_dataset.py'))

        logz.configure_output_dir(log_dir)
        logz.save_config(self.cfg)

        train_loader = DataLoader(train_db,
                                  batch_size=self.cfg.batch_size,
                                  shuffle=True,
                                  num_workers=self.cfg.num_workers)

        # self.optimizer = BertAdam(net.parameters(), lr=cfg.lr, warmup=cfg.warmup)
        # self.scheduler = optim.lr_self.scheduler.StepLR(self.optimizer, step_size=3, gamma=0.8)

        num_train_steps = int(
            len(train_loader) / self.cfg.accumulation_steps * self.cfg.epochs)
        num_warmup_steps = int(num_train_steps * self.cfg.warmup)

        no_decay = ['bias', 'LayerNorm.weight']
        not_optim = []

        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in self.net.named_parameters()
                if (not any(nd in n for nd in no_decay)) and (not any(
                    nd in n for nd in not_optim))
            ],
            'weight_decay':
            self.cfg.weight_decay
        }, {
            'params': [
                p for n, p in self.net.named_parameters()
                if (any(nd in n
                        for nd in no_decay)) and (not any(nd in n
                                                          for nd in not_optim))
            ],
            'weight_decay':
            0.0
        }]

        if self.cfg.fix_emb:
            for p in self.net.embedding.embeddings.parameters():
                p.requires_grad = False

        if self.cfg.ft_last_layer:
            for p in self.net.embedding.embeddings.parameters():
                p.requires_grad = False
            for i in range(10):
                for p in self.net.embedding.encoder.layer[i]:
                    p.requires_grad = False

        self.optimizer = AdamW(optimizer_grouped_parameters,
                               lr=self.cfg.lr,
                               eps=self.cfg.adam_eps,
                               betas=eval(self.cfg.adam_betas))
        # self.optimizer = AdamW(self.net.parameters(), lr=self.cfg.lr, eps=1e-8)

        self.scheduler = WarmupLinearSchedule(self.optimizer,
                                              warmup_steps=num_warmup_steps,
                                              t_total=num_train_steps)
        loss_func = nn.CrossEntropyLoss()

        if self.cfg.fp16:
            try:
                from apex import amp
            except ImportError:
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
                )
            self.net, self.optimizer = amp.initialize(
                self.net, self.optimizer, opt_level=self.cfg.fp16_opt_level)
        # self.scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_train_steps)

        # self.optimizer.set_self.scheduler(self.scheduler)

        torch.cuda.synchronize()
        self.start = time.time()
        self.net.zero_grad()
        self.batch_loss, self.batch_acc = [], []
        self.global_step = 0
        for epoch in range(self.start_epoch, self.cfg.epochs):

            print('Training...')
            torch.cuda.empty_cache()
            self.batch_loss, self.batch_acc = [], []
            for cnt, batch in tqdm(enumerate(train_loader)):
                self.net.train()

                input_ids, input_mask, segment_ids, features, fea_mask, labels, input_indexs = self.batch_data(
                    batch)
                batch_input = (input_ids, input_mask, segment_ids, features,
                               fea_mask)
                # self.net.zero_grad()
                logits, probabilities, preds = self.net(
                    input_ids, input_mask, segment_ids, features, fea_mask)
                loss = loss_func(logits, labels).mean()
                # print(probabilities)

                # one_hot_labels = nn.functional.one_hot(labels, num_classes = Number_class[self.cfg.task.lower()]).float()
                # per_example_loss = -torch.sum(one_hot_labels * log_probs, dim=-1)
                # loss = torch.mean(per_example_loss)

                if self.cfg.accumulation_steps > 1:
                    loss = loss / self.cfg.accumulation_steps

                if self.cfg.fp16:
                    with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                        scaled_loss.backward()
                    if self.cfg.max_grad_norm > 0.0:
                        torch.nn.utils.clip_grad_norm_(
                            amp.master_params(self.optimizer),
                            self.cfg.max_grad_norm)
                else:
                    loss.backward()
                    if self.cfg.max_grad_norm > 0.0:
                        nn.utils.clip_grad_norm_(self.net.parameters(),
                                                 self.cfg.max_grad_norm)

                acc, _, _, _ = self.evaluate(preds, labels, input_indexs)

                self.batch_loss.append(loss.cpu().data.item() / len(input_ids))
                self.batch_acc.append(acc)

                if self.global_step == 0 and cnt == 0:
                    _ = self.update_log(summary_writer, epoch, val_db)

                if ((cnt + 1) % self.cfg.accumulation_steps) == 0:
                    # print(nn.utils.clip_grad_norm_(self.net.parameters(), max_norm=1e5))
                    self.optimizer.step()
                    self.scheduler.step()
                    self.net.zero_grad()
                    self.global_step += 1

                    if self.global_step % self.cfg.log_per_steps == 0:
                        val_acc = self.update_log(summary_writer, epoch,
                                                  val_db)
                        self.batch_loss, self.batch_acc = [], []

                        if self.cfg.save_ckpt:
                            if epoch >= (self.cfg.epochs / 4):
                                if self.best < val_acc:
                                    print('Saving checkpoint...')
                                    self.save_checkpoint(epoch, acc=val_acc)
                                    self.best = val_acc

            ##################################################################
            ## Checkpoint
            ##################################################################
            if len(self.batch_loss) > 0:
                val_acc = self.update_log(summary_writer, epoch, val_db)
                self.best = max(self.best, val_acc)
                self.batch_loss, self.batch_acc = [], []

            # val_wrong_qa = []
            # for q, a in zip(val_wrong, val_wrong_answer):
            #     val_wrong_qa.append([val_db.index2qid[q], trainer.index2label[a]])
            # epoch_wrong = {epoch: val_wrong_qa}
            if self.cfg.save_ckpt:
                if epoch >= (self.cfg.epochs / 4):
                    print('Saving checkpoint...')
                    self.save_checkpoint(epoch, True, acc=val_acc)
            torch.cuda.empty_cache()

        summary_writer.close()
    def train(self, train_db, val_db, test_db):
        ##################################################################
        ## LOG
        ##################################################################
        logz.configure_output_dir(self.cfg.model_dir)
        logz.save_config(self.cfg)

        ##################################################################
        ## NN table
        ##################################################################
        if self.cfg.use_hard_mining:
            self.train_tables = AllCategoriesTables(train_db)
            self.val_tables = AllCategoriesTables(val_db)
            self.train_tables.build_nntables_for_all_categories(True)
            self.val_tables.build_nntables_for_all_categories(True)

        ##################################################################
        ## Main loop
        ##################################################################
        start = time()
        min_val_loss = 100000000
        for epoch in range(self.epoch, self.cfg.n_epochs):
            ##################################################################
            ## Training
            ##################################################################
            torch.cuda.empty_cache()
            train_loss, train_accu = self.train_epoch(train_db, epoch)

            ##################################################################
            ## Validation
            ##################################################################
            torch.cuda.empty_cache()
            val_loss, val_accu = self.validate_epoch(val_db, epoch)

            ##################################################################
            ## Logging
            ##################################################################

            # update optim scheduler
            current_val_loss = np.mean(val_loss[:, 0])
            # self.optimizer.update(current_val_loss, epoch)
            logz.log_tabular("Time", time() - start)
            logz.log_tabular("Iteration", epoch)
            logz.log_tabular("AverageLoss", np.mean(train_loss[:, 0]))
            logz.log_tabular("AveragePredLoss", np.mean(train_loss[:, 1]))
            logz.log_tabular("AverageEmbedLoss", np.mean(train_loss[:, 2]))
            logz.log_tabular("AverageAttnLoss", np.mean(train_loss[:, 3]))
            logz.log_tabular("AverageObjAccu", np.mean(train_accu[:, 0]))
            logz.log_tabular("AverageCoordAccu", np.mean(train_accu[:, 1]))
            logz.log_tabular("AverageScaleAccu", np.mean(train_accu[:, 2]))
            logz.log_tabular("AverageRatioAccu", np.mean(train_accu[:, 3]))

            logz.log_tabular("ValAverageLoss", np.mean(val_loss[:, 0]))
            logz.log_tabular("ValAveragePredLoss", np.mean(val_loss[:, 1]))
            logz.log_tabular("ValAverageEmbedLoss", np.mean(val_loss[:, 2]))
            logz.log_tabular("ValAverageAttnLoss", np.mean(val_loss[:, 3]))
            logz.log_tabular("ValAverageObjAccu", np.mean(val_accu[:, 0]))
            logz.log_tabular("ValAverageCoordAccu", np.mean(val_accu[:, 1]))
            logz.log_tabular("ValAverageScaleAccu", np.mean(val_accu[:, 2]))
            logz.log_tabular("ValAverageRatioAccu", np.mean(val_accu[:, 3]))
            logz.dump_tabular()

            ##################################################################
            ## Checkpoint
            ##################################################################
            if self.cfg.use_hard_mining:
                if (epoch + 1) % 3 == 0:
                    torch.cuda.empty_cache()
                    t0 = time()
                    self.dump_shape_vectors(train_db)
                    torch.cuda.empty_cache()
                    self.dump_shape_vectors(val_db)
                    print("Dump shape vectors completes (time %.2fs)" %
                          (time() - t0))
                    torch.cuda.empty_cache()
                    t0 = time()
                    self.train_tables.build_nntables_for_all_categories(False)
                    self.val_tables.build_nntables_for_all_categories(False)
                    print("NN completes (time %.2fs)" % (time() - t0))
            self.save_checkpoint(epoch)
Beispiel #11
0
    def train(self, train_db, val_db, test_db):
        ##################################################################
        ## LOG
        ##################################################################
        logz.configure_output_dir(self.cfg.model_dir)
        logz.save_config(self.cfg)

        ##################################################################
        ## Main loop
        ##################################################################
        start = time()
        min_val_loss = 100000000

        for epoch in range(self.epoch, self.cfg.n_epochs):
            ##################################################################
            ## Training
            ##################################################################
            torch.cuda.empty_cache()
            train_loss = self.train_epoch(train_db, epoch)

            ##################################################################
            ## Validation
            ##################################################################
            torch.cuda.empty_cache()
            val_loss = self.validate_epoch(val_db, epoch)
            # val_loss = train_loss

            ##################################################################
            ## Sample
            ##################################################################
            torch.cuda.empty_cache()
            self.sample_for_vis(epoch, test_db, self.cfg.n_samples)
            torch.cuda.empty_cache()
            ##################################################################
            ## Logging
            ##################################################################

            # update optim scheduler
            current_val_loss = np.mean(val_loss)

            logz.log_tabular("Time", time() - start)
            logz.log_tabular("Iteration", epoch)
            logz.log_tabular("AverageTotalError", np.mean(train_loss[:, 0]))
            logz.log_tabular("AveragePredError", np.mean(train_loss[:, 1]))
            logz.log_tabular("AverageImageError", np.mean(train_loss[:, 2]))
            logz.log_tabular("AverageFeat0Error", np.mean(train_loss[:, 3]))
            logz.log_tabular("AverageFeat1Error", np.mean(train_loss[:, 4]))
            logz.log_tabular("AverageFeat2Error", np.mean(train_loss[:, 5]))
            logz.log_tabular("AverageFeat3Error", np.mean(train_loss[:, 6]))
            logz.log_tabular("AverageFeat4Error", np.mean(train_loss[:, 7]))
            logz.log_tabular("ValAverageTotalError", np.mean(val_loss[:, 0]))
            logz.log_tabular("ValAveragePredError", np.mean(val_loss[:, 1]))
            logz.log_tabular("ValAverageImageError", np.mean(val_loss[:, 2]))
            logz.log_tabular("ValAverageFeat0Error", np.mean(val_loss[:, 3]))
            logz.log_tabular("ValAverageFeat1Error", np.mean(val_loss[:, 4]))
            logz.log_tabular("ValAverageFeat2Error", np.mean(val_loss[:, 5]))
            logz.log_tabular("ValAverageFeat3Error", np.mean(val_loss[:, 6]))
            logz.log_tabular("ValAverageFeat4Error", np.mean(val_loss[:, 7]))
            logz.dump_tabular()

            ##################################################################
            ## Checkpoint
            ##################################################################
            if min_val_loss > current_val_loss:
                min_val_loss = current_val_loss
            self.save_checkpoint(epoch)
            torch.cuda.empty_cache()
Beispiel #12
0
    def train(self, train_db, val_db):
        ##################################################################
        ## Optimizer
        ##################################################################
        if self.cfg.cuda and self.cfg.parallel:
            net = self.net.module
        else:
            net = self.net
        optimizer = optim.Adam([{
            'params': net.encoder.parameters()
        }, {
            'params': net.decoder.parameters()
        }],
                               lr=self.cfg.lr)

        ##################################################################
        ## LOG
        ##################################################################
        logz.configure_output_dir(self.cfg.model_dir)
        logz.save_config(self.cfg)

        ##################################################################
        ## Main loop
        ##################################################################
        start = time()
        min_val_loss = 100000000

        for epoch in range(self.cfg.n_epochs):
            ##################################################################
            ## Training
            ##################################################################
            torch.cuda.empty_cache()
            train_loss = self.train_epoch(train_db, optimizer, epoch)

            ##################################################################
            ## Validation
            ##################################################################
            torch.cuda.empty_cache()
            val_loss = self.validate_epoch(val_db, epoch)
            # val_loss = train_loss

            ##################################################################
            ## Sample
            ##################################################################
            torch.cuda.empty_cache()
            self.sample(epoch, val_db, self.cfg.n_samples)
            torch.cuda.empty_cache()
            ##################################################################
            ## Logging
            ##################################################################

            # update optim scheduler
            current_val_loss = np.mean(val_loss)

            logz.log_tabular("Time", time() - start)
            logz.log_tabular("Iteration", epoch)
            logz.log_tabular("AverageTotalError", np.mean(train_loss[:, 0]))
            logz.log_tabular("AveragePredError", np.mean(train_loss[:, 1]))
            logz.log_tabular("AverageImageError", np.mean(train_loss[:, 2]))
            logz.log_tabular("AverageFeat0Error", np.mean(train_loss[:, 3]))
            logz.log_tabular("AverageFeat1Error", np.mean(train_loss[:, 4]))
            logz.log_tabular("AverageFeat2Error", np.mean(train_loss[:, 5]))
            logz.log_tabular("AverageFeat3Error", np.mean(train_loss[:, 6]))
            logz.log_tabular("AverageFeat4Error", np.mean(train_loss[:, 7]))
            logz.log_tabular("ValAverageTotalError", np.mean(val_loss[:, 0]))
            logz.log_tabular("ValAveragePredError", np.mean(val_loss[:, 1]))
            logz.log_tabular("ValAverageImageError", np.mean(val_loss[:, 2]))
            logz.log_tabular("ValAverageFeat0Error", np.mean(val_loss[:, 3]))
            logz.log_tabular("ValAverageFeat1Error", np.mean(val_loss[:, 4]))
            logz.log_tabular("ValAverageFeat2Error", np.mean(val_loss[:, 5]))
            logz.log_tabular("ValAverageFeat3Error", np.mean(val_loss[:, 6]))
            logz.log_tabular("ValAverageFeat4Error", np.mean(val_loss[:, 7]))
            logz.dump_tabular()

            ##################################################################
            ## Checkpoint
            ##################################################################
            if min_val_loss > current_val_loss:
                min_val_loss = current_val_loss
                self.save_checkpoint()
                torch.cuda.empty_cache()