class PoemImageEmbedTrainer():
    def __init__(self, train_data, test_data, sentiment_model, batchsize, load_model, device):
        self.device = device
        self.train_data = train_data
        self.test_data = test_data
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

        self.train_transform = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor()
        ])

        self.test_transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor()
        ])

        img_dir = 'data/image'
        self.train_set = PoemImageEmbedDataset(self.train_data, img_dir,
                                               tokenizer=self.tokenizer, max_seq_len=100,
                                               transform=self.train_transform)
        self.train_loader = DataLoader(self.train_set, batch_size=batchsize, shuffle=True, num_workers=4)

        self.test_set = PoemImageEmbedDataset(self.test_data, img_dir,
                                              tokenizer=self.tokenizer, max_seq_len=100,
                                              transform=self.test_transform)
        self.test_loader = DataLoader(self.test_set, batch_size=batchsize, num_workers=4)

        self.model = PoemImageEmbedModel(device)

        self.model = DataParallel(self.model)
        load_dataparallel(self.model.module.img_embedder.sentiment_feature, sentiment_model)
        if load_model:
            logger.info('load model from '+ load_model)
            self.model.load_state_dict(torch.load(load_model))
        self.model.to(device)
        self.optimizer = optim.Adam(list(self.model.module.poem_embedder.linear.parameters()) + \
                                    list(self.model.module.img_embedder.linear.parameters()), lr=1e-4)
        self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[2, 4, 6], gamma=0.33)

    def train_epoch(self, epoch, log_interval, save_interval, ckpt_file):
        self.model.train()
        running_ls = 0
        acc_ls = 0
        start = time.time()
        num_batches = len(self.train_loader)
        for i, batch in enumerate(self.train_loader):
            img1, ids1, mask1, img2, ids2, mask2 = [t.to(self.device) for t in batch]
            self.model.zero_grad()
            loss = self.model(img1, ids1, mask1, img2, ids2, mask2)
            loss.backward(torch.ones_like(loss))
            running_ls += loss.mean().item()
            acc_ls += loss.mean().item()
            self.optimizer.step()

            if (i + 1) % log_interval == 0:
                elapsed_time = time.time() - start
                iters_per_sec = (i + 1) / elapsed_time
                remaining = (num_batches - i - 1) / iters_per_sec
                remaining_time = time.strftime("%H:%M:%S", time.gmtime(remaining))

                print('[{:>2}, {:>4}/{}] running loss:{:.4} acc loss:{:.4} {:.3}iters/s {} left'.format(
                    epoch, (i + 1), num_batches, running_ls / log_interval, acc_ls /(i+1),
                    iters_per_sec, remaining_time))
                running_ls = 0

            if (i + 1) % save_interval == 0:
                self.save_model(ckpt_file)

    def save_model(self, file):
        torch.save(self.model.state_dict(), file)
示例#2
0
class Trainer:
    def _init_dataset(self):
        if self.cfg['dataset_type'] == 'pose_feats':
            train_set = RelPoseFeatsDataset(self.cfg['dataset_root'],
                                            self.cfg['dataset_extract_name'],
                                            self.cfg['dataset_match_name'],
                                            self.cfg['train_pair_info_fn'],
                                            self.cfg['epipolar_inlier_thresh'],
                                            self.cfg['use_eig'],
                                            self.cfg['dataset_eig_name'],
                                            self.cfg['use_feats'],
                                            is_train=True)
            val_set = RelPoseFeatsDataset(self.cfg['dataset_root'],
                                          self.cfg['dataset_extract_name'],
                                          self.cfg['dataset_match_name'],
                                          self.cfg['val_pair_info_fn'],
                                          self.cfg['epipolar_inlier_thresh'],
                                          self.cfg['use_eig'],
                                          self.cfg['dataset_eig_name'],
                                          self.cfg['use_feats'],
                                          is_train=False)
        elif self.cfg['dataset_type'] == 'detrac':
            root_dir = 'data/detrac_train_cache' if 'root_dir' not in self.cfg else self.cfg[
                'root_dir']
            train_set = DETRACTrainDataset(self.cfg['train_pair_info_fn'],
                                           root_dir,
                                           self.cfg['dataset_extract_name'],
                                           self.cfg['dataset_match_name'],
                                           self.cfg['use_eig'],
                                           self.cfg['eig_name'], True, True)
            val_set = DETRACTrainDataset(self.cfg['val_pair_info_fn'],
                                         root_dir,
                                         self.cfg['dataset_extract_name'],
                                         self.cfg['dataset_match_name'],
                                         self.cfg['use_eig'],
                                         self.cfg['eig_name'], False)
        else:
            raise NotImplementedError

        self.train_set = DataLoader(train_set,
                                    self.cfg['batch_size'],
                                    True,
                                    num_workers=16,
                                    pin_memory=False,
                                    collate_fn=collate_fn)
        self.val_set = DataLoader(val_set,
                                  self.cfg['batch_size'],
                                  False,
                                  num_workers=4,
                                  collate_fn=collate_fn)
        print(f'train set len {len(self.train_set)}')
        print(f'val set len {len(self.val_set)}')

    def _init_network(self):
        self.network = LMCNet(self.cfg).cuda()
        self.optimizer = Adam(self.network.parameters(), lr=1e-3)

        self.val_losses = []
        for loss_name in self.cfg['loss']:
            self.val_losses.append(name2loss[loss_name](self.cfg))
        self.val_metrics = []

        for metric_name in self.cfg['val_metric']:
            if metric_name in name2metric:
                self.val_metrics.append(name2metric[metric_name](self.cfg))
            else:
                self.val_metrics.append(name2loss[metric_name](self.cfg))

        if self.cfg['multi_gpus']:
            # make multi gpu network
            self.train_network = DataParallel(
                MultiGPUWrapper(self.network, self.val_losses))
            self.train_losses = [DummyLoss(self.val_losses)]
        else:
            self.train_network = self.network
            self.train_losses = self.val_losses

        if 'finetune' in self.cfg and self.cfg['finetune']:
            checkpoint = torch.load(self.cfg['finetune_path'])
            self.network.load_state_dict(checkpoint['network_state_dict'])
            print(f'==> resuming from step {self.cfg["finetune_path"]}')
        self.val_evaluator = ValidationEvaluator(self.cfg)

    def __init__(self, cfg):
        self.cfg = cfg
        self.model_dir = os.path.join('data/model', cfg['name'])
        if not os.path.exists(self.model_dir): os.mkdir(self.model_dir)
        self.pth_fn = os.path.join(self.model_dir, 'model.pth')
        self.best_pth_fn = os.path.join(self.model_dir, 'model_best.pth')

    def run(self):
        self._init_dataset()
        self._init_network()
        self._init_logger()

        best_para, start_step = self._load_model()
        train_iter = iter(self.train_set)

        pbar = tqdm(total=self.cfg['total_step'], bar_format='{r_bar}')
        pbar.update(start_step)
        for step in range(start_step, self.cfg['total_step']):
            try:
                train_data = next(train_iter)
            except StopIteration:
                train_iter = iter(self.train_set)
                train_data = next(train_iter)
            if not self.cfg['multi_gpus']:
                train_data = to_cuda(train_data)
            train_data['step'] = step

            self.train_network.train()
            self.network.train()
            reset_learning_rate(self.optimizer, self._get_lr(step))
            self.optimizer.zero_grad()
            self.train_network.zero_grad()

            log_info = {}
            outputs = self.train_network(train_data)
            for loss in self.train_losses:
                loss_results = loss(outputs, train_data, step)
                for k, v in loss_results.items():
                    log_info[k] = v

            loss = 0
            for k, v in log_info.items():
                if k.startswith('loss'):
                    loss = loss + torch.mean(v)

            loss.backward()
            self.optimizer.step()
            if ((step + 1) % self.cfg['train_log_step']) == 0:
                self._log_data(log_info, step + 1, 'train')

            if (step + 1) % self.cfg['val_interval'] == 0:
                val_results, val_para = self.val_evaluator(
                    self.network, self.val_losses + self.val_metrics,
                    self.val_set)
                if val_para > best_para:
                    print(
                        f'New best model {self.cfg["key_metric_name"]}: {val_para:.5f} previous {best_para:.5f}'
                    )
                    best_para = val_para
                    # if self.cfg['save_inter_model'] and (step+1)%self.cfg['save_inter_interval']==0:
                    #     self._save_model(step + 1, best_para, os.path.join(self.model_dir,f'{step+1}.pth'))
                    self._save_model(step + 1, best_para, self.best_pth_fn)
                self._log_data(val_results, step + 1, 'val')

            if (step + 1) % self.cfg['save_interval'] == 0:
                # if self.cfg['save_inter_model'] and (step+1)%10000==0:
                #     self._save_model(step+1,best_para,f'{self.model_dir}/{step}.pth')
                self._save_model(step + 1, best_para)

            pbar.set_postfix(loss=float(loss.detach().cpu().numpy()))
            pbar.update(1)

        pbar.close()

    def _load_model(self):
        best_para, start_step = 0, 0
        if os.path.exists(self.pth_fn):
            checkpoint = torch.load(self.pth_fn)
            best_para = checkpoint['best_para']
            start_step = checkpoint['step']
            self.network.load_state_dict(checkpoint['network_state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            print(f'==> resuming from step {start_step} best para {best_para}')

        return best_para, start_step

    def _save_model(self, step, best_para, save_fn=None):
        save_fn = self.pth_fn if save_fn is None else save_fn
        torch.save(
            {
                'step': step,
                'best_para': best_para,
                'network_state_dict': self.network.state_dict(),
                'optimizer_state_dict': self.optimizer.state_dict(),
            }, save_fn)

    def _init_logger(self):
        self.logger = Logger(self.model_dir)

    def _log_data(self, results, step, prefix='train', verbose=False):
        log_results = {}
        for k, v in results.items():
            if isinstance(v, float) or np.isscalar(v):
                log_results[k] = v
            elif type(v) == np.ndarray:
                log_results[k] = np.mean(v)
            else:
                log_results[k] = np.mean(v.detach().cpu().numpy())
        self.logger.log(log_results, prefix, step, verbose)

    def _get_lr(self, step):
        if 'lr_type' not in self.cfg or self.cfg['lr_type'] == 'default':
            if step <= self.cfg['lr_mid_epoch']:
                return self.cfg['lr_start']
            else:
                decay_rate = self.cfg['lr_decay_rate']
                decay_step = self.cfg['lr_decay_step']
                decay_num = (step - self.cfg['lr_mid_epoch']) // decay_step
                return max(self.cfg['lr_start'] * decay_rate**decay_num,
                           self.cfg['lr_min'])
        elif self.cfg['lr_type'] == 'warm_up':
            if step <= self.cfg['lr_warm_up_step']:
                return self.cfg['lr_warm_up']
            else:
                decay_rate = self.cfg['lr_decay_rate']
                decay_step = self.cfg['lr_decay_step']
                decay_num = (step - self.cfg['lr_warm_up_step']) // decay_step
                return max(self.cfg['lr_start'] * decay_rate**decay_num,
                           self.cfg['lr_min'])
示例#3
0
for epoch in range(EPOCH):
    D_losses = []
    G_losses = []
    epoch_start_time = time()
    for step, (x, y) in enumerate(train_loader):
        if step == train_loader.dataset.__len__() // BATCH_SIZE:
            break

        z = torch.randn((BATCH_SIZE, 100)).view(-1, 100, 1, 1)
        if GPU_MODE:
            x, z = Variable(x.cuda()), Variable(z.cuda())
        else:
            x, z = Variable(x), Variable(z)

        ###### train D #####
        D.zero_grad()

        D_real = D(x).squeeze()  # squeeze
        D_real_loss = BCE_loss(D_real, y_real)

        G_ = G(z)
        D_fake = D(G_).squeeze()  # squeeze
        D_fake_loss = BCE_loss(D_fake, y_fake)

        # - (log(D(x)) + log(1 - D(G(z))))
        D_train_loss = D_real_loss + D_fake_loss

        D_train_loss.backward()
        D_optimizer.step()
        D_losses.append(D_train_loss.data[0])
示例#4
0
class RecurrentGAN():
    def __init__(self, cfg):
        """A recurrent GAN model, each time step a generated image
        (x'_{t-1}) and the current question q_{t} are fed to the RNN
        to produce the conditioning vector for the GAN.
        The following equations describe this model:

            - c_{t} = RNN(h_{t-1}, q_{t}, x^{~}_{t-1})
            - x^{~}_{t} = G(z | c_{t})
        """
        super(RecurrentGAN, self).__init__()

        self.generator = DataParallel(
            GeneratorFactory.create_instance(cfg)).cuda()
        self.generator_optimizer = OPTIM[cfg.generator_optimizer](
            self.generator.parameters(), cfg.generator_lr, cfg.generator_beta1,
            cfg.generator_beta2, cfg.generator_weight_decay)

        # discriminator
        self.discriminator = DataParallel(
            DiscriminatorFactory.create_instance(cfg)).cuda()
        self.discriminator_optimizer = OPTIM[cfg.discriminator_optimizer](
            self.discriminator.parameters(), cfg.discriminator_lr,
            cfg.discriminator_beta1, cfg.discriminator_beta2,
            cfg.discriminator_weight_decay)

        # word-level instruction encoder
        self.sentence_encoder = nn.DataParallel(SentenceEncoder(cfg)).cuda()
        self.sentence_encoder_optimizer = OPTIM[cfg.gru_optimizer](
            self.sentence_encoder.parameters(), cfg.gru_lr)

        # instruction-level encoder, implemented by GRU.
        self.use_history = cfg.use_history
        if self.use_history:
            self.rnn = nn.DataParallel(nn.GRU(cfg.input_dim,
                                              cfg.hidden_dim,
                                              batch_first=False),
                                       dim=1).cuda()
            self.rnn_optimizer = OPTIM[cfg.rnn_optimizer](
                self.rnn.parameters(), cfg.rnn_lr)

        # layer norm for output of rnn
        self.layer_norm = nn.DataParallel(nn.LayerNorm(cfg.hidden_dim)).cuda()

        # fusion condition, as input of rnn (Actually we only use sentence
        self.condition_encoder = DataParallel(ConditionEncoder(cfg)).cuda()
        feature_encoding_params = list(self.condition_encoder.parameters())

        # image encoder
        self.use_image_encoder = cfg.use_fg
        if self.use_image_encoder:
            self.image_encoder = DataParallel(
                ImageEncoder(cfg)).cuda()  # 用于generator的image encoder
            feature_encoding_params += list(self.image_encoder.parameters())

        self.feature_encoders_optimizer = OPTIM['adam'](
            feature_encoding_params, cfg.feature_encoder_lr)

        # Criterion
        self.criterion = LOSSES[cfg.criterion]()
        self.aux_criterion = DataParallel(
            torch.nn.BCELoss(reduction='none')).cuda()

        self.cfg = cfg
        self.logger = Logger(cfg.log_path, cfg.exp_name)

    def train_batch(self, batch, epoch, iteration, visualizer, logger):
        """
        The training scheme follows the following:
            - Discriminator and Generator is updated every time step.
            - RNN, SentenceEncoder and ImageEncoder parameters are
            updated every sequence
        @args:
            batch:
                image: (N, max_seq_len, C, H, W)
                turn_lengths: (N, max_seq_len)
                dialog_length: (N, )
                turn_word_embedding: (N, max_seq_len, max_sent_len, embed_dim)

        max_seq_len: the length of longest dialog in this batch

        """
        batch_size = len(batch['image'])
        max_seq_len = batch['image'].size(1)

        prev_image = torch.FloatTensor(batch['background'])
        prev_image = prev_image.unsqueeze(0) \
            .repeat(batch_size, 1, 1, 1)
        disc_prev_image = prev_image  # (N, C, H, W)

        # Initial inputs for the RNN set to zeros
        hidden = torch.zeros(1, batch_size,
                             self.cfg.hidden_dim)  # (1, N, hidden_dim)
        prev_objects = torch.zeros(batch_size,
                                   self.cfg.num_objects)  # (N, num_objects)

        teller_images = []
        drawer_images = []
        added_entities = []

        for t in range(max_seq_len):
            image = batch['image'][:, t]  # (N, C, H, W)
            turns_word_embedding = batch[
                'turn_word_embedding'][:, t]  # (N, max_sent_len, embed_dim)
            turns_lengths = batch['turn_lengths'][:, t]  # (batch_size, )
            objects = batch['objects'][:, t]  # (batch_size, )
            seq_ended = t > (batch['dialog_length'] - 1)  # (batch_size, )

            image_feature_map, image_vec = self.image_encoder(prev_image)
            turn_embedding, _ = self.sentence_encoder(turns_word_embedding,
                                                      turns_lengths)
            rnn_condition = self.condition_encoder(turn_embedding, image_vec)

            if self.use_history:
                rnn_condition = rnn_condition.unsqueeze(
                    0
                )  # input vector for condition rnn, (1, batch_size, condition_dim)
                output, hidden = self.rnn(rnn_condition, hidden)

                output = output.squeeze(
                    0)  # (batch_size, condition_output_dim)
                output = self.layer_norm(output)
            else:
                output = rnn_condition
                output = self.layer_norm(output)

            fake_image, mu, logvar, sigma = self._forward_generator(
                batch_size,
                output.detach(
                ),  # Instruction encoder is only optimized from discriminator.
                image_feature_map)

            visualizer.track_sigma(sigma)

            hamming = objects - prev_objects
            hamming = torch.clamp(hamming, min=0)

            mask = (1 - seq_ended).to(torch.float32).cuda()

            d_loss, d_real, d_fake, aux_loss, discriminator_gradient = \
                self._optimize_discriminator(image,
                                             fake_image.detach(),
                                             disc_prev_image,
                                             output,
                                             mask,
                                             hamming,
                                             self.cfg.gp_reg,
                                             self.cfg.aux_reg)

            g_loss, generator_gradient = \
                self._optimize_generator(fake_image,
                                         disc_prev_image.detach(),
                                         output.detach(),
                                         objects,
                                         self.cfg.aux_reg,
                                         mask,
                                         mu,
                                         logvar)

            if self.cfg.teacher_forcing:
                prev_image = image
            else:
                prev_image = fake_image

            disc_prev_image = image
            prev_objects = objects

            if (t + 1) % 2 == 0:
                prev_image = prev_image.detach()

            rnn_grads = []
            gru_grads = []
            condition_encoder_grads = []
            img_encoder_grads = []

            if t == max_seq_len - 1:
                rnn_gradient, gru_gradient, condition_gradient,\
                    img_encoder_gradient = self._optimize_rnn()

                gru_grads.append(gru_gradient.data.cpu().numpy())
                condition_encoder_grads.append(
                    condition_gradient.data.cpu().numpy())

                if self.use_image_encoder:
                    img_encoder_grads.append(
                        img_encoder_gradient.data.cpu().numpy())
                if self.use_history:
                    rnn_grads.append(rnn_gradient.data.cpu().numpy())

                visualizer.track(d_real, d_fake)

            hamming = hamming.data.cpu().numpy()[0]
            teller_images.extend(image[:4].data.numpy())
            drawer_images.extend(fake_image[:4].data.cpu().numpy())
            entities = str.join(',', list(batch['entities'][hamming > 0]))
            added_entities.append(entities)

        if iteration % self.cfg.vis_rate == 0:
            visualizer.histogram()
            self._plot_losses(visualizer, g_loss, d_loss, aux_loss, iteration)
            rnn_gradient = np.array(rnn_grads).mean()
            gru_gradient = np.array(gru_grads).mean()
            condition_gradient = np.array(condition_encoder_grads).mean()
            img_encoder_gradient = np.array(img_encoder_grads).mean()
            rnn_grads, gru_grads = [], []
            condition_encoder_grads, img_encoder_grads = [], []
            self._plot_gradients(visualizer, rnn_gradient, generator_gradient,
                                 discriminator_gradient, gru_gradient,
                                 condition_gradient, img_encoder_gradient,
                                 iteration)
            self._draw_images(visualizer, teller_images, drawer_images, nrow=4)
            self.logger.write(epoch, iteration, d_real, d_fake, d_loss, g_loss)

            if isinstance(batch['turn'], list):
                batch['turn'] = np.array(batch['turn']).transpose()

            visualizer.write(batch['turn'][0])
            visualizer.write(added_entities, var_name='entities')
            teller_images = []
            drawer_images = []

        if iteration % self.cfg.save_rate == 0:
            path = os.path.join(self.cfg.log_path, self.cfg.exp_name)

            self._save(fake_image[:4], path, epoch, iteration)
            if not self.cfg.debug:
                self.save_model(path, epoch, iteration)

    def _forward_generator(self, batch_size, condition, image_feature_maps):
        # noise = torch.FloatTensor(batch_size,
        #                           self.cfg.noise_dim).normal_(0, 1).cuda()
        noise = torch.FloatTensor(batch_size,
                                  self.cfg.noise_dim).zero_().cuda()
        fake_images, mu, logvar, sigma = self.generator(
            noise, condition, image_feature_maps)

        return fake_images, mu, logvar, sigma

    def _optimize_discriminator(self,
                                real_images,
                                fake_images,
                                prev_image,
                                condition,
                                mask,
                                objects,
                                gp_reg=0,
                                aux_reg=0):
        """Discriminator is updated every step independent of batch_size
        RNN and the generator
        """

        self.discriminator.zero_grad()
        real_images.requires_grad_()

        d_real, aux_real, _ = self.discriminator(real_images, condition,
                                                 prev_image)
        d_fake, aux_fake, _ = self.discriminator(fake_images, condition,
                                                 prev_image)
        if self.cfg.wrong_fake_ratio == 0:
            d_wrong = None
        else:
            wrong_images = torch.cat((real_images[1:], real_images[0:1]),
                                     dim=0)
            wrong_prev = torch.cat((prev_image[1:], prev_image[0:1]), dim=0)
            d_wrong, _, _ = self.discriminator(wrong_images, condition,
                                               wrong_prev)

        d_loss, aux_loss = self._discriminator_masked_loss(
            d_real, d_fake, d_wrong, aux_real, aux_fake, objects, aux_reg,
            mask)

        d_loss.backward(retain_graph=True)
        if gp_reg:
            reg = gp_reg * self._masked_gradient_penalty(
                d_real, real_images, mask)
            reg.backward(retain_graph=True)

        grad_norm = _recurrent_gan.get_grad_norm(
            self.discriminator.parameters())
        self.discriminator_optimizer.step()

        d_loss_scalar = d_loss.item()
        d_real_np = d_real.cpu().data.numpy()
        d_fake_np = d_fake.cpu().data.numpy()
        aux_loss_scalar = aux_loss.item() if isinstance(
            aux_loss, torch.Tensor) else aux_loss
        grad_norm_scalar = grad_norm.item()
        del d_loss
        del d_real
        del d_fake
        del aux_loss
        del grad_norm
        gc.collect()

        return d_loss_scalar, d_real_np, d_fake_np, aux_loss_scalar, grad_norm_scalar

    def _optimize_generator(self, fake_images, prev_image, condition, objects,
                            aux_reg, mask, mu, logvar):
        self.generator.zero_grad()
        d_fake, aux_fake, _ = self.discriminator(fake_images, condition,
                                                 prev_image)
        g_loss = self._generator_masked_loss(d_fake, aux_fake, objects,
                                             aux_reg, mu, logvar, mask)

        g_loss.backward(retain_graph=True)
        gen_grad_norm = _recurrent_gan.get_grad_norm(
            self.generator.parameters())

        self.generator_optimizer.step()

        g_loss_scalar = g_loss.item()
        gen_grad_norm_scalar = gen_grad_norm.item()

        del g_loss
        del gen_grad_norm
        gc.collect()

        return g_loss_scalar, gen_grad_norm_scalar

    def _optimize_rnn(self):
        if self.use_history:
            torch.nn.utils.clip_grad_norm_(self.rnn.parameters(),
                                           self.cfg.grad_clip)
            rnn_grad_norm = _recurrent_gan.get_grad_norm(self.rnn.parameters())
            self.rnn_optimizer.step()
            self.rnn.zero_grad()
        else:
            rnn_grad_norm = None

        gru_grad_norm = None
        torch.nn.utils.clip_grad_norm_(self.sentence_encoder.parameters(),
                                       self.cfg.grad_clip)
        gru_grad_norm = _recurrent_gan.get_grad_norm(
            self.sentence_encoder.parameters())
        self.sentence_encoder_optimizer.step()
        self.sentence_encoder.zero_grad()

        ce_grad_norm = _recurrent_gan.get_grad_norm(
            self.condition_encoder.parameters())
        self.feature_encoders_optimizer.step()
        self.condition_encoder.zero_grad()

        if self.use_image_encoder:
            ie_grad_norm = _recurrent_gan.get_grad_norm(
                self.image_encoder.parameters())
            self.image_encoder.zero_grad()
        else:
            ie_grad_norm = None
        return rnn_grad_norm, gru_grad_norm, ce_grad_norm, ie_grad_norm

    def _discriminator_masked_loss(self, d_real, d_fake, d_wrong, aux_real,
                                   aux_fake, objects, aux_reg, mask):
        """Accumulates losses only for sequences that have not ended
        to avoid back-propagation through padding"""
        aux_loss = 0
        sample_loss = self.criterion.discriminator(d_real, d_fake, d_wrong,
                                                   self.cfg.wrong_fake_ratio,
                                                   mask)
        if aux_reg > 0:
            aux_loss = (
                self.aux_criterion(aux_real, objects) +
                self.aux_criterion(aux_fake, objects)) * mask.unsqueeze(1)
            aux_loss = aux_reg * aux_loss.mean()
        d_loss = sample_loss + aux_loss
        return d_loss, aux_loss

    def _generator_masked_loss(self, d_fake, aux_fake, objects, aux_reg, mu,
                               logvar, mask):
        """Accumulates losses only for sequences that have not ended
        to avoid back-propagation through padding"""
        sample_loss = self.criterion.generator(d_fake * mask)
        if aux_reg > 0:
            aux_loss = aux_reg * (self.aux_criterion(aux_fake, objects) *
                                  mask.unsqueeze(1)).mean()
        else:
            aux_loss = 0
        if mu is not None:
            kl_loss = self.cfg.cond_kl_reg * kl_penalty(mu, logvar, mask)
        else:
            kl_loss = 0

        g_loss = sample_loss + aux_loss + kl_loss
        return g_loss

    def _masked_gradient_penalty(self, d_real, real_images, mask):
        gp_reg = gradient_penalty(d_real, real_images).mean()
        return gp_reg

    # region Helpers
    def _plot_losses(self, visualizer, g_loss, d_loss, aux_loss, iteration):
        _recurrent_gan._plot_losses(self, visualizer, g_loss, d_loss, aux_loss,
                                    iteration)

    def _plot_gradients(self, visualizer, rnn, gen, disc, gru, ce, ie,
                        iteration):
        _recurrent_gan._plot_gradients(self, visualizer, rnn, gen, disc, gru,
                                       ce, ie, iteration)

    def _draw_images(self, visualizer, real, fake, nrow):
        _recurrent_gan.draw_images(self, visualizer, real, fake, nrow)

    def _save(self, fake, path, epoch, iteration):
        _recurrent_gan._save(self, fake, path, epoch, iteration)

    def save_model(self, path, epoch, iteration):
        _recurrent_gan.save_model(self, path, epoch, iteration)

    def load_model(self, snapshot_path):
        _recurrent_gan.load_model(self, snapshot_path)
            error_d_real.backward()

            ## 尽可能把假图片判别为错误
            noises.data.copy_(
                t.randn(CONFIG["BATCH_SIZE"], CONFIG["NOISE_DIM"], 1, 1))
            fake_img = netG(noises).detach()  # 根据噪声生成假图
            output = netD(fake_img)
            error_d_fake = criterion(output, fake_labels)
            error_d_fake.backward()
            optimizer_discriminator.step()

            error_d = error_d_fake + error_d_real

        if ii % 1 == 0:
            # 训练生成器
            netG.zero_grad()
            noises.data.copy_(
                t.randn(CONFIG["BATCH_SIZE"], CONFIG["NOISE_DIM"], 1, 1))
            fake_img = netG(noises)
            output = netD(fake_img)
            error_g = criterion(output, true_labels)
            error_g.backward()
            optimizer_generator.step()

        proBar.show(epoch, error_d.item(), error_g.item())

    # 保存模型、图片
    fix_fake_imgs = netG(fix_noises)
    tv.utils.save_image(fix_fake_imgs.data[:64],
                        'outputs/Pytorch_AnimateFace_%03d.png' % epoch,
                        normalize=True,
示例#6
0
class NERTrainer:
    """
    NERTrainer manages to train and test NER models on different datasets.

    Args:
        config (TrainerConfig): Trainer configuration.
        datasets (Tuple[list, list, list]): Train/Dev/Test set.
    """

    def __init__(self, config: TrainerConfig, datasets: Tuple[list, list, list]):
        super(NERTrainer, self).__init__()
        writer_folder = os.path.join(config.output_folder, "summary")
        if not os.path.isdir(writer_folder):
            os.makedirs(writer_folder)
        self._config = config
        self._tokenizer = __TOKENIZER_MAP__[config.tokenizer](config.tokenizer_folder)
        self._tokenizer.save(config.output_folder)
        self._outadapter = OutAdapter(config.dataset_folder)
        self._outadapter.save(config.output_folder)
        self._trainset = __DATASET_MAP__[config.dataset](datasets[0], config.max_seq_len, self._outadapter, self._tokenizer)
        self._devset = __DATASET_MAP__[config.dataset](datasets[1], config.max_seq_len, self._outadapter, self._tokenizer)
        self._testset = __DATASET_MAP__[config.dataset](datasets[2], config.max_seq_len, self._outadapter, self._tokenizer)
        self._collate_fn = collate_fn(self._tokenizer.pad_id, self._outadapter.pad_id, config.device)
        self._trainloader = DataLoader(self._trainset, config.batch_size, collate_fn=self._collate_fn)
        self._devloader = DataLoader(self._devset, config.batch_size, collate_fn=self._collate_fn)
        self._testloader = DataLoader(self._testset, config.batch_size, collate_fn=self._collate_fn)
        config.hyperparameters["n_tags"] = len(self._outadapter)
        config.hyperparameters["empty_id"] = self._tokenizer.empty_id
        config.hyperparameters.update(self._tokenizer.configs())
        self._model = __MODEL_MAP__[config.model](
            **config.hyperparameters, token_embeddings=self._tokenizer.token_embeddings()
        ).to(config.device)
        if len(config.gpu) > 1:
            self._model = DataParallel(self._model, device_ids=config.gpu)
        self._loss_fn = nn.CrossEntropyLoss()
        self._optimizer = optim.Adam(self._model.parameters(), lr=config.learning_rate)
        self._writer = SummaryWriter(writer_folder)
        formatter = logging.Formatter("%(asctime)s %(message)s", "%Y-%m-%d %H:%M:%S")
        self._logger = logging.getLogger(__name__)
        self._logger.handlers.clear()
        fh = logging.FileHandler(os.path.join(config.output_folder, "log.txt"))
        fh.setFormatter(formatter)
        sh = logging.StreamHandler(sys.stdout)
        sh.setFormatter(formatter)
        self._logger.addHandler(fh)
        self._logger.addHandler(sh)
        self._logger.setLevel(logging.DEBUG)

    def load_checkpoints(self) -> None:
        """Load trained checkpoints from local disk."""
        checkpoints_path = os.path.join(self._config.output_folder, "model.checkpoints")
        if os.path.isfile(checkpoints_path):
            checkpoints = torch.load(checkpoints_path, map_location=torch.device("cpu"))
            if isinstance(self._model, DataParallel):
                self._model.module.load_state_dict(checkpoints)
            elif isinstance(self._model, nn.Module):
                self._model.load_state_dict(checkpoints)

    def save_checkpoints(self) -> None:
        """Save trained checkpoints into local disk."""
        checkpoints_path = os.path.join(self._config.output_folder, "model.checkpoints")
        if isinstance(self._model, DataParallel):
            checkpoints = self._model.module.state_dict()
        else:
            checkpoints = self._model.state_dict()
        torch.save(checkpoints, checkpoints_path)

    def log(self, content: str, y_value: float = None, x_value: float = None) -> None:
        """
        Record status by logging, tensorboard. If x_value or y_value is None, we regard content
        as text and record it. Otherwise, we regard content as a record classification for merge
        the same values (x, y). The commonly used record classifications are F1 score, precision
        score, etc.

        Args:
            content (str): Logging content.
            y_value (float): Y input.
            x_value (float): X input.
        """
        if y_value is not None and x_value is not None:
            self._writer.add_scalar("{0}/{1}".format(self._config.identity, content), y_value, x_value)
        else:
            self._writer.add_text("{0}/log".format(self._config.identity), content)
            content = "[{0}-{1}] ".format(self._config.model, self._config.dataset) + content
            self._logger.debug(content)

    def test(self, loader: DataLoader = None) -> Dict[str, Any]:
        """
        Return model's performance.

        Args:
            loader (DataLoader): Dataloader to be tested.
        """
        if loader is None:
            loader = self._testloader
        self._model.eval()
        batch_gold_labels, batch_pred_labels = [], []
        for _, batch in enumerate(loader):
            input_ids, output_ids, masks = batch
            preds_ = self._model(input_ids, masks)
            pred_labels = [[self._outadapter[label_id] for label_id in label_ids] for label_ids in preds_.argmax(dim=-1).tolist()]
            batch_pred_labels.extend(pred_labels)
            gold_labels = [[self._outadapter[label_id] for label_id in label_ids] for label_ids in output_ids.tolist()]
            batch_gold_labels.extend(gold_labels)
            del batch, input_ids, preds_

        return evalner(batch_gold_labels, batch_pred_labels)

    def train(self) -> Dict[str, Any]:
        """Start to train model on a given dataset."""
        no_improvemnet, max_f1_score = 0, -math.inf
        for epoch in range(self._config.epoch):
            self._model.train()
            total_loss = 0.
            for _, batch in enumerate(self._trainloader):
                input_ids, output_ids, masks = batch
                self._model.zero_grad()
                preds_ = self._model(input_ids, masks)
                loss = self._loss_fn(preds_.view(-1, len(self._outadapter)), output_ids.view(-1))
                loss.backward()
                self._optimizer.step()
                total_loss += loss.item()
                del batch, loss, input_ids, output_ids
            train_loss = total_loss / len(self._trainloader)
            evaluations = self.test(self._devloader)
            self.log("f1", evaluations["entity"]["f1"], epoch)
            self.log("p", evaluations["entity"]["p"], epoch)
            self.log("r", evaluations["entity"]["r"], epoch)
            self.log("train_loss", train_loss, epoch)
            self.log("epoch {0} dev-f1: {1}, dev-p: {2}, dev-r: {3}, train-loss: {4}".format(
                epoch, evaluations["entity"]["f1"], evaluations["entity"]["p"], evaluations["entity"]["r"], train_loss
            ))
            if evaluations["entity"]["f1"] > max_f1_score:
                max_f1_score = evaluations["entity"]["f1"]
                no_improvemnet = 0
                self.save_checkpoints()
            else:
                no_improvemnet += 1
            if train_loss < self._config.early_stop_loss or no_improvemnet > self._config.stop_if_no_improvement:
                break
        self.load_checkpoints()
        evaluations = self.test(self._testloader)
        self.log("test-f1: {0}, test-p: {1}, test-r: {2}".format(evaluations["entity"]["f1"], evaluations["entity"]["p"], evaluations["entity"]["r"]))

        return evaluations

    def create_discriminated_examples(self, trainset: List[dict]) -> List[dict]:
        """
        Create new counterfactual examples with a discriminator from a observational seed examples.

        Args:
            trainset (List[dict]): A list of observational seed examples.
        """
        self.load_checkpoints()
        self._model.eval()
        all_cfexamples = create_counterfactual_examples(trainset)
        reasonable_cfexamples, unreasonable_cfexamples = [], []
        dataset = __DATASET_MAP__[self._config.dataset](all_cfexamples, self._config.max_seq_len, self._outadapter, self._tokenizer)
        dataloader = DataLoader(dataset, self._config.batch_size, collate_fn=self._collate_fn)
        for i, batch in enumerate(dataloader):
            input_ids, output_ids, masks = batch
            preds_ = self._model(input_ids, masks)
            pred_labels = [[self._outadapter[label_id] for label_id in label_ids] for label_ids in preds_.argmax(dim=-1).tolist()]
            for j, labels in enumerate(pred_labels):
                text = all_cfexamples[i*self._config.batch_size + j]["text"]
                replaced_spans = all_cfexamples[i*self._config.batch_size + j]["replaced"]
                predicted_spans = ["[{0}]({1}, {2})".format(span["label"], span["start"], span["end"]) for span in to_entities(text, labels)]
                if len(set(replaced_spans).intersection(set(predicted_spans))) == len(replaced_spans):
                    reasonable_cfexamples.append(all_cfexamples[i*self._config.batch_size + j])
                else:
                    unreasonable_cfexamples.append(all_cfexamples[i*self._config.batch_size + j])

        return (reasonable_cfexamples, unreasonable_cfexamples)
class Trainer:
    def __init__(self, cfg_file):
        self._train_set_ready = False
        self._network_ready = False
        self._init_config(cfg_file)

    def _init_config(self, cfg_file):

        with open(cfg_file, 'r') as f:
            overwrite_cfg = yaml.load(f, Loader=yaml.FullLoader)

        if 'default_config' in overwrite_cfg:
            with open(os.path.join(overwrite_cfg['default_config']), 'r') as f:
                default_train_config = yaml.load(f, Loader=yaml.FullLoader)
                self.config = overwrite_configs(default_train_config,
                                                overwrite_cfg)
        else:
            self.config = overwrite_cfg

        self.name = self.config['name']
        self.recorder = Recorder(
            os.path.join('data', 'record', self.name),
            os.path.join('data', 'record', self.name + '.log'))
        self.model_dir = os.path.join('data', 'model', self.name)
        self.hem_thresh = self.config['hem_thresh_begin']
        self.hem_hit_count = 0

    def _init_train_set(self):
        if self._train_set_ready:
            print('training set is ready')
            return

        print('begin preparing training set...')
        database = CorrespondenceDatabase()
        self.database = database

        train_set = []
        for name in self.config['train_set']:
            train_set += database.__getattr__(name + "_set")
        self.train_set = CorrespondenceDataset(self.config, train_set)
        self.train_loader = DataLoader(self.train_set,
                                       self.config['batch_size'],
                                       shuffle=True,
                                       num_workers=self.config['worker_num'],
                                       worker_init_fn=worker_init_fn)
        self._train_set_ready = True
        print('training set is ready')

    def _init_network(self):
        if self._network_ready:
            return

        print('begin preparing network...')
        self.network = TrainWrapper(self.config)
        self.extractor = self.network.extractor_wrapper
        self.embedder = self.network.embedder_wrapper
        self.network = DataParallel(self.network).cuda()

        paras = []
        if self.config['train_extractor']: paras += self.extractor.parameters()
        if self.config['train_embedder']: paras += self.embedder.parameters()
        self.optim = Adam(paras, lr=1e-3)

        if self.config['pretrain']:
            self._load_model(self.config['pretrain_model_path'],
                             self.config['pretrain_step'],
                             self.config['pretrain_extractor'],
                             self.config['pretrain_embedder'], False)

        self.step = 0
        self._load_model(self.model_dir, -1, True, True, True)

        self._network_ready = True
        self.transformer = TransformerCV(self.config)
        print('network is ready')

    def _adjust_hem_thresh(self):
        decay_num = (self.step + 1) // self.config['hem_thresh_decay_step']
        old_hem_thresh = self.hem_thresh
        self.hem_thresh = self.config[
            'hem_thresh_begin'] - decay_num * self.config[
                'hem_thresh_decay_rate']
        self.hem_thresh = max(self.hem_thresh, self.config['hem_thresh_end'])
        if self.hem_thresh != old_hem_thresh:
            print('hem_thresh adjust from {} to {}'.format(
                old_hem_thresh, self.hem_thresh))

    def _get_warm_up_lr(self):
        if self.step <= 2500:
            lr = 1e-4 * (self.step // 250 + 1)
        elif self.step <= 5000:
            lr = 1e-3
        else:
            # 1e-3 to 5e-5
            lr = 1e-3 - (1e-3 - 1e-4) / (15000 // 250) * (
                (self.step - 5000) // 250)
            lr = max(lr, 1e-4)
        return lr

    def _get_finetune_lr(self):
        # 5e-4 to 1e-5
        lr = 5e-5 - (5e-5 - 1e-6) / (10000 // 250) * (self.step // 250)
        lr = max(lr, 1e-6)
        return lr

    def train(self):
        self._init_network()
        self._init_train_set()

        batch_begin = time.time()
        for data in self.train_loader:
            lr = self.__getattribute__('_get_{}_lr'.format(
                self.config['lr_type']))()
            reset_learning_rate(self.optim, lr)
            self._adjust_hem_thresh()

            loss_info = OrderedDict()
            img_list0, pts_list0, pts0, grid_list0, img_list1, pts_list1, pts1, grid_list1, scale_offset, rotate_offset, H = data
            data_time = time.time() - batch_begin

            results = self.network(img_list0, pts_list0, pts0, grid_list0,
                                   img_list1, pts_list1, pts1, grid_list1,
                                   scale_offset, rotate_offset,
                                   self.hem_thresh, self.config['loss_type'])

            loss = 0.0
            for k, v in results.items():
                v = torch.mean(v)
                if k.endswith('loss'): loss = loss + v
                loss_info[k] = v.cpu().detach().numpy()

            self.network.zero_grad()
            self.optim.zero_grad()
            loss.backward()
            self.optim.step()

            batch_time = time.time() - batch_begin
            # record
            loss_info['data_time'] = data_time
            loss_info['batch_time'] = batch_time
            loss_info['lr'] = lr
            loss_info['hem_thresh'] = self.hem_thresh
            total_step = self.step
            self.recorder.rec_loss(loss_info, total_step, total_step, 'train',
                                   ((total_step + 1) %
                                    self.config['info_step']) == 0)
            print("step is %d, loss is %f" % (self.step, loss))
            # save model
            if (total_step + 1) % self.config['save_step'] == 0:
                self._save_model()
                print('model saved!')

            batch_begin = time.time()
            self.step += 1
            if self.step > self.config['train_step']: break

        self._save_model()
        print('model saved!')

    def _save_model(self):
        os.system('mkdir -p {}'.format(self.model_dir))
        state_dict = {
            'extractor': self.extractor.state_dict(),
            'optim': self.optim.state_dict(),
            'step': self.step
        }
        if self.embedder is not None:
            state_dict['embedder'] = self.embedder.state_dict()
        torch.save(state_dict,
                   os.path.join(self.model_dir, '{}.pth'.format(self.step)))

    def _load_model(self,
                    model_dir,
                    step=-1,
                    load_extractor=True,
                    load_embedder=False,
                    load_optimizer=True):
        if not os.path.exists(model_dir):
            return 0

        pths = [int(pth.split('.')[0]) for pth in os.listdir(model_dir)]
        if len(pths) == 0:
            return 0
        if step == -1:
            pth = max(pths)
        else:
            pth = step

        pretrained_model = torch.load(
            os.path.join(model_dir, '{}.pth'.format(pth)))
        if load_extractor and self.extractor is not None:
            state_dict = pretrained_model['extractor']
            self.extractor.load_state_dict(state_dict)
        if load_embedder and self.embedder is not None and 'embedder' in pretrained_model:
            self.embedder.load_state_dict(pretrained_model['embedder'])
        if load_optimizer:
            self.optim.load_state_dict(pretrained_model['optim'])
        print('load {} step {}'.format(model_dir, pretrained_model['step']))
        self.step = pretrained_model['step'] + 1
class VisualSentimentTrainer():
    def __init__(self, train_data, test_data, img_dir, batchsize, load_model,
                 device):
        self.device = device
        self.train_data = train_data
        self.test_data = test_data
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

        self.train_transform = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor()
        ])

        self.test_transform = transforms.Compose([
            transforms.Resize(224),
            transforms.CenterCrop(224),
            transforms.ToTensor()
        ])

        self.train_set = VisualSentimentDataset(self.train_data,
                                                img_dir,
                                                transform=self.train_transform)
        self.train_loader = DataLoader(self.train_set,
                                       batch_size=batchsize,
                                       shuffle=True,
                                       num_workers=4)

        self.test_set = VisualSentimentDataset(self.test_data,
                                               img_dir,
                                               transform=self.test_transform)
        self.test_loader = DataLoader(self.test_set,
                                      batch_size=batchsize,
                                      num_workers=4)

        self.model = Res50_sentiment()
        self.model = DataParallel(self.model)
        if load_model:
            logger.info('load model from ' + load_model)
            self.model.load_state_dict(torch.load(load_model))
        self.model.to(device)
        self.optimizer = optim.Adam(self.model.parameters(), lr=5e-5)
        self.criterion = nn.CrossEntropyLoss()
        self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer,
                                                        milestones=[2, 4],
                                                        gamma=0.5)

    def train_epoch(self, epoch, log_interval, save_interval, ckpt_file):
        self.model.train()
        running_ls = 0
        acc_ls = 0
        start = time.time()
        num_batches = len(self.train_loader)
        for i, batch in enumerate(self.train_loader):
            img, label = [t.to(self.device) for t in batch]
            self.model.zero_grad()
            pred = self.model(img)
            loss = self.criterion(pred, label)
            loss.backward(torch.ones_like(loss))
            running_ls += loss.mean().item()
            acc_ls += loss.mean().item()
            self.optimizer.step()

            if (i + 1) % log_interval == 0:
                elapsed_time = time.time() - start
                iters_per_sec = (i + 1) / elapsed_time
                remaining = (num_batches - i - 1) / iters_per_sec
                remaining_fmt = time.strftime("%H:%M:%S",
                                              time.gmtime(remaining))
                elapsed_fmt = time.strftime("%H:%M:%S",
                                            time.gmtime(elapsed_time))

                print(
                    '[{:>2}, {:>4}/{}] running loss:{:.4} acc loss:{:.4} {:.3}iters/s {}<{}'
                    .format(epoch, (i + 1), num_batches,
                            running_ls / log_interval, acc_ls / (i + 1),
                            iters_per_sec, elapsed_fmt, remaining_fmt))
                running_ls = 0

            if (i + 1) % save_interval == 0:
                self.save_model(ckpt_file)

    def test(self):
        self.model.eval()
        batches_count = 0
        data_count = 0
        num_correct = 0
        with torch.no_grad():
            for i, batch in enumerate(tqdm(self.test_loader)):
                batches_count += 1
                img, label = tuple(t.to(self.device) for t in batch)
                data_count += img.shape[0]
                logits = self.model(img).cpu().numpy()
                label = label.cpu().numpy()
                num_correct += np.sum(np.argmax(logits, axis=1) == label)

        accuracy = num_correct / data_count
        print('accuracy: {:.4}%'.format(accuracy * 100))

    def save_model(self, file):
        torch.save(self.model.state_dict(), file)
def train_model(args):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    device_ids = [0, 1, 2, 3]
    batch_size = args.batch_size
    input_channels = 1
    out_channels = [args.out_channels1, args.out_channels2]
    kernel_size_cnn = [[args.kernel_size_cnn1, args.kernel_size_cnn2],
                       [args.kernel_size_cnn2, args.kernel_size_cnn1]]
    stride_size_cnn = [[args.stride_size_cnn1, args.stride_size_cnn2],
                       [args.stride_size_cnn2, args.stride_size_cnn1]]
    kernel_size_pool = [[args.kernel_size_pool1, args.kernel_size_pool2],
                        [args.kernel_size_pool2, args.kernel_size_pool1]]
    stride_size_pool = [[args.stride_size_pool1, args.stride_size_pool2],
                        [args.stride_size_pool2, args.stride_size_pool1]]
    hidden_dim = 200
    num_layers = 2
    dropout = 0
    num_labels = 4
    hidden_dim_lstm = 200
    epoch_num = 50
    num_layers_lstm = 2
    nfft = [512, 1024]
    weight = args.weight
    #model = MultiSpectrogramModel(input_channels,out_channels, kernel_size_cnn, stride_size_cnn, kernel_size_pool,
    #stride_size_pool, hidden_dim,num_layers,dropout,num_labels, batch_size,
    #hidden_dim_lstm,num_layers_lstm,device, nfft, weight, False)
    model = resnet18()
    print(
        "============================ Number of parameters ===================================="
    )
    print(str(sum(p.numel() for p in model.parameters() if p.requires_grad)))

    path = "batch_size:{};out_channels:{};kernel_size_cnn:{};stride_size_cnn:{};kernel_size_pool:{};stride_size_pool:{}; weight:{}".format(
        args.batch_size, out_channels, kernel_size_cnn, stride_size_cnn,
        kernel_size_pool, stride_size_pool, weight)
    with open("/scratch/speech/models/classification/resnet_stats.txt",
              "a+") as f:
        f.write("\n" + "============ model starts ===========")
        f.write(
            "\n" + "model_parameters: " +
            str(sum(p.numel()
                    for p in model.parameters() if p.requires_grad)) + "\n" +
            path + "\n")
    model.cuda()
    model = DataParallel(model, device_ids=device_ids)
    model.train()

    # Use Adam as the optimizer with learning rate 0.01 to make it fast for testing purposes
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    optimizer2 = optim.SGD(model.parameters(), lr=0.1)
    scheduler = ReduceLROnPlateau(optimizer=optimizer,
                                  factor=0.5,
                                  patience=2,
                                  threshold=1e-3)
    #scheduler2=ReduceLROnPlateau(optimizer=optimizer2, factor=0.5, patience=2, threshold=1e-3)
    #scheduler2 =CosineAnnealingLR(optimizer2, T_max=300, eta_min=0.0001)
    scheduler3 = MultiStepLR(optimizer, [5, 10, 15], gamma=0.1)

    # Load the training data
    training_data = IEMOCAP(name='mel', nfft=nfft, train=True)
    train_loader = DataLoader(dataset=training_data,
                              batch_size=batch_size,
                              shuffle=True,
                              collate_fn=my_collate,
                              num_workers=0,
                              drop_last=True)
    testing_data = IEMOCAP(name='mel', nfft=nfft, train=False)
    test_loader = DataLoader(dataset=testing_data,
                             batch_size=batch_size,
                             shuffle=True,
                             collate_fn=my_collate,
                             num_workers=0,
                             drop_last=True)

    #print("=================")
    #print(len(training_data))
    #print("===================")

    test_acc = []
    train_acc = []
    test_loss = []
    train_loss = []
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    for epoch in range(
            epoch_num
    ):  # again, normally you would NOT do 300 epochs, it is toy data
        print("===================================" + str(epoch + 1) +
              "==============================================")
        losses = 0
        correct = 0
        model.train()
        for j, (input_lstm, input1, input2, target,
                seq_length) in enumerate(train_loader):
            if (j + 1) % 20 == 0:
                print("=================================Train Batch" +
                      str(j + 1) +
                      "===================================================")
            model.zero_grad()
            x = model(input1)
            target = target.to(device)
            target_index = torch.argmax(target, dim=1).to(device)
            correct_batch = torch.sum(target_index == torch.argmax(x, dim=1))
            losses_batch = F.cross_entropy(x, torch.max(target, 1)[1])
            correct_batch = torch.unsqueeze(correct_batch, dim=0)
            losses_batch = torch.unsqueeze(losses_batch, dim=0)
            loss = torch.mean(losses_batch, dim=0)
            #print(loss)
            correct_batch = torch.sum(correct_batch, dim=0)
            losses += loss.item() * batch_size
            loss.backward()
            #weight=model.module.state_dict()["weight"]
            #weight=torch.exp(10*weight)/(1+torch.exp(10*weight)).item()
            optimizer.step()
            correct += correct_batch.item()
        accuracy = correct * 1.0 / ((j + 1) * batch_size)
        losses = losses / ((j + 1) * batch_size)
        #scheduler3.step()
        losses_test = 0
        correct_test = 0
        #torch.save(model.module.state_dict(), "/scratch/speech/models/classification/spec_full_joint_checkpoint_epoch_{}.pt".format(epoch+1))
        model.eval()
        with torch.no_grad():
            for j, (input_lstm, input1, input2, target,
                    seq_length) in enumerate(test_loader):
                if (j + 1) % 10 == 0:
                    print(
                        "=================================Test Batch" +
                        str(j + 1) +
                        "===================================================")
                x = model(input1)
                target = target.to(device)
                target_index = torch.argmax(target, dim=1).to(device)
                correct_batch = torch.sum(
                    target_index == torch.argmax(x, dim=1))
                losses_batch = F.cross_entropy(x, torch.max(target, 1)[1])
                correct_batch = torch.unsqueeze(correct_batch, dim=0)
                losses_batch = torch.unsqueeze(losses_batch, dim=0)
                loss = torch.mean(losses_batch, dim=0)
                correct_batch = torch.sum(correct_batch, dim=0)
                losses_test += loss.item() * batch_size
                correct_test += correct_batch.item()

        #print("how many correct:", correct_test)
        accuracy_test = correct_test * 1.0 / ((j + 1) * batch_size)
        losses_test = losses_test / ((j + 1) * batch_size)

        # data gathering
        test_acc.append(accuracy_test)
        train_acc.append(accuracy)
        test_loss.append(losses_test)
        train_loss.append(losses)
        print(
            "Epoch: {}-----------Training Loss: {} -------- Testing Loss: {} -------- Training Acc: {} -------- Testing Acc: {}"
            .format(epoch + 1, losses, losses_test, accuracy, accuracy_test) +
            "\n")
        with open("/scratch/speech/models/classification/resnet_stats.txt",
                  "a+") as f:
            #f.write("Epoch: {}-----------Training Loss: {} -------- Testing Loss: {} -------- Training Acc: {} -------- Testing Acc: {}".format(epoch+1,losses,losses_test, accuracy, accuracy_test)+"\n")
            if epoch == epoch_num - 1:
                f.write("Best Accuracy:{:06.5f}".format(max(test_acc)) + "\n")
                f.write("Average Top 10 Accuracy:{:06.5f}".format(
                    np.mean(np.sort(np.array(test_acc))[-10:])) + "\n")
                f.write("=============== model ends ===================" +
                        "\n")
    print("success:{}, Best Accuracy:{}".format(path, max(test_acc)))
示例#10
0
            log_likelihood = loss * trg_len  # B x 1
        lls.append(log_likelihood)

    lls = torch.cat(lls, 1).squeeze()
    ppl = torch.exp(lls.sum(1) / lengths)
    return dict(perplexity=ppl.tolist())


def perplexity_without_device(x):
    return perplexity(x, device)

with torch.no_grad():
    print(perplexity_without_device(dset_sbert[0:32]))


model.zero_grad(set_to_none=True)
with torch.no_grad():
    dset_sbert = dset_sbert.map(perplexity_without_device, batched=True, batch_size=32)

dset_sbert.save_to_disk("/home/ahemf/processed_datasets/dsets_448_sbert")

#####################################################################################################
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
from torch.nn import CrossEntropyLoss, functional as F, DataParallel
from torch import nn
from datasets import load_dataset, concatenate_datasets, Dataset, DatasetDict
import torch
import os
import numpy as np
os.environ['TOKENIZERS_PARALLELISM'] = "true"
device = torch.device("cuda:0")
示例#11
0
class RecurrentGAN_Mingyang():
    def __init__(self, cfg):
        """A recurrent GAN model, each time step a generated image
        (x'_{t-1}) and the current question q_{t} are fed to the RNN
        to produce the conditioning vector for the GAN.
        The following equations describe this model:

            - c_{t} = RNN(h_{t-1}, q_{t}, x^{~}_{t-1})
            - x^{~}_{t} = G(z | c_{t})
        """
        super(RecurrentGAN_Mingyang, self).__init__()

        # region Models-Instantiation

        ###############################Original DataParallel###################
        self.generator = DataParallel(
            GeneratorFactory.create_instance(cfg)).cuda()

        self.discriminator = DataParallel(
            DiscriminatorFactory.create_instance(cfg)).cuda()

        self.rnn = nn.DataParallel(nn.GRU(cfg.input_dim,
                                          cfg.hidden_dim,
                                          batch_first=False),
                                   dim=1).cuda()
        # self.rnn = DistributedDataParallel(nn.GRU(cfg.input_dim,
        #                                           cfg.hidden_dim,
        # batch_first=False), dim=1).cuda()

        self.layer_norm = nn.DataParallel(nn.LayerNorm(cfg.hidden_dim)).cuda()

        self.image_encoder = DataParallel(ImageEncoder(cfg)).cuda()

        self.condition_encoder = DataParallel(ConditionEncoder(cfg)).cuda()

        self.sentence_encoder = nn.DataParallel(SentenceEncoder(cfg)).cuda()
        #######################################################################
        # self.generator = GeneratorFactory.create_instance(cfg).cuda()

        # self.discriminator = DiscriminatorFactory.create_instance(cfg).cuda()

        # self.rnn = nn.GRU(cfg.input_dim,cfg.hidden_dim,batch_first=False).cuda()
        # # self.rnn = DistributedDataParallel(nn.GRU(cfg.input_dim,
        # #                                           cfg.hidden_dim,
        # # batch_first=False), dim=1).cuda()

        # self.layer_norm = nn.LayerNorm(cfg.hidden_dim).cuda()

        # self.image_encoder = =ImageEncoder(cfg).cuda()

        # self.condition_encoder = ConditionEncoder(cfg).cuda()

        # self.sentence_encoder = SentenceEncoder(cfg).cuda()

        # endregion

        # region Optimizers

        self.generator_optimizer = OPTIM[cfg.generator_optimizer](
            self.generator.parameters(), cfg.generator_lr, cfg.generator_beta1,
            cfg.generator_beta2, cfg.generator_weight_decay)

        self.discriminator_optimizer = OPTIM[cfg.discriminator_optimizer](
            self.discriminator.parameters(), cfg.discriminator_lr,
            cfg.discriminator_beta1, cfg.discriminator_beta2,
            cfg.discriminator_weight_decay)

        self.rnn_optimizer = OPTIM[cfg.rnn_optimizer](self.rnn.parameters(),
                                                      cfg.rnn_lr)

        self.sentence_encoder_optimizer = OPTIM[cfg.gru_optimizer](
            self.sentence_encoder.parameters(), cfg.gru_lr)

        self.use_image_encoder = cfg.use_fg
        feature_encoding_params = list(self.condition_encoder.parameters())
        if self.use_image_encoder:
            feature_encoding_params += list(self.image_encoder.parameters())

        self.feature_encoders_optimizer = OPTIM['adam'](
            feature_encoding_params, cfg.feature_encoder_lr)

        # endregion

        # region Criterion

        self.criterion = LOSSES[cfg.criterion]()
        self.aux_criterion = DataParallel(torch.nn.BCELoss()).cuda()

        #Added by Mingyang for segmentation loss
        if cfg.balanced_seg:
            label_weights = np.array([
                3.02674201e-01, 1.91545454e-03, 2.90009221e-04, 7.50949673e-04,
                1.08670452e-03, 1.11353785e-01, 4.00971053e-04, 1.06240113e-02,
                1.59590824e-01, 5.38960105e-02, 3.36431602e-02, 3.99029734e-02,
                1.88888847e-02, 2.06441476e-03, 6.33775290e-02, 5.81920411e-03,
                3.79528817e-03, 7.87975754e-02, 2.73547355e-03, 1.08308135e-01,
                0.00000000e+00, 8.44408475e-05
            ])
            #reverse the loss
            label_weights = 1 / label_weights
            label_weights[20] = 0
            label_weights = label_weights / np.min(label_weights[:20])
            #convert numpy to tensor
            label_weights = torch.from_numpy(label_weights)
            label_weights = label_weights.type(torch.FloatTensor)
            self.seg_criterion = DataParallel(
                torch.nn.CrossEntropyLoss(weight=label_weights)).cuda()
        else:
            self.seg_criterion = DataParallel(
                torch.nn.CrossEntropyLoss()).cuda()

        # endregion

        self.cfg = cfg
        self.logger = Logger(cfg.log_path, cfg.exp_name)

        # define unorm
        self.unorm = UnNormalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))

        # define the label distribution

    def train_batch(self,
                    batch,
                    epoch,
                    iteration,
                    visualizer,
                    logger,
                    total_iters=0,
                    current_batch_t=0):
        """
        The training scheme follows the following:
            - Discriminator and Generator is updated every time step.
            - RNN, SentenceEncoder and ImageEncoder parameters are
            updated every sequence
        """
        batch_size = len(batch['image'])
        max_seq_len = batch['image'].size(1)

        prev_image = torch.FloatTensor(batch['background'])
        prev_image = prev_image \
            .repeat(batch_size, 1, 1, 1)
        disc_prev_image = prev_image
        # print("disc_prev_image size is: {}".format(disc_prev_image.shape))

        # Initial inputs for the RNN set to zeros
        hidden = torch.zeros(1, batch_size, self.cfg.hidden_dim)
        prev_objects = torch.zeros(batch_size, self.cfg.num_objects)

        teller_images = []
        drawer_images = []
        added_entities = []

        #print("max sequence length of current batch: {}".format(max_seq_len))
        for t in range(max_seq_len):
            image = batch['image'][:, t]
            turns_word_embedding = batch['turn_word_embedding'][:, t]
            turns_lengths = batch['turn_lengths'][:, t]
            objects = batch['objects'][:, t]
            seq_ended = t > (batch['dialog_length'] - 1)

            image_feature_map, image_vec, object_detections = \
                self.image_encoder(prev_image)
            _, current_image_feat, _ = self.image_encoder(image)

            # print("[image_encoder] image_feature_map shape is: {}".format(image_feature_map.shape))
            # print("[image_encoder] image_vec shape is: {}".format(image_vec.shape))

            turn_embedding = self.sentence_encoder(turns_word_embedding,
                                                   turns_lengths)
            rnn_condition, current_image_feat = \
                self.condition_encoder(turn_embedding,
                                       image_vec,
                                       current_image_feat)

            rnn_condition = rnn_condition.unsqueeze(0)
            # self.rnn.flatten_parameters()  # Added by Mingyang to Resolve the
            # Warning
            self.rnn.module.flatten_parameters()
            output, hidden = self.rnn(rnn_condition, hidden)

            output = output.squeeze(0)
            output = self.layer_norm(output)

            fake_image, mu, logvar, sigma = self._forward_generator(
                batch_size, output.detach(), image_feature_map)

            #print("[image_generator] fake_image size is: {}".format(fake_image.shape))
            #print("[image_generator] fake_image_one_pixel is: {}".format(fake_image[0,:,0,0]))

            visualizer.track_sigma(sigma)

            hamming = objects - prev_objects
            hamming = torch.clamp(hamming, min=0)

            # print(image.shape)
            # print(disc_prev_image.shape)
            d_loss, d_real, d_fake, aux_loss, discriminator_gradient = \
                self._optimize_discriminator(image,
                                             fake_image.detach(),
                                             disc_prev_image,
                                             output,
                                             seq_ended,
                                             hamming,
                                             self.cfg.gp_reg,
                                             self.cfg.aux_reg)

            # append the segmentation loss accordingly
            if re.search(r"seg", self.cfg.gan_type):
                assert self.cfg.seg_reg > 0, "the sge_reg must be larger than 0"
                if self.cfg.gan_type == "recurrent_gan_mingyang_img64_seg":
                    #The size of seg_fake is adjusted to (Batch, N, C)
                    seg_fake = fake_image.view(fake_image.size(0),
                                               fake_image.size(1),
                                               -1).permute(0, 2, 1)
                    #The size of the seg_gt is obtained from image
                    seg_gt = torch.argmax(image, dim=1).view(image.size(0), -1)

            else:
                assert self.cfg.seg_reg == 0, "the sge_reg must be equal to 0"
                seg_fake = None
                seg_gt = None


            g_loss, generator_gradient = \
                self._optimize_generator(fake_image,
                                         disc_prev_image.detach(),
                                         output.detach(),
                                         objects,
                                         self.cfg.aux_reg,
                                         seq_ended,
                                         mu,
                                         logvar,
                                         self.cfg.seg_reg,
                                         seg_fake,
                                         seg_gt)
            #return

            if self.cfg.teacher_forcing:
                prev_image = image
            else:
                prev_image = fake_image

            disc_prev_image = image
            prev_objects = objects

            if (t + 1) % 2 == 0:
                prev_image = prev_image.detach()

            rnn_grads = []
            gru_grads = []
            condition_encoder_grads = []
            img_encoder_grads = []

            if t == max_seq_len - 1:
                rnn_gradient, gru_gradient, condition_gradient,\
                    img_encoder_gradient = self._optimize_rnn()

                rnn_grads.append(rnn_gradient.data.cpu().numpy())
                gru_grads.append(gru_gradient.data.cpu().numpy())
                condition_encoder_grads.append(
                    condition_gradient.data.cpu().numpy())

                if self.use_image_encoder:
                    img_encoder_grads.append(
                        img_encoder_gradient.data.cpu().numpy())

                visualizer.track(d_real, d_fake)

            hamming = hamming.data.cpu().numpy()[0]
            # teller_images.extend(image[:4].data.cpu().numpy())
            # drawer_images.extend(fake_image[:4].data.cpu().numpy())
            new_teller_images = []
            for x in image[:4].data.cpu():
                # print(x.shape)
                # new_x = self.unorm(x)
                # new_x = transforms.ToPILImage()(new_x).convert('RGB')
                # # new_x = np.array(new_x)[..., ::-1]
                # new_x = np.moveaxis(np.array(new_x), -1, 0)

                if self.cfg.image_gen_mode == "real":
                    new_x = self.unormalize(x)
                elif self.cfg.image_gen_mode == "segmentation":
                    new_x = self.unormalize_segmentation(x.data.numpy())
                elif self.cfg.image_gen_mode == "segmentation_onehot":
                    #TODO: Implement the functino to convert new_x to colored_image
                    new_x = self.unormalize_segmentation_onehot(
                        x.data.cpu().numpy())
                    #print(new_x.shape)
                    #return

                # print(new_x.shape)
                new_teller_images.append(new_x)
            teller_images.extend(new_teller_images)

            new_drawer_images = []
            for x in fake_image[:4].data.cpu():
                # print(x.shape)
                # new_x = self.unorm(x)
                # new_x = transforms.ToPILImage()(new_x).convert('RGB')
                # # new_x = np.array(new_x)[..., ::-1]
                # new_x = np.moveaxis(np.array(new_x), -1, 0)

                if self.cfg.image_gen_mode == "real":
                    new_x = self.unormalize(x)
                elif self.cfg.image_gen_mode == "segmentation":
                    new_x = self.unormalize_segmentation(x.data.cpu().numpy())
                elif self.cfg.image_gen_mode == "segmentation_onehot":
                    #TODO: Implement the functino to convert new_x to colored_image
                    new_x = self.unormalize_segmentation_onehot(
                        x.data.cpu().numpy())

                # print(new_x.shape)
                new_drawer_images.append(new_x)
            drawer_images.extend(new_drawer_images)
            # drawer_images.extend(fake_image[:4].data.cpu().numpy())
            # print(drawer_images[0].shape)

            # entities = str.join(',', list(batch['entities'][hamming > 0]))
            # added_entities.append(entities)
        # print(iteration)

        if iteration % self.cfg.vis_rate == 0:
            visualizer.histogram()
            self._plot_losses(visualizer, g_loss, d_loss, aux_loss, iteration)
            rnn_gradient = np.array(rnn_grads).mean()
            gru_gradient = np.array(gru_grads).mean()
            condition_gradient = np.array(condition_encoder_grads).mean()
            img_encoder_gradient = np.array(img_encoder_grads).mean()
            rnn_grads, gru_grads = [], []
            condition_encoder_grads, img_encoder_grads = [], []
            self._plot_gradients(visualizer, rnn_gradient, generator_gradient,
                                 discriminator_gradient, gru_gradient,
                                 condition_gradient, img_encoder_gradient,
                                 iteration)

            self._draw_images(visualizer, teller_images, drawer_images, nrow=4)
            # self.logger.write(epoch, "{}/{}".format(iteration,total_iters),
            # d_real, d_fake, d_loss, g_loss)
            remaining_time = str(
                datetime.timedelta(seconds=current_batch_t *
                                   (total_iters - iteration)))
            self.logger.write(epoch,
                              "{}/{}".format(iteration, total_iters),
                              d_real,
                              d_fake,
                              d_loss,
                              g_loss,
                              expected_finish_time=remaining_time)
            if isinstance(batch['turn'], list):
                batch['turn'] = np.array(batch['turn']).transpose()

            visualizer.write(batch['turn'][0])
            # visualizer.write(added_entities, var_name='entities')
            teller_images = []
            drawer_images = []

        if iteration % self.cfg.save_rate == 0:
            path = os.path.join(self.cfg.log_path, self.cfg.exp_name)

            # self._save(fake_image[:4], path, epoch,
            #            iteration)
            if not self.cfg.debug:
                self.save_model(path, epoch, iteration)

    def _forward_generator(self, batch_size, condition, image_feature_maps):
        noise = torch.FloatTensor(batch_size,
                                  self.cfg.noise_dim).normal_(0, 1).cuda()

        fake_images, mu, logvar, sigma = self.generator(
            noise, condition, image_feature_maps)

        return fake_images, mu, logvar, sigma

    def _optimize_discriminator(self,
                                real_images,
                                fake_images,
                                prev_image,
                                condition,
                                mask,
                                objects,
                                gp_reg=0,
                                aux_reg=0):
        """Discriminator is updated every step independent of batch_size
        RNN and the generator
        """
        wrong_images = torch.cat((real_images[1:], real_images[0:1]), dim=0)
        wrong_prev = torch.cat((prev_image[1:], prev_image[0:1]), dim=0)

        self.discriminator.zero_grad()
        real_images.requires_grad_()

        d_real, aux_real, _ = self.discriminator(real_images, condition,
                                                 prev_image)
        d_fake, aux_fake, _ = self.discriminator(fake_images, condition,
                                                 prev_image)
        d_wrong, _, _ = self.discriminator(wrong_images, condition, wrong_prev)

        d_loss, aux_loss = self._discriminator_masked_loss(
            d_real, d_fake, d_wrong, aux_real, aux_fake, objects, aux_reg,
            mask)

        d_loss.backward(retain_graph=True)
        if gp_reg:
            reg = gp_reg * self._masked_gradient_penalty(
                d_real, real_images, mask)
            reg.backward(retain_graph=True)

        grad_norm = _recurrent_gan.get_grad_norm(
            self.discriminator.parameters())
        self.discriminator_optimizer.step()

        d_loss_scalar = d_loss.item()
        d_real_np = d_real.cpu().data.numpy()
        d_fake_np = d_fake.cpu().data.numpy()
        aux_loss_scalar = aux_loss.item() if isinstance(
            aux_loss, torch.Tensor) else aux_loss
        grad_norm_scalar = grad_norm.item()
        del d_loss
        del d_real
        del d_fake
        del aux_loss
        del grad_norm
        gc.collect()

        return d_loss_scalar, d_real_np, d_fake_np, aux_loss_scalar, grad_norm_scalar

    def _optimize_generator(self,
                            fake_images,
                            prev_image,
                            condition,
                            objects,
                            aux_reg,
                            mask,
                            mu,
                            logvar,
                            seg_reg=0,
                            seg_fake=None,
                            seg_gt=None):
        self.generator.zero_grad()
        d_fake, aux_fake, _ = self.discriminator(fake_images, condition,
                                                 prev_image)
        g_loss = self._generator_masked_loss(d_fake, aux_fake, objects,
                                             aux_reg, mu, logvar, mask,
                                             seg_reg, seg_fake, seg_gt)

        g_loss.backward(retain_graph=True)
        gen_grad_norm = _recurrent_gan.get_grad_norm(
            self.generator.parameters())

        self.generator_optimizer.step()

        g_loss_scalar = g_loss.item()
        gen_grad_norm_scalar = gen_grad_norm.item()

        del g_loss
        del gen_grad_norm
        gc.collect()

        return g_loss_scalar, gen_grad_norm_scalar

    def _optimize_rnn(self):
        torch.nn.utils.clip_grad_norm_(self.rnn.parameters(),
                                       self.cfg.grad_clip)
        rnn_grad_norm = _recurrent_gan.get_grad_norm(self.rnn.parameters())
        self.rnn_optimizer.step()
        self.rnn.zero_grad()

        gru_grad_norm = None
        torch.nn.utils.clip_grad_norm_(self.sentence_encoder.parameters(),
                                       self.cfg.grad_clip)
        gru_grad_norm = _recurrent_gan.get_grad_norm(
            self.sentence_encoder.parameters())
        self.sentence_encoder_optimizer.step()
        self.sentence_encoder.zero_grad()

        ce_grad_norm = _recurrent_gan.get_grad_norm(
            self.condition_encoder.parameters())
        ie_grad_norm = _recurrent_gan.get_grad_norm(
            self.image_encoder.parameters())
        self.feature_encoders_optimizer.step()
        self.condition_encoder.zero_grad()
        self.image_encoder.zero_grad()
        return rnn_grad_norm, gru_grad_norm, ce_grad_norm, ie_grad_norm

    def _discriminator_masked_loss(self, d_real, d_fake, d_wrong, aux_real,
                                   aux_fake, objects, aux_reg, mask):
        """Accumulates losses only for sequences that have not ended
        to avoid back-propagation through padding"""
        d_loss = []
        aux_losses = []
        for b, ended in enumerate(mask):
            if not ended:
                sample_loss = self.criterion.discriminator(
                    d_real[b], d_fake[b], d_wrong[b],
                    self.cfg.wrong_fake_ratio)
                if aux_reg > 0:
                    aux_loss = aux_reg * (
                        self.aux_criterion(aux_real[b], objects[b]).mean() +
                        self.aux_criterion(aux_fake[b], objects[b]).mean())
                    sample_loss += aux_loss
                    aux_losses.append(aux_loss)

                d_loss.append(sample_loss)

        d_loss = torch.stack(d_loss).mean()

        if len(aux_losses) > 0:
            aux_losses = torch.stack(aux_losses).mean()
        else:
            aux_losses = 0

        return d_loss, aux_losses

    def _generator_masked_loss(self,
                               d_fake,
                               aux_fake,
                               objects,
                               aux_reg,
                               mu,
                               logvar,
                               mask,
                               seg_reg=0,
                               seg_fake=None,
                               seg_gt=None):
        """Accumulates losses only for sequences that have not ended
        to avoid back-propagation through padding
        Append the segmentation loss to the model.
        seg_fake: (1*C*H*W)
        seg_gt: (1*H*W)
        """
        g_loss = []
        for b, ended in enumerate(mask):
            if not ended:
                sample_loss = self.criterion.generator(d_fake[b])
                if aux_reg > 0:
                    aux_loss = aux_reg * \
                        self.aux_criterion(aux_fake[b], objects[b]).mean()
                else:
                    aux_loss = 0
                if mu is not None:
                    kl_loss = self.cfg.cond_kl_reg * \
                        kl_penalty(mu[b], logvar[b])
                else:
                    kl_loss = 0
                #Append a seg_loss to the total generator loss
                if seg_reg > 0:
                    #TODO: Implement the Segmentation Loss here
                    seg_loss = seg_reg * self.seg_criterion(
                        seg_fake[b], seg_gt[b]
                    )  #By default it should just give a mean number
                    #print(seg_loss)
                else:
                    seg_loss = 0

                g_loss.append(sample_loss + aux_loss + kl_loss + seg_loss)

        g_loss = torch.stack(g_loss)
        return g_loss.mean()

    def _masked_gradient_penalty(self, d_real, real_images, mask):
        gp_reg = gradient_penalty(d_real, real_images).mean()
        return gp_reg

    # region Helpers
    def _plot_losses(self, visualizer, g_loss, d_loss, aux_loss, iteration):
        _recurrent_gan._plot_losses(self, visualizer, g_loss, d_loss, aux_loss,
                                    iteration)

    def _plot_gradients(self, visualizer, rnn, gen, disc, gru, ce, ie,
                        iteration):
        _recurrent_gan._plot_gradients(self, visualizer, rnn, gen, disc, gru,
                                       ce, ie, iteration)

    def _draw_images(self, visualizer, real, fake, nrow):
        _recurrent_gan.draw_images_gandraw(self, visualizer, real, fake,
                                           nrow)  # Changed by Mingyang Zhou

    def _save(self, fake, path, epoch, iteration):
        _recurrent_gan._save(self, fake, path, epoch, iteration)

    def save_model(self, path, epoch, iteration):
        _recurrent_gan.save_model(self, path, epoch, iteration)

    def load_model(self, snapshot_path):
        _recurrent_gan.load_model(self, snapshot_path)

    def unormalize(self, x):
        """
        unormalize the image
        """
        new_x = self.unorm(x)
        new_x = transforms.ToPILImage()(new_x).convert('RGB')
        # new_x = np.array(new_x)[..., ::-1]
        new_x = np.moveaxis(np.array(new_x), -1, 0)
        return new_x

    def unormalize_segmentation(self, x):
        new_x = (x + 1) * 127.5
        # new_x = new_x.transpose(1, 2, 0)[..., ::-1]
        return new_x

    def unormalize_segmentation_onehot(self, x):
        """
        Convert the segmentation into image
        """

        LABEL2COLOR = {
            0: {
                "name": "sky",
                "color": np.array([134, 193, 46])
            },
            1: {
                "name": "dirt",
                "color": np.array([30, 22, 100])
            },
            2: {
                "name": "gravel",
                "color": np.array([163, 164, 153])
            },
            3: {
                "name": "mud",
                "color": np.array([35, 90, 74])
            },
            4: {
                "name": "sand",
                "color": np.array([196, 15, 241])
            },
            5: {
                "name": "clouds",
                "color": np.array([198, 182, 115])
            },
            6: {
                "name": "fog",
                "color": np.array([76, 60, 231])
            },
            7: {
                "name": "hill",
                "color": np.array([190, 128, 82])
            },
            8: {
                "name": "mountain",
                "color": np.array([122, 101, 17])
            },
            9: {
                "name": "river",
                "color": np.array([97, 140, 33])
            },
            10: {
                "name": "rock",
                "color": np.array([90, 90, 81])
            },
            11: {
                "name": "sea",
                "color": np.array([255, 252, 51])
            },
            12: {
                "name": "snow",
                "color": np.array([51, 255, 252])
            },
            13: {
                "name": "stone",
                "color": np.array([106, 107, 97])
            },
            14: {
                "name": "water",
                "color": np.array([0, 255, 0])
            },
            15: {
                "name": "bush",
                "color": np.array([204, 113, 46])
            },
            16: {
                "name": "flower",
                "color": np.array([0, 0, 255])
            },
            17: {
                "name": "grass",
                "color": np.array([255, 0, 0])
            },
            18: {
                "name": "straw",
                "color": np.array([255, 51, 252])
            },
            19: {
                "name": "tree",
                "color": np.array([255, 51, 175])
            },
            20: {
                "name": "wood",
                "color": np.array([66, 18, 120])
            },
            21: {
                "name": "road",
                "color": np.array([255, 255, 0])
            },
        }
        seg_map = np.argmax(x, axis=0)
        new_x = np.zeros((3, seg_map.shape[0], seg_map.shape[1]),
                         dtype=np.uint8)
        for i in range(seg_map.shape[0]):
            for j in range(seg_map.shape[1]):
                new_x[:, i, j] = LABEL2COLOR[seg_map[i, j]]["color"]
        return new_x
示例#12
0
class Trainer(object):
    """
    Trainer class
    """
    def __init__(self,
                 chkpt_path: str,
                 config: TrainerConfig,
                 train: str,
                 test: str,
                 dev: str = None,
                 disable_dataparallel: bool = False):
        """
        Instantiate trainer

        :param str chkpt_path: Path to checkpoint the model, optimizer and scheduler
        :param TrainerConfig config: Configuration instance for Trainer
        :param str train: Path to JSON with lines file which contains the training set
        :param str test: Path to JSON with lines file which contains the evaluation set
        :param str dev: Path to JSON with lines file which contains the development set (optional)
        :param bool disable_dataparallel:
            True if module should not be parallelized across different GPU devices. False by default.
        """
        # Register configuration
        self._config = config
        self.disable_dataparallel = disable_dataparallel

        # Prepare internal states
        self._best_on_dev = 0.0  #: Best score on the development set
        self._ema_on_dev = None  #: Exponential Moving Average score on the development set.
        self._random_restored = False  #: Whether the RNG state restored or not

        # Epoch & step information
        self._epoch = 0
        self._steps_to_go = 0
        self._step_per_epoch = 0
        self._minibatch_per_epoch = 0

        # Dictionary that records the last performance metrics
        self._last_performances = {}
        self._last_metrics = {}

        # Prepare checkpointing
        self._chkpt_path = Path(chkpt_path)
        if not self._chkpt_path.exists():
            self._chkpt_path.mkdir(parents=True)

        # Logging file handler
        file_handler = logging.FileHandler(filename=Path(
            chkpt_path, 'train.log'),
                                           encoding='UTF-8')
        file_handler.setFormatter(
            logging.Formatter(
                '[%(asctime)s] %(levelname)s %(name)s: %(message)s',
                datefmt='%m/%d/%Y %H:%M:%S'))
        file_handler.setLevel(logging.INFO)

        # Set the logger
        self._logger = logging.getLogger(self.__class__.__name__ +
                                         '_%s' % id(self))
        self._logger.addHandler(file_handler)
        self._logger.setLevel(logging.INFO)

        # If DEBUG is on, turn on the anomaly detection
        if 'DEBUG' in ENV:
            torch.autograd.set_detect_anomaly(True)

        # Prepare Tensorboard if available.
        try:
            from tensorboardX import SummaryWriter
            self._writer = SummaryWriter(logdir=str(self._chkpt_path),
                                         flush_secs=30)
        except ImportError:
            self._writer = None

        # Prepare data-parallel if available.
        if torch.cuda.is_available():
            devices = get_available_device_count()
            cuda_keys = list(range(devices))
            random.shuffle(cuda_keys)

            self.main_device = torch.device('cuda', cuda_keys[0])
            self.device_order = cuda_keys
        else:
            self.main_device = torch.device('cpu')
            self.device_order = [self.main_device]
        self._logger.info(
            "We will use [%s] device as a main device for training, with ordering [%s]",
            self.main_device, self.device_order)

        # Read the datasets
        self.set_seed(
        )  #: Set seed before loading the datasets (because of shuffling in training set)
        self.trainset, self.devset, self.evalset = self._config.read_datasets(
            train=train, dev=dev, test=test)
        self._trainit = iter(self.trainset)

        # Log dataset statistics
        self._logger.info('From %s, we loaded %s mini-batch(es)', train,
                          len(self.trainset))
        self._logger.info('From %s, we loaded %s mini-batch(es)', dev,
                          len(self.devset))
        self._logger.info('From %s, we loaded %s mini-batch(es)', test,
                          len(self.evalset))
        self.trainset.print_item_statistics(self._logger)

        # Build or restore module
        self._module = None
        self._module_init = {}
        self._optimizer = None
        self._answer_checker = None
        self.restore_checkpoint()

    @property
    def checkpoints(self) -> List[Path]:
        """
        :rtype: List[Path]
        :return: List of checkpointed steps (dictionaries)
        """
        checkpoints = sorted(Path(self._chkpt_path).glob('*'))
        checkpoints = [
            x for x in checkpoints if x.is_dir() and x.name.isnumeric()
        ]
        return checkpoints

    @property
    def last_checkpoint(self) -> Path:
        """
        :rtype: Path
        :return: The last checkpoint if exists. Otherwise, None
        """
        return self.checkpoints[-1] if len(self.checkpoints) else None

    @property
    def current_epoch(self) -> int:
        """
        :rtype: int
        :return: Current epoch index
        """
        return self._epoch

    @property
    def is_done(self) -> bool:
        """
        :rtype: bool
        :return: True if trainer already reached maximum epoch specified.
        """
        return self._epoch == self._config.epoch

    def close(self):
        """
        Close and clean-up the trainer.
        """
        if self._writer is not None:
            # Close the TensorboardX
            self._writer.close()
            self._writer = None
        if self._answer_checker is not None:
            # Kill the answer checker child processes
            self._answer_checker.close()
            self._answer_checker = None

    def rotate_checkpoint(self, max_item: int = 10):
        """
        Rotate checkpoints

        :param int max_item: Maximum number of allowed checkpoints
        """
        # Check if we should delete older checkpoint(s)
        if len(self.checkpoints) <= max_item:
            return

        for chkpt in self.checkpoints[:-max_item]:
            # Remove old checkpoints
            self._logger.info("Deleting old checkpoint [%s]", chkpt)
            shutil.rmtree(chkpt)

    def checkpoint(self):
        """
        Make a checkpoint
        """
        # Build dictionary format to make the order directory names and the order of epoch index be the same.
        directory_format = '%%0%dd' % int(
            math.ceil(math.log10(self._config.epoch + 1)))
        # If directory exists, exit the method.
        output_dir = Path(self._chkpt_path, directory_format % self._epoch)
        if output_dir.exists():
            return

        # Prepare the directory for checkpointing
        self._logger.info("Save checkpoint to [%s]", output_dir)
        output_dir.mkdir(parents=True)

        # Save the all RNG states used in this trainer.
        torch.save(
            {
                'numpy': numpy.random.get_state(),
                'random': random.getstate(),
                'trainset': self.trainset.get_rng_state(),
                'torch': {
                    'cpu':
                    torch.get_rng_state(),
                    'cuda':
                    torch.cuda.get_rng_state_all()
                    if torch.cuda.is_available() else None
                }
            }, Path(output_dir, 'random.pt'))

        # Save Trainer's internal states
        torch.save(
            {
                '_best_on_dev': self._best_on_dev,
                '_ema_on_dev': self._ema_on_dev,
                '_last_performances': self._last_performances,
                '_last_metrics': self._last_metrics
            }, Path(output_dir, 'internal.pt'))

        # Save the model
        _unwrap_parallel(self._module).save_pretrained(output_dir)
        # Save the optimizer
        torch.save(self._optimizer.state_dict(),
                   Path(output_dir, 'optimizer.pt'))

        # Save the scheduler if available.
        if hasattr(self, '_scheduler'):
            torch.save(self._scheduler.state_dict(),
                       Path(output_dir, 'scheduler.pt'))

        # Write configuration that has been used.
        self._config.save_pretrained(output_dir)
        # Rotate checkpoints.
        self.rotate_checkpoint()

    def restore_checkpoint(self):
        """
        Restore from the last checkpoint if available. Otherwise, configure this trainer from the scratch.
        """
        # Check if there exists any checkpoints.
        chkpt_path = self.last_checkpoint
        if chkpt_path:
            # reload configuration from the checkpoint
            self._config = TrainerConfig.from_pretrained(str(chkpt_path))
            self._logger.info("TrainerConfig at [%s] is restored.", chkpt_path)

            # Recover random number generator states
            self.set_seed()  # Set seed before restoring RNG
            random_path = Path(chkpt_path, 'random.pt')
            random_states = torch.load(random_path)
            numpy.random.set_state(random_states['numpy'])
            random.setstate(random_states['random'])
            self.trainset.set_rng_state(random_states['trainset'])

            torch.set_rng_state(random_states['torch']['cpu'])
            if torch.cuda.is_available():
                torch.cuda.set_rng_state_all(random_states['torch']['cuda'])

            # Record that the RNG is restored.
            self._logger.info(
                "State of random number generator is restored from [%s]",
                random_path)
            self._random_restored = True

            # Recover the trainer's internal states
            internal_states = torch.load(Path(chkpt_path, 'internal.pt'))
            for key, value in internal_states.items():
                if hasattr(self, key):
                    setattr(self, key, value)
        else:
            self.set_seed()  # Set seed.

        # Build/restore model
        self._config.model.set_chkpt_path(chkpt_path)
        self._module = Solver.from_pretrained(config=self._config.model)
        self._module_init = {
            id(p): p.clone()
            for p in self._module.parameters()
        }
        self._module.to(self.main_device)
        self._logger.info("A network at [%s] is restored.", chkpt_path)

        # Compute the epoch/step information
        self._minibatch_per_epoch = len(self.trainset)
        self._step_per_epoch = int(
            math.ceil(self._minibatch_per_epoch /
                      self._config.gradient_accumulation_steps))
        self._steps_to_go = self._step_per_epoch * self._config.epoch
        self._logger.info("Steps / Epoch = %5d", self._step_per_epoch)
        self._logger.info("We will run %3d epoch(s) or %6d step(s)",
                          self._config.epoch, self._steps_to_go)
        self._logger.info(
            "Per a single step, %2d gradient(s) will be accumulated. (Total %2d mini-batch(es)/epoch)",
            self._config.gradient_accumulation_steps,
            self._minibatch_per_epoch)
        self._logger.info(
            "We will report TRAINING loss/accuracy for every %3d epoch(s)",
            self._config.epoch_report)
        self._logger.info(
            "We will report DEV ACC. and save CHKPTs for every %3d epoch(s)",
            self._config.epoch_chkpt)

        # Restore the number of steps that were passed before
        if chkpt_path:
            self._epoch = int(chkpt_path.name)
            self._logger.info("Attempt to restore from the checkpoint [%s]",
                              chkpt_path)
            self._logger.info("Resume training from epoch %s", self._epoch)

        # Classify parameters to form parameter groups to build optimizer
        no_w_decay = {'bias', 'norm', 'Norm', '_embedding'}
        parameters = [((2 if 'text_model.model.embeddings' in n else
                        (1 if 'text_model' in n else 0),
                        any(t in n for t in no_w_decay)), p)
                      for n, p in self._module.named_parameters()]
        parameters = groupby(sorted(parameters, key=lambda t: t[0]),
                             key=lambda t: t[0])

        # Build optimizer groups
        optimizer_grouped_parameters = []
        for (encoder_type_flag, is_without_wd), group in parameters:
            group = {'params': [p for _, p in group]}

            if is_without_wd:
                group['weight_decay'] = 0.0

            if encoder_type_flag == 2 and self._config.fix_encoder_embedding:
                group['lr'] = 0.0
            elif encoder_type_flag == 1:
                group['lr'] = self._config.optimizer.kwargs[
                    'lr'] * self._config.lr_multiplier_encoder

            optimizer_grouped_parameters.append(group)

        # Build optimizer before restoration
        self._optimizer = self._config.optimizer.build(
            optimizer_grouped_parameters)
        self._logger.info("We will use the following optimizer: %s",
                          self._optimizer)

        # Restore the optimizer if available.
        if chkpt_path:
            # Check if saved optimizer exists
            optimizer_file = Path(chkpt_path, 'optimizer.pt')
            if optimizer_file.is_file():
                self._optimizer.load_state_dict(torch.load(optimizer_file))
                self._logger.info(
                    "An optimizer for module at [%s] is restored.",
                    optimizer_file)

        # Specify warmup strategy if warmup value is not negative
        warmup_steps = int(self._step_per_epoch * self._config.epoch_warmup)
        if warmup_steps >= 0:
            # Build scheduler before restoration
            self._scheduler = get_linear_schedule_with_warmup(
                self._optimizer,
                num_warmup_steps=warmup_steps,
                num_training_steps=self._steps_to_go)
            self._logger.info(
                "We will use linear scheduling: warm up %s epochs or %s steps",
                self._config.epoch_warmup, warmup_steps)

            # Restore the scheduler if available
            if chkpt_path:
                # Check if saved scheduler exists
                scheduler_file = Path(chkpt_path, 'scheduler.pt')
                if scheduler_file.is_file():
                    self._scheduler.load_state_dict(torch.load(scheduler_file))
                    self._logger.info(
                        "A scheduler for module at [%s] is restored.",
                        scheduler_file)

        # Log the threshold of gradient clipping.
        if self._config.gradient_clip > 0:
            self._logger.info("We will use gradient clipping at %.3f",
                              self._config.gradient_clip)
        else:
            self._logger.info("We will not use gradient clipping")

        # Log the structure of the network.
        parameters_size = sum(p.numel() for p in self._module.parameters())
        disk_space = sum(
            required_space_param(p) for p in self._module.parameters())
        self._logger.info('==== [Network Structure] ====\n%s',
                          str(self._module))
        self._logger.info(
            'There are %12d parameters in a network. Required space for checkpointing is %.3fMB.',
            parameters_size, disk_space / 1048576)

        # Wrap data parallel if we can use more than one GPU
        if len(self.device_order) > 1 and not self.disable_dataparallel:
            self._module = DataParallel(self._module,
                                        device_ids=self.device_order,
                                        output_device=self.device_order[0])
            self._logger.info(
                "We identified [%s] devices for parallel training",
                len(self.device_order))
        else:
            self._logger.info("We don't use DataParallel.")

        # Set answer checker
        self._answer_checker = AnswerChecker(
            is_expression_type=_unwrap_parallel(
                self._module).is_expression_type,
            logger=self._logger)

    def set_seed(self):
        """
        Set the random seeds
        """
        if self._random_restored:
            # Ignore seed setting when state of rng was restored.
            return

        seed = self._config.seed
        self._logger.info("Seed for random number generation = %s", seed)

        random.seed(seed)
        numpy.random.seed(seed)
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)

    def get_evaluation_output(self, key: str):
        """
        Get the evaluation output of specified key.

        :param str key: metric key to read
        :return: metric value of specified key
        """
        return self._last_performances[key]

    def get_metrics(self) -> dict:
        """
        :return: The latest metric dictionary.
        """
        return self._last_metrics

    def run_a_chkpt_iter(self):
        """
        Run epochs until checkpointing
        """

        try:
            accumulated_values = {}

            for _ in range(self._config.epoch_chkpt):
                # For each epoch (at most the number of checkpointing epoch)
                self._epoch += 1

                all_grad_applied = True
                for batch_step in range(self._minibatch_per_epoch):
                    # For each minibatch
                    self._module.eval()
                    self._module.zero_grad()

                    # Load a minibatch
                    batch = next(self._trainit)
                    batch = self._load_batch(batch)

                    # Execute training
                    self._module.train()
                    reported_values = self._step(**batch)
                    reported_values['Loss/generate'] = reported_values[
                        'total_loss']
                    reported_values['total_loss'].backward()
                    all_grad_applied = False

                    # Accumulate statistics and update gradient
                    _accumulate_stats(reported_values, accumulated_values)
                    if (batch_step +
                            1) % self._config.gradient_accumulation_steps == 0:
                        self._update_grad()
                        all_grad_applied = True
                else:
                    # If there exists not-updated gradients, update gradient
                    if not all_grad_applied:
                        self._update_grad()

                if self._config.epoch_report > 0 and self._epoch % self._config.epoch_report == 0:
                    # Log metrics
                    if self._writer is not None:
                        for name, val in accumulated_values.items():
                            self._writer.add_scalar(name,
                                                    sum(val) / len(val),
                                                    self._epoch)
                        # Report current optimizer status
                        self._report_optimizer()

                    accumulated_values.clear()

            # Evaluate current result on development set
            self.evaluate()
            self.checkpoint()
        except Exception as e:
            self._logger.error('Exception occurred!', exc_info=e)
            raise e

    def train(self):
        """
        Do full-length training (until the maximum epoch)
        """
        # Set seed
        self.set_seed()

        # Prepare estimated time calculator class
        eta = ExpectedTimeToFinishCalculator(self._config.epoch,
                                             current=self._epoch)
        while self._epoch < self._config.epoch:
            self.run_a_chkpt_iter()
            eta_time = eta.step(increase=self._config.epoch_chkpt)
            self._logger.info('Expected time to finish: %s', eta_time)

        # Evaluate performance on the evaluation set
        try:
            self.evaluate(is_development=False)
        except Exception as e:
            self._logger.error('Exception occurred!', exc_info=e)
            raise e
        finally:
            # Remove old checkpoints and close Tensorboard writer
            self.rotate_checkpoint(1)

    def _update_grad(self):
        """
        Update accumulated gradients
        """
        if self._config.gradient_clip > 0:
            # If clipping threshold is set, then clip the gradient
            torch.nn.utils.clip_grad_norm_(self._module.parameters(),
                                           self._config.gradient_clip)

        if self._config.gradient_normalize:
            # If normalizing gradient is set, then normalize the gradient
            _normalize_gradients(*self._module.parameters())

        # Apply optimizer & scheduler
        self._optimizer.step()
        if hasattr(self, '_scheduler'):
            self._scheduler.step()

        # Reset the gradient
        self._module.zero_grad()

    def _load_batch(self,
                    batch: ProblemInstance,
                    is_training=True,
                    max_len=0) -> dict:
        """
        Load batch instance into dictionary that can feed-able into the model.

        :param ProblemInstance batch: A mini-batch
        :param bool is_training: True if this batch is used for training. True by default.
        :param int max_len: Maximum length of equation to be generated. 0 by default (i.e. depends on the current batch)
        :rtype: dict
        :return: Dictionary representing mini-batch
        """
        # Prepare dictionary
        batch_dict = {
            'max_numbers':
            max(len(numbers) for numbers in batch.text.number_value),
            IN_TXT: batch.text.token,
            IN_TPAD: batch.text.pad,
            IN_TNUM: batch.text.number
        }

        # Retrieve information about the target field
        required_field = _unwrap_parallel(self._module).required_field
        # Get equation in terms of the target field
        equation = getattr(batch, required_field)
        if is_training:
            # If this is training, then directly provide target equation for teacher-forcing
            batch_dict[IN_EQN] = equation
        else:
            # Otherwise, just provide information about maximum length of generation & arity of operators
            batch_dict['max_len'] = max(equation.shape[-2], max_len) + 1
            if required_field.startswith('tuple'):
                batch_dict['function_arities'] = getattr(
                    self.evalset, required_field + '_field').function_arities

        if not isinstance(self._module, DataParallel):
            # If we applied data parallel, then move the value to the main device
            batch_dict = {
                k: v.to(self.main_device) if isinstance(v, torch.Tensor) else v
                for k, v in batch_dict.items()
            }

        # Returned value is a dict.
        return batch_dict

    def _step(self, training: bool = True, **kwargs):
        """
        Execute forward computation of the module

        :param bool training: True if this execution is for training. True by default.
        :param kwargs: Keyword arguments to execute the module.
        :return: Result of execution.
            - If training is True, return value will be a dictionary mapping from string to accuracy/loss Tensors.
            - Otherwise, return value will be a LongTensor indicating the generated tokens
        """
        result = self._module(**kwargs)
        if type(result) is dict and training:
            return {k: v.mean() if training else v for k, v in result.items()}
        else:
            return result

    def _report_optimizer(self):
        """
        Report the current state of the optimizer
        """
        # Classify parameters by their types
        param_type = {
            id(p): ('Enc' if 'text_model.' in n else 'Dec') +
            ('Embed' if '_embedding' in n else 'Trans')
            for n, p in _unwrap_parallel(self._module).named_parameters()
        }
        # Dictionary for accumulating parameter information
        param_states = {
            key: {
                'weight_norm': [],
                'acc_update': []
            }
            for key in set(param_type.values())
        }

        with torch.no_grad():
            # Without using gradients, accumulate information about weight and gradient
            for gid, group in enumerate(self._optimizer.param_groups):
                for p in group['params']:
                    id_p = id(p)
                    states = param_states[param_type[id_p]]
                    w_init = self._module_init[id_p]

                    w_elem = p.numel()
                    w_norm = p.norm(2).item() / w_elem
                    delta_norm = (w_init -
                                  p.clone().cpu()).norm(2).item() / w_elem

                    states['weight_norm'].append(w_norm)
                    states['acc_update'].append(delta_norm)

        # Write accumulated results
        if self._writer:
            for part, states in param_states.items():
                prefix = 'Optimizer_%s/%%s' % part

                for key, val in states.items():
                    if not len(val):
                        continue

                    # Track average & histograms
                    val = numpy.array(val)
                    self._writer.add_scalar(prefix % key, val.mean(),
                                            self._epoch)
                    self._writer.add_scalar(prefix % (key + '_std'), val.std(),
                                            self._epoch)

    def _check_equation(self, checker: AnswerChecker, outputs: torch.Tensor,
                        batch: ProblemInstance):
        """
        Verify whether the outputted equation is correct or not.

        :param AnswerChecker checker: AnswerChecker instance to compute equation and check answer
        :param torch.Tensor outputs:
            LongTensor containing generated equations.
            - If the model should generate op-tokens, Shape = [B, M, T], where B = batch size, M = beams, and T = length
            - Otherwise, Shape = [B, M, T, 1+2A], where A = maximum arity.
        :param batch:
        :return:
        """
        # Retrieve size information
        batch_sz, beam_sz = outputs.shape[:2]

        # Get the target field information
        required_field = _unwrap_parallel(self._module).required_field
        # Retrieve the target field
        field = getattr(self.evalset, required_field + '_field')
        # Recover string representation of gold set and generated beams
        golds = field.convert_ids_to_equations(getattr(batch, required_field))
        beams = [
            field.convert_ids_to_equations(outputs[i]) for i in range(batch_sz)
        ]

        outputs = []
        for i in range(batch_sz):
            # For each batch, retrieve information about written numbers and expected answer tuples
            numbers = batch.text.number_value[i]
            expected = batch.expected[i]

            # Test whether the produced equation in each beam
            results = [
                checker.check(beam, numbers, expected) for beam in beams[i]
            ]
            # Record outputs: (index, goldset output, generated output, correctness)
            outputs.append((i, golds[i], beams[i], results))

        return outputs

    def evaluate(self, is_development: bool = True):
        """
        Evaluate the current model.

        :param bool is_development: True if current evaluation is done on development set. True by default.
        """
        # Shortcut for beam size
        beam_size = self._config.model.beam_size
        # Accumulator for output
        accumulator = []

        # Define log storage for information
        set_type = 'Dev' if is_development else 'Test'
        errored_path = Path(self._chkpt_path, 'error_sample_%s.log' % set_type)
        correct_path = Path(self._chkpt_path,
                            'correct_sample_%s.log' % set_type)
        result_path = Path(self._chkpt_path, 'results.csv')

        # Check whether we should write header or not.
        first_result_output = not result_path.exists()

        # Open file handlers
        errored_fp = errored_path.open('w+t', encoding='UTF-8')
        correct_fp = correct_path.open('w+t', encoding='UTF-8')
        result_fp = result_path.open('a+t', encoding='UTF-8')

        # Set module as evaluation phase
        self._module.eval()

        # Load dataset
        dataset = self.devset if is_development else self.evalset
        max_len = 0 if is_development else MEM_MAX
        for batch in dataset:
            # For each batch item, load it and produce outputs
            kwargs = self._load_batch(batch,
                                      is_training=False,
                                      max_len=max_len)
            outputs = self._step(**kwargs, training=False, beam=beam_size)

            # Convert text into string (for printing purpose)
            texts = dataset.problem_field.convert_ids_to_string(
                batch.text.token)

            # Check the result and print the result for each item.
            for i, gold, beams, results in self._check_equation(
                    self._answer_checker, outputs, batch):
                # Record the best output of the beam search results
                result_dict = {
                    'Index':
                    batch.index[i],
                    'Error':
                    str(type(results[0][2])),
                    'correct':
                    results[0][0],
                    'error_1_Parse':
                    results[0][2] is not None,
                    'error_2_Empty':
                    len(results[0][1]) == 0 and results[0][2] is None,
                    'error_3_Match':
                    not results[0][0] and len(results[0][1]) > 0
                    and results[0][2] is None,
                    'correct_in_beam':
                    any(r[0] for r in results)
                }

                # Accumulate the test result.
                accumulator.append(result_dict)

                # Select appropriate file handler
                fp = errored_fp if not result_dict['correct'] else correct_fp
                # Write problem & result
                fp.writelines([
                    '[Q] ', batch.index[i], '\n', texts[i], '\n',
                    '---------------------------------------\n',
                    '[EXPECTED]\t%s\n' % ' '.join(gold),
                    '---ANSWER:\t%s\n' % batch.expected[i],
                    '---------------------------------------\n'
                ])
                fp.writelines([
                    '[BEAM#%3d]\t%s\n'
                    '---ANSWER:\t%s\n%s' %
                    (b, ' '.join(beam), res[1],
                     '' if res[2] is None else '----ERROR:\t%s %s\n' %
                     (type(res[2]), str(res[2])))
                    for b, (beam, res) in enumerate(zip(beams, results))
                ])
                fp.write('\n')

        # Close file handlers
        errored_fp.close()
        correct_fp.close()

        # Write CSV results
        sorted_keys = sorted(accumulator[0].keys())
        # Write CSV header
        if first_result_output:
            _write_csv_line(result_fp, 'Set', 'GlobalStep', 'Beam',
                            *sorted_keys)

        # Write CSV results
        for values in accumulator:
            _write_csv_line(result_fp, set_type, self._epoch, beam_size,
                            *[values[key] for key in sorted_keys])

        # Close CSV handler
        result_fp.close()

        # Average metric across items (correctness & errors)
        metric_dict = {}
        for key in sorted_keys:
            value = [item[key] for item in accumulator]

            if type(value[0]) is not str:
                average = sum(value) / len(value)

                # Write accumulated results
                self._logger.info('Evaluating on %s (beam %s): %s = %.6f',
                                  set_type, beam_size, key, average)
                metric_dict[set_type + '/' + key] = average

        # Reset the dataset (since dataset reached EOF)
        dataset.reset()

        # Write exponential moving average & maximum value into metric dict
        if is_development:
            self._best_on_dev = max(self._best_on_dev,
                                    metric_dict['Dev/correct'])
            if self._ema_on_dev is None:
                self._ema_on_dev = metric_dict['Dev/correct']
            else:
                self._ema_on_dev = metric_dict[
                    'Dev/correct'] * 0.6 + self._ema_on_dev * 0.4

            metric_dict['Dev/correct_max'] = self._best_on_dev
            metric_dict['Dev/correct_ema'] = self._ema_on_dev

        # Record last output
        self._last_performances[set_type] = [
            item['correct']
            for item in sorted(accumulator, key=lambda d: d['Index'])
        ]
        self._last_metrics.update(metric_dict)
示例#13
0
class RecurrentGAN():
    def __init__(self, cfg):
        """A recurrent GAN model, each time step a generated image
        (x'_{t-1}) and the current question q_{t} are fed to the RNN
        to produce the conditioning vector for the GAN.
        The following equations describe this model:

            - c_{t} = RNN(h_{t-1}, q_{t}, x^{~}_{t-1})
            - x^{~}_{t} = G(z | c_{t})
        """
        super(RecurrentGAN, self).__init__()

        # region Models-Instantiation

        self.generator = DataParallel(
            GeneratorFactory.create_instance(cfg)).cuda()

        self.discriminator = DataParallel(
            DiscriminatorFactory.create_instance(cfg)).cuda()

        self.rnn = nn.DataParallel(nn.GRU(cfg.input_dim,
                                          cfg.hidden_dim,
                                          batch_first=False),
                                   dim=1).cuda()

        self.layer_norm = nn.DataParallel(nn.LayerNorm(cfg.hidden_dim)).cuda()

        self.image_encoder = DataParallel(ImageEncoder(cfg)).cuda()

        self.condition_encoder = DataParallel(ConditionEncoder(cfg)).cuda()

        self.sentence_encoder = nn.DataParallel(SentenceEncoder(cfg)).cuda()

        # endregion

        # region Optimizers

        self.generator_optimizer = OPTIM[cfg.generator_optimizer](
            self.generator.parameters(), cfg.generator_lr, cfg.generator_beta1,
            cfg.generator_beta2, cfg.generator_weight_decay)

        self.discriminator_optimizer = OPTIM[cfg.discriminator_optimizer](
            self.discriminator.parameters(), cfg.discriminator_lr,
            cfg.discriminator_beta1, cfg.discriminator_beta2,
            cfg.discriminator_weight_decay)

        self.rnn_optimizer = OPTIM[cfg.rnn_optimizer](self.rnn.parameters(),
                                                      cfg.rnn_lr)

        self.sentence_encoder_optimizer = OPTIM[cfg.gru_optimizer](
            self.sentence_encoder.parameters(), cfg.gru_lr)

        self.use_image_encoder = cfg.use_fg
        feature_encoding_params = list(self.condition_encoder.parameters())
        if self.use_image_encoder:
            feature_encoding_params += list(self.image_encoder.parameters())

        self.feature_encoders_optimizer = OPTIM['adam'](
            feature_encoding_params, cfg.feature_encoder_lr)

        # endregion

        # region Criterion

        self.criterion = LOSSES[cfg.criterion]()
        self.aux_criterion = DataParallel(torch.nn.BCELoss()).cuda()

        # endregion

        self.cfg = cfg
        self.logger = Logger(cfg.log_path, cfg.exp_name)

        # CTR
        self.ctr = CTR(cfg)

    def train_ctr(self,
                  batch,
                  epoch,
                  iteration,
                  visualizer,
                  logger,
                  is_eval=False,
                  is_infer=False):

        batch_size = len(batch['image'])
        max_seq_len = batch['image'].size(1)

        prev_image = torch.FloatTensor(batch['background'])
        prev_image = prev_image.unsqueeze(0).repeat(batch_size, 1, 1, 1)
        disc_prev_image = prev_image

        hidden = torch.zeros(1, batch_size, self.cfg.hidden_dim)

        rec_loss, rec_out = [], []

        for t in range(max_seq_len):
            image = batch['image'][:, t]
            turns_word_embedding = batch['turn_word_embedding'][:, t]
            turns_word = batch['turn_word'][:, t]
            turns_lengths = batch['turn_lengths'][:, t]
            objects = batch['objects'][:, t]
            seq_ended = t > (batch['dialog_length'] - 1)

            # update ctr
            out_ctr, loss_ctr = self.ctr.update_ctr(disc_prev_image, image,
                                                    hidden,
                                                    turns_word_embedding,
                                                    turns_word, is_eval)
            rec_loss.append(loss_ctr.detach().cpu().numpy())
            rec_out.append(out_ctr.unsqueeze(1))

            image_feature_map, image_vec, object_detections = self.image_encoder(
                prev_image)
            _, current_image_feat, _ = self.image_encoder(image)
            turn_embedding = self.sentence_encoder(turns_word_embedding,
                                                   turns_lengths)
            rnn_condition, current_image_feat = self.condition_encoder(
                turn_embedding, image_vec, current_image_feat)
            rnn_condition = rnn_condition.unsqueeze(0)

            output, hidden = self.rnn(rnn_condition, hidden)

            prev_image = image
            disc_prev_image = image

        if iteration % self.cfg.save_rate == 0:
            path = os.path.join(self.cfg.log_path, self.cfg.exp_name)
            self.ctr.save_ctr(path, iteration)

        loss = np.average(rec_loss)

        if is_infer == True:
            rec_out = torch.cat(rec_out, dim=1)
            rec_out = rec_out.detach().cpu().numpy()

            return rec_out, loss

        else:
            return loss

    def infer_gen(self, batch):

        batch_size = len(batch['image'])
        max_seq_len = batch['image'].size(1)

        prev_image = torch.FloatTensor(batch['background'])
        prev_image = prev_image.unsqueeze(0).repeat(batch_size, 1, 1, 1)
        disc_prev_image = prev_image

        hidden = torch.zeros(1, batch_size, self.cfg.hidden_dim)

        rec_out = []

        for t in range(max_seq_len):
            image = batch['image'][:, t]
            turns_word_embedding = batch['turn_word_embedding'][:, t]
            turns_word = batch['turn_word'][:, t]
            turns_lengths = batch['turn_lengths'][:, t]
            objects = batch['objects'][:, t]
            seq_ended = t > (batch['dialog_length'] - 1)

            image_feature_map, image_vec, object_detections = self.image_encoder(
                prev_image)
            _, current_image_feat, _ = self.image_encoder(image)
            turn_embedding = self.sentence_encoder(turns_word_embedding,
                                                   turns_lengths)
            rnn_condition, current_image_feat = self.condition_encoder(
                turn_embedding, image_vec, current_image_feat)
            rnn_condition = rnn_condition.unsqueeze(0)

            output, hidden = self.rnn(rnn_condition, hidden)
            output = output.squeeze(0)
            output = self.layer_norm(output)

            fake_image, mu, logvar, sigma = self._forward_generator(
                batch_size, output.detach(), image_feature_map)

            prev_image = fake_image
            disc_prev_image = image

            rec_out.append(fake_image.unsqueeze(1))

        rec_out = torch.cat(rec_out, dim=1)
        rec_out = rec_out.detach().cpu().numpy()

        return rec_out

    def train_batch_with_ctr(self, batch, epoch, iteration, visualizer,
                             logger):
        """
        The training scheme follows the following:
            - Discriminator and Generator is updated every time step.
            - RNN, SentenceEncoder and ImageEncoder parameters are
            updated every sequence
        """

        batch_size = len(batch['image'])
        max_seq_len = batch['image'].size(1)

        prev_image = torch.FloatTensor(batch['background'])
        prev_image = prev_image.unsqueeze(0) \
            .repeat(batch_size, 1, 1, 1)
        disc_prev_image = prev_image

        # Initial inputs for the RNN set to zeros
        hidden = torch.zeros(1, batch_size, self.cfg.hidden_dim)
        prev_objects = torch.zeros(batch_size, self.cfg.num_objects)

        teller_images = []
        drawer_images = []
        added_entities = []

        for t in range(max_seq_len):
            image = batch['image'][:, t]
            turns_word_embedding = batch['turn_word_embedding'][:, t]
            turns_word = batch['turn_word'][:, t]
            turns_lengths = batch['turn_lengths'][:, t]
            objects = batch['objects'][:, t]
            seq_ended = t > (batch['dialog_length'] - 1)

            old_hidden = hidden

            image_feature_map, image_vec, object_detections = self.image_encoder(
                prev_image)
            _, current_image_feat, _ = self.image_encoder(image)

            turn_embedding = self.sentence_encoder(turns_word_embedding,
                                                   turns_lengths)
            rnn_condition, current_image_feat = \
                self.condition_encoder(turn_embedding,
                                       image_vec,
                                       current_image_feat)

            rnn_condition = rnn_condition.unsqueeze(0)

            output, hidden = self.rnn(rnn_condition, hidden)

            output = output.squeeze(0)
            output = self.layer_norm(output)

            fake_image, mu, logvar, sigma = self._forward_generator(
                batch_size, output.detach(), image_feature_map)

            _, loss_ctr = self.ctr.update_ctr(disc_prev_image,
                                              fake_image,
                                              old_hidden,
                                              turns_word_embedding,
                                              turns_word,
                                              is_eval=True)

            visualizer.track_sigma(sigma)

            hamming = objects - prev_objects
            hamming = torch.clamp(hamming, min=0)

            d_loss, d_real, d_fake, aux_loss, discriminator_gradient = \
                self._optimize_discriminator(image,
                                             fake_image.detach(),
                                             disc_prev_image,
                                             output,
                                             seq_ended,
                                             hamming,
                                             self.cfg.gp_reg,
                                             self.cfg.aux_reg)

            g_loss, generator_gradient = \
                self._optimize_generator(fake_image,
                                         disc_prev_image.detach(),
                                         output.detach(),
                                         objects,
                                         self.cfg.aux_reg,
                                         seq_ended,
                                         mu,
                                         logvar, loss_ctr=loss_ctr)

            if self.cfg.teacher_forcing:
                prev_image = image
            else:
                prev_image = fake_image

            disc_prev_image = image
            prev_objects = objects

            if (t + 1) % 2 == 0:
                prev_image = prev_image.detach()

            rnn_grads = []
            gru_grads = []
            condition_encoder_grads = []
            img_encoder_grads = []

            if t == max_seq_len - 1:
                rnn_gradient, gru_gradient, condition_gradient,\
                    img_encoder_gradient = self._optimize_rnn()

                rnn_grads.append(rnn_gradient.data.cpu().numpy())
                gru_grads.append(gru_gradient.data.cpu().numpy())
                condition_encoder_grads.append(
                    condition_gradient.data.cpu().numpy())

                if self.use_image_encoder:
                    img_encoder_grads.append(
                        img_encoder_gradient.data.cpu().numpy())

                visualizer.track(d_real, d_fake)

            hamming = hamming.data.cpu().numpy()[0]
            teller_images.extend(image[:4].data.numpy())
            drawer_images.extend(fake_image[:4].data.cpu().numpy())
            entities = str.join(',', list(batch['entities'][hamming > 0]))
            added_entities.append(entities)

        if iteration % self.cfg.vis_rate == 0:
            visualizer.histogram()
            self._plot_losses(visualizer, g_loss, d_loss, aux_loss, iteration)
            rnn_gradient = np.array(rnn_grads).mean()
            gru_gradient = np.array(gru_grads).mean()
            condition_gradient = np.array(condition_encoder_grads).mean()
            img_encoder_gradient = np.array(img_encoder_grads).mean()
            rnn_grads, gru_grads = [], []
            condition_encoder_grads, img_encoder_grads = [], []
            self._plot_gradients(visualizer, rnn_gradient, generator_gradient,
                                 discriminator_gradient, gru_gradient,
                                 condition_gradient, img_encoder_gradient,
                                 iteration)
            self._draw_images(visualizer, teller_images, drawer_images, nrow=4)
            self.logger.write(epoch, iteration, d_real, d_fake, d_loss, g_loss)

            if isinstance(batch['turn'], list):
                batch['turn'] = np.array(batch['turn']).transpose()

            visualizer.write(batch['turn'][0])
            visualizer.write(added_entities, var_name='entities')
            teller_images = []
            drawer_images = []

        if iteration % self.cfg.save_rate == 0:
            path = os.path.join(self.cfg.log_path, self.cfg.exp_name)

            self._save(fake_image[:4], path, epoch, iteration)
            if not self.cfg.debug:
                self.save_model(path, epoch, iteration)

    def train_batch(self, batch, epoch, iteration, visualizer, logger):
        """
        The training scheme follows the following:
            - Discriminator and Generator is updated every time step.
            - RNN, SentenceEncoder and ImageEncoder parameters are
            updated every sequence
        """

        batch_size = len(batch['image'])
        max_seq_len = batch['image'].size(1)

        prev_image = torch.FloatTensor(batch['background'])
        prev_image = prev_image.unsqueeze(0) \
            .repeat(batch_size, 1, 1, 1)
        disc_prev_image = prev_image

        # Initial inputs for the RNN set to zeros
        hidden = torch.zeros(1, batch_size, self.cfg.hidden_dim)
        prev_objects = torch.zeros(batch_size, self.cfg.num_objects)

        teller_images = []
        drawer_images = []
        added_entities = []

        for t in range(max_seq_len):
            image = batch['image'][:, t]
            turns_word_embedding = batch['turn_word_embedding'][:, t]
            turns_word = batch['turn_word'][:, t]
            turns_lengths = batch['turn_lengths'][:, t]
            objects = batch['objects'][:, t]
            seq_ended = t > (batch['dialog_length'] - 1)

            image_feature_map, image_vec, object_detections = self.image_encoder(
                prev_image)
            _, current_image_feat, _ = self.image_encoder(image)

            turn_embedding = self.sentence_encoder(turns_word_embedding,
                                                   turns_lengths)
            rnn_condition, current_image_feat = \
                self.condition_encoder(turn_embedding,
                                       image_vec,
                                       current_image_feat)

            rnn_condition = rnn_condition.unsqueeze(0)

            output, hidden = self.rnn(rnn_condition, hidden)

            output = output.squeeze(0)
            output = self.layer_norm(output)

            fake_image, mu, logvar, sigma = self._forward_generator(
                batch_size, output.detach(), image_feature_map)

            visualizer.track_sigma(sigma)

            hamming = objects - prev_objects
            hamming = torch.clamp(hamming, min=0)

            d_loss, d_real, d_fake, aux_loss, discriminator_gradient = \
                self._optimize_discriminator(image,
                                             fake_image.detach(),
                                             disc_prev_image,
                                             output,
                                             seq_ended,
                                             hamming,
                                             self.cfg.gp_reg,
                                             self.cfg.aux_reg)

            g_loss, generator_gradient = \
                self._optimize_generator(fake_image,
                                         disc_prev_image.detach(),
                                         output.detach(),
                                         objects,
                                         self.cfg.aux_reg,
                                         seq_ended,
                                         mu,
                                         logvar)

            if self.cfg.teacher_forcing:
                prev_image = image
            else:
                prev_image = fake_image

            disc_prev_image = image
            prev_objects = objects

            if (t + 1) % 2 == 0:
                prev_image = prev_image.detach()

            rnn_grads = []
            gru_grads = []
            condition_encoder_grads = []
            img_encoder_grads = []

            if t == max_seq_len - 1:
                rnn_gradient, gru_gradient, condition_gradient,\
                    img_encoder_gradient = self._optimize_rnn()

                rnn_grads.append(rnn_gradient.data.cpu().numpy())
                gru_grads.append(gru_gradient.data.cpu().numpy())
                condition_encoder_grads.append(
                    condition_gradient.data.cpu().numpy())

                if self.use_image_encoder:
                    img_encoder_grads.append(
                        img_encoder_gradient.data.cpu().numpy())

                visualizer.track(d_real, d_fake)

            hamming = hamming.data.cpu().numpy()[0]
            teller_images.extend(image[:4].data.numpy())
            drawer_images.extend(fake_image[:4].data.cpu().numpy())
            entities = str.join(',', list(batch['entities'][hamming > 0]))
            added_entities.append(entities)

        if iteration % self.cfg.vis_rate == 0:
            visualizer.histogram()
            self._plot_losses(visualizer, g_loss, d_loss, aux_loss, iteration)
            rnn_gradient = np.array(rnn_grads).mean()
            gru_gradient = np.array(gru_grads).mean()
            condition_gradient = np.array(condition_encoder_grads).mean()
            img_encoder_gradient = np.array(img_encoder_grads).mean()
            rnn_grads, gru_grads = [], []
            condition_encoder_grads, img_encoder_grads = [], []
            self._plot_gradients(visualizer, rnn_gradient, generator_gradient,
                                 discriminator_gradient, gru_gradient,
                                 condition_gradient, img_encoder_gradient,
                                 iteration)
            self._draw_images(visualizer, teller_images, drawer_images, nrow=4)
            self.logger.write(epoch, iteration, d_real, d_fake, d_loss, g_loss)

            if isinstance(batch['turn'], list):
                batch['turn'] = np.array(batch['turn']).transpose()

            visualizer.write(batch['turn'][0])
            visualizer.write(added_entities, var_name='entities')
            teller_images = []
            drawer_images = []

        if iteration % self.cfg.save_rate == 0:
            path = os.path.join(self.cfg.log_path, self.cfg.exp_name)

            self._save(fake_image[:4], path, epoch, iteration)
            if not self.cfg.debug:
                self.save_model(path, epoch, iteration)

    def _forward_generator(self, batch_size, condition, image_feature_maps):
        noise = torch.FloatTensor(batch_size,
                                  self.cfg.noise_dim).normal_(0, 1).cuda()

        fake_images, mu, logvar, sigma = self.generator(
            noise, condition, image_feature_maps)

        return fake_images, mu, logvar, sigma

    def _optimize_discriminator(self,
                                real_images,
                                fake_images,
                                prev_image,
                                condition,
                                mask,
                                objects,
                                gp_reg=0,
                                aux_reg=0):
        """Discriminator is updated every step independent of batch_size
        RNN and the generator
        """
        wrong_images = torch.cat((real_images[1:], real_images[0:1]), dim=0)
        wrong_prev = torch.cat((prev_image[1:], prev_image[0:1]), dim=0)

        self.discriminator.zero_grad()
        real_images.requires_grad_()

        d_real, aux_real, _ = self.discriminator(real_images, condition,
                                                 prev_image)
        d_fake, aux_fake, _ = self.discriminator(fake_images, condition,
                                                 prev_image)
        d_wrong, _, _ = self.discriminator(wrong_images, condition, wrong_prev)

        d_loss, aux_loss = self._discriminator_masked_loss(
            d_real, d_fake, d_wrong, aux_real, aux_fake, objects, aux_reg,
            mask)

        d_loss.backward(retain_graph=True)
        if gp_reg:
            reg = gp_reg * self._masked_gradient_penalty(
                d_real, real_images, mask)
            reg.backward(retain_graph=True)

        grad_norm = _recurrent_gan.get_grad_norm(
            self.discriminator.parameters())
        self.discriminator_optimizer.step()

        d_loss_scalar = d_loss.item()
        d_real_np = d_real.cpu().data.numpy()
        d_fake_np = d_fake.cpu().data.numpy()
        aux_loss_scalar = aux_loss.item() if isinstance(
            aux_loss, torch.Tensor) else aux_loss
        grad_norm_scalar = grad_norm.item()
        del d_loss
        del d_real
        del d_fake
        del aux_loss
        del grad_norm
        gc.collect()

        return d_loss_scalar, d_real_np, d_fake_np, aux_loss_scalar, grad_norm_scalar

    def _optimize_generator(self,
                            fake_images,
                            prev_image,
                            condition,
                            objects,
                            aux_reg,
                            mask,
                            mu,
                            logvar,
                            loss_ctr=None):

        self.generator.zero_grad()
        d_fake, aux_fake, _ = self.discriminator(fake_images, condition,
                                                 prev_image)
        g_loss = self._generator_masked_loss(d_fake, aux_fake, objects,
                                             aux_reg, mu, logvar, mask)

        g_loss.backward(retain_graph=True)

        if loss_ctr is not None:
            loss_ctr.backward(retain_graph=True)

            del loss_ctr

        gen_grad_norm = _recurrent_gan.get_grad_norm(
            self.generator.parameters())

        self.generator_optimizer.step()

        g_loss_scalar = g_loss.item()
        gen_grad_norm_scalar = gen_grad_norm.item()

        del g_loss
        del gen_grad_norm
        gc.collect()

        return g_loss_scalar, gen_grad_norm_scalar

    def _optimize_rnn(self):
        torch.nn.utils.clip_grad_norm_(self.rnn.parameters(),
                                       self.cfg.grad_clip)
        rnn_grad_norm = _recurrent_gan.get_grad_norm(self.rnn.parameters())
        self.rnn_optimizer.step()
        self.rnn.zero_grad()

        gru_grad_norm = None
        torch.nn.utils.clip_grad_norm_(self.sentence_encoder.parameters(),
                                       self.cfg.grad_clip)
        gru_grad_norm = _recurrent_gan.get_grad_norm(
            self.sentence_encoder.parameters())
        self.sentence_encoder_optimizer.step()
        self.sentence_encoder.zero_grad()

        ce_grad_norm = _recurrent_gan.get_grad_norm(
            self.condition_encoder.parameters())
        ie_grad_norm = _recurrent_gan.get_grad_norm(
            self.image_encoder.parameters())
        self.feature_encoders_optimizer.step()
        self.condition_encoder.zero_grad()
        self.image_encoder.zero_grad()
        return rnn_grad_norm, gru_grad_norm, ce_grad_norm, ie_grad_norm

    def _discriminator_masked_loss(self, d_real, d_fake, d_wrong, aux_real,
                                   aux_fake, objects, aux_reg, mask):
        """Accumulates losses only for sequences that have not ended
        to avoid back-propagation through padding"""
        d_loss = []
        aux_losses = []
        for b, ended in enumerate(mask):
            if not ended:
                sample_loss = self.criterion.discriminator(
                    d_real[b], d_fake[b], d_wrong[b],
                    self.cfg.wrong_fake_ratio)
                if aux_reg > 0:
                    aux_loss = aux_reg * (
                        self.aux_criterion(aux_real[b], objects[b]).mean() +
                        self.aux_criterion(aux_fake[b], objects[b]).mean())
                    sample_loss += aux_loss
                    aux_losses.append(aux_loss)

                d_loss.append(sample_loss)

        d_loss = torch.stack(d_loss).mean()

        if len(aux_losses) > 0:
            aux_losses = torch.stack(aux_losses).mean()
        else:
            aux_losses = 0

        return d_loss, aux_losses

    def _generator_masked_loss(self, d_fake, aux_fake, objects, aux_reg, mu,
                               logvar, mask):
        """Accumulates losses only for sequences that have not ended
        to avoid back-propagation through padding"""
        g_loss = []
        for b, ended in enumerate(mask):
            if not ended:
                sample_loss = self.criterion.generator(d_fake[b])
                if aux_reg > 0:
                    aux_loss = aux_reg * self.aux_criterion(
                        aux_fake[b], objects[b]).mean()
                else:
                    aux_loss = 0
                if mu is not None:
                    kl_loss = self.cfg.cond_kl_reg * kl_penalty(
                        mu[b], logvar[b])
                else:
                    kl_loss = 0

                g_loss.append(sample_loss + aux_loss + kl_loss)

        g_loss = torch.stack(g_loss)
        return g_loss.mean()

    def _masked_gradient_penalty(self, d_real, real_images, mask):
        gp_reg = gradient_penalty(d_real, real_images).mean()
        return gp_reg

    # region Helpers
    def _plot_losses(self, visualizer, g_loss, d_loss, aux_loss, iteration):
        _recurrent_gan._plot_losses(self, visualizer, g_loss, d_loss, aux_loss,
                                    iteration)

    def _plot_gradients(self, visualizer, rnn, gen, disc, gru, ce, ie,
                        iteration):
        _recurrent_gan._plot_gradients(self, visualizer, rnn, gen, disc, gru,
                                       ce, ie, iteration)

    def _draw_images(self, visualizer, real, fake, nrow):
        _recurrent_gan.draw_images(self, visualizer, real, fake, nrow)

    def _save(self, fake, path, epoch, iteration):
        _recurrent_gan._save(self, fake, path, epoch, iteration)

    def save_model(self, path, epoch, iteration):
        _recurrent_gan.save_model(self, path, epoch, iteration)

    def load_model(self, snapshot_path):
        _recurrent_gan.load_model(self, snapshot_path)
示例#14
0
class Solver(object):
    def __init__(self, opt):
        self.opt = opt
        self.name = opt.name
        self.output_dir = Path(opt.output_dir) / self.name
        self.preddump_dir = self.output_dir / 'preddump'
        self.preddump_dir.mkdir(parents=True, exist_ok=True)
        self.sample_dir = self.output_dir / 'sample'
        self.sample_dir.mkdir(parents=True, exist_ok=True)
        self.log_dir = self.output_dir / 'tensorboard'
        self.log_dir.mkdir(parents=True, exist_ok=True)
        self.ckpt_dir = self.output_dir / 'ckpt'
        self.ckpt_dir.mkdir(parents=True, exist_ok=True)

        self.global_iter = 0
        self.init_loss_functions()
        self.init_colorize()
        self.init_models_optimizers_data()
        self.load_states()

        if not opt.inference:
            self.writer = SummaryWriter(self.log_dir,
                                        purge_step=self.global_iter)

    def init_models_optimizers_data(self):
        opt = self.opt
        device = opt.device
        self.encoder, self.decoder = get_autoencoder(opt)
        self.frame_predictor = DeterministicConvLSTM(opt.g_dim + opt.z_dim,
                                                     opt.g_dim, opt.rnn_size,
                                                     opt.predictor_rnn_layers,
                                                     opt.batch_size, opt.M)
        self.posterior = GaussianConvLSTM(opt.g_dim, opt.z_dim, opt.rnn_size,
                                          opt.posterior_rnn_layers,
                                          opt.batch_size, opt.M)
        self.prior = GaussianConvLSTM(opt.g_dim, opt.z_dim, opt.rnn_size,
                                      opt.prior_rnn_layers, opt.batch_size,
                                      opt.M)
        if not opt.deepspeed:
            self.encoder = self.encoder.to(device)
            self.decoder = self.decoder.to(device)
            self.frame_predictor = self.frame_predictor.to(device)
            self.posterior = self.posterior.to(device)
            self.prior = self.prior.to(device)

        self.frame_predictor_optimizer = optim.Adam(
            self.frame_predictor.parameters(),
            lr=opt.lr,
            betas=(opt.beta1, 0.999))
        self.posterior_optimizer = optim.Adam(self.posterior.parameters(),
                                              lr=opt.lr,
                                              betas=(opt.beta1, 0.999))
        self.prior_optimizer = optim.Adam(self.prior.parameters(),
                                          lr=opt.lr,
                                          betas=(opt.beta1, 0.999))
        self.encoder_optimizer = optim.Adam(self.encoder.parameters(),
                                            lr=opt.lr,
                                            betas=(opt.beta1, 0.999))
        self.decoder_optimizer = optim.Adam(self.decoder.parameters(),
                                            lr=opt.lr,
                                            betas=(opt.beta1, 0.999))

        self.frame_predictor.apply(init_weights)
        self.posterior.apply(init_weights)
        self.prior.apply(init_weights)
        self.encoder.apply(init_weights)
        self.decoder.apply(init_weights)

        encoder_params = filter(lambda p: p.requires_grad,
                                self.encoder.parameters())
        decoder_params = filter(lambda p: p.requires_grad,
                                self.decoder.parameters())
        frame_predictor_params = filter(lambda p: p.requires_grad,
                                        self.frame_predictor.parameters())
        posterior_params = filter(lambda p: p.requires_grad,
                                  self.posterior.parameters())
        prior_params = filter(lambda p: p.requires_grad,
                              self.prior.parameters())

        if opt.load_dp_ckpt:
            self.load_dp_ckpt()
        if opt.load_ds_ckpt:
            self.load_ds_ckpt()

        train_data, test_data = load_dataset(opt)
        if opt.inference:
            # use pytorch loaders for both train/test loader in inference mode
            train_loader = DataLoader(train_data,
                                      num_workers=opt.data_threads,
                                      batch_size=opt.batch_size,
                                      shuffle=True,
                                      drop_last=True,
                                      pin_memory=True)
            test_loader = DataLoader(test_data,
                                     num_workers=opt.data_threads,
                                     batch_size=1,
                                     shuffle=False,
                                     drop_last=False,
                                     pin_memory=True)
        elif not opt.inference and not opt.deepspeed:
            # use pytorch loaders for both train/test loader when not using deepspeed
            train_loader = DataLoader(train_data,
                                      num_workers=opt.data_threads,
                                      batch_size=opt.batch_size,
                                      shuffle=True,
                                      drop_last=True,
                                      pin_memory=True)
            test_loader = DataLoader(test_data,
                                     num_workers=opt.data_threads,
                                     batch_size=opt.batch_size,
                                     shuffle=True,
                                     drop_last=True,
                                     pin_memory=True)
        elif not opt.inference and opt.deepspeed:
            # use deepspeed train loader when training with deepspeed.
            # use pytorch test loader when testing
            test_loader = DataLoader(test_data,
                                     num_workers=opt.data_threads,
                                     batch_size=opt.batch_size,
                                     shuffle=True,
                                     drop_last=True,
                                     pin_memory=True)

        if opt.deepspeed:
            if not opt.inference:
                self.encoder, self.encoder_optimizer, train_loader, _ = ds.initialize(
                    opt,
                    model=self.encoder,
                    model_parameters=encoder_params,
                    dist_init_required=True,
                    training_data=train_data)
            else:
                self.encoder, self.encoder_optimizer, _, _ = ds.initialize(
                    opt,
                    model=self.encoder,
                    model_parameters=encoder_params,
                    dist_init_required=True)
            self.decoder, self.decoder_optimizer, _, _ = ds.initialize(
                opt,
                model=self.decoder,
                model_parameters=decoder_params,
                dist_init_required=False)
            self.frame_predictor, self.frame_predictor_optimizer, _, _ = ds.initialize(
                opt,
                model=self.frame_predictor,
                model_parameters=frame_predictor_params,
                dist_init_required=False)
            self.posterior, self.posterior_optimizer, _, _ = ds.initialize(
                opt,
                model=self.posterior,
                model_parameters=posterior_params,
                dist_init_required=False)
            self.prior, self.prior_optimizer, _, _ = ds.initialize(
                opt,
                model=self.prior,
                model_parameters=prior_params,
                dist_init_required=False)

            def normalize_data_ds(opt, sequence):
                data, data_path = sequence
                data.transpose_(0, 1)
                return data.to(self.encoder.local_rank), data_path

            def get_batch(loader):
                while True:
                    for sequence in loader:
                        batch = normalize_data_ds(opt, sequence)
                        yield batch

            def get_dump_batch(loader):
                for sequence in loader:
                    batch = normalize_data_ds(opt, sequence)
                    yield batch
        else:
            self.encoder = DataParallel(self.encoder)
            self.decoder = DataParallel(self.decoder)
            self.frame_predictor = DataParallel(self.frame_predictor)
            self.posterior = DataParallel(self.posterior)
            self.prior = DataParallel(self.prior)

            if opt.device == 'cuda':
                dtype = torch.cuda.FloatTensor
            else:
                dtype = torch.FloatTensor

            def get_batch(loader):
                while True:
                    for sequence in loader:
                        batch = normalize_data_dp(opt, dtype, sequence)
                        yield batch

            def get_dump_batch(loader):
                for sequence in loader:
                    batch = normalize_data_dp(opt, dtype, sequence)
                    yield batch

        self.training_batch_generator = get_batch(train_loader)
        if opt.inference:
            self.testing_batch_generator = get_dump_batch(test_loader)
        else:
            self.testing_batch_generator = get_batch(test_loader)

    def init_colorize(self):
        if self.opt.dataset in [
                'KITTI_64', 'KITTI_128', 'KITTI_256', 'Cityscapes_128x256'
        ]:
            self.opt.n_class = n_class = 19
            self.pallette = return_colormap('KITTI').byte().numpy().reshape(
                -1).tolist()
            self.colorize = Colorize(n_class, return_colormap('KITTI'))
        elif self.opt.dataset in ['Pose_64', 'Pose_128']:
            self.opt.n_class = n_class = 25
            self.pallette = return_colormap(
                N=25).byte().numpy().reshape(-1).tolist()
            self.colorize = Colorize(n_class, return_colormap(N=25))
        else:
            raise ValueError()

    def load_states(self, idx=None):
        if self.opt.deepspeed and not self.opt.load_dp_ckpt:
            if idx is None:
                idx = 'last'
            savedir = self.ckpt_dir / str(idx)
            if savedir is not None:
                try:
                    _, _ = self.encoder.load_checkpoint(savedir, 'encoder')
                    _, _ = self.decoder.load_checkpoint(savedir, 'decoder')
                    _, _ = self.frame_predictor.load_checkpoint(
                        savedir, 'frame_predictor')
                    _, _ = self.posterior.load_checkpoint(savedir, 'posterior')
                    _, _ = self.prior.load_checkpoint(savedir, 'prior')
                    self.global_iter = _['step']
                except:
                    printstr = 'ckpt is not found at: %s' % savedir
                    print_rank_0(printstr)
                    return
                else:
                    printstr = 'ckpt is loaded from: %s' % savedir
                    print_rank_0(printstr)

        if not self.opt.deepspeed and not self.opt.load_ds_ckpt:
            idx = 'last.pth' if idx is None else '%d.pth' % idx
            path = self.ckpt_dir / idx
            try:
                ckpt = torch.load(path)

                self.global_iter = ckpt['global_iter']

                self.frame_predictor.load_state_dict(ckpt['frame_predictor'])
                self.posterior.load_state_dict(ckpt['posterior'])
                self.prior.load_state_dict(ckpt['prior'])
                self.encoder.load_state_dict(ckpt['encoder'])
                self.decoder.load_state_dict(ckpt['decoder'])
            except:
                printstr = 'failed to load ckpt from: %s' % path
                print(printstr)
            else:
                printstr = 'ckpt is loaded from: %s' % path
                print(printstr)

    def dump_states(self, idx=None):
        if self.opt.deepspeed:
            if idx is None:
                idx = 'last'
            savedir = self.ckpt_dir / str(idx)
            client_state = {'step': self.global_iter, 'opt': self.opt}
            self.encoder.save_checkpoint(savedir, 'encoder', client_state)
            self.decoder.save_checkpoint(savedir, 'decoder', client_state)
            self.frame_predictor.save_checkpoint(savedir, 'frame_predictor',
                                                 client_state)
            self.posterior.save_checkpoint(savedir, 'posterior', client_state)
            self.prior.save_checkpoint(savedir, 'prior', client_state)
        else:
            torch.save(
                {
                    'global_iter':
                    self.global_iter,
                    'encoder':
                    self.encoder.state_dict(),
                    'encoder_optimizer':
                    self.encoder_optimizer.state_dict(),
                    'decoder':
                    self.decoder.state_dict(),
                    'decoder_optimizer':
                    self.decoder_optimizer.state_dict(),
                    'frame_predictor':
                    self.frame_predictor.state_dict(),
                    'frame_predictor_optimizer':
                    self.frame_predictor_optimizer.state_dict(),
                    'posterior':
                    self.posterior.state_dict(),
                    'posterior_optimizer':
                    self.posterior_optimizer.state_dict(),
                    'prior':
                    self.prior.state_dict(),
                    'prior_optimizer':
                    self.prior_optimizer.state_dict(),
                    'opt':
                    self.opt
                }, '%s/%s.pth' % (self.ckpt_dir, idx))

    def load_dp_ckpt(self, idx=None):
        idx = 'last.pth' if idx is None else '%d.pth' % idx
        path = self.ckpt_dir / idx
        try:
            ckpt = torch.load(path)
        except FileNotFoundError as e:
            print(e)
            pass
        else:
            self.global_iter = ckpt['global_iter']

            self.encoder = DataParallel(self.encoder)
            self.decoder = DataParallel(self.decoder)
            self.frame_predictor = DataParallel(self.frame_predictor)
            self.posterior = DataParallel(self.posterior)
            self.prior = DataParallel(self.prior)

            self.frame_predictor.load_state_dict(ckpt['frame_predictor'])
            self.posterior.load_state_dict(ckpt['posterior'])
            self.prior.load_state_dict(ckpt['prior'])
            self.encoder.load_state_dict(ckpt['encoder'])
            self.decoder.load_state_dict(ckpt['decoder'])

            self.encoder = self.encoder.module
            self.decoder = self.decoder.module
            self.frame_predictor = self.frame_predictor.module
            self.posterior = self.posterior.module
            self.prior = self.prior.module

            printstr = 'ckpt is loaded from: %s' % path
            print(printstr)

    def load_ds_ckpt(self, idx=None):
        idx = 'last' if idx is None else str(idx)
        path = str(self.ckpt_dir / idx / '%s/mp_rank_00_model_states.pt')

        try:
            encoder_ckpt = torch.load(path % 'encoder')
            decoder_ckpt = torch.load(path % 'decoder')
            frame_predictor_ckpt = torch.load(path % 'frame_predictor')
            posterior_ckpt = torch.load(path % 'posterior')
            prior_ckpt = torch.load(path % 'prior')
        except FileNotFoundError as e:
            print(e)
            pass
        else:
            self.encoder.load_state_dict(encoder_ckpt['module'])
            self.decoder.load_state_dict(decoder_ckpt['module'])
            self.frame_predictor.load_state_dict(
                frame_predictor_ckpt['module'])
            self.posterior.load_state_dict(posterior_ckpt['module'])
            self.prior.load_state_dict(prior_ckpt['module'])

            self.encoder_optimizer.load_state_dict(encoder_ckpt['optimizer'])
            self.decoder_optimizer.load_state_dict(decoder_ckpt['optimizer'])
            self.frame_predictor_optimizer.load_state_dict(
                frame_predictor_ckpt['optimizer'])
            self.posterior_optimizer.load_state_dict(
                posterior_ckpt['optimizer'])
            self.prior_optimizer.load_state_dict(prior_ckpt['optimizer'])
            self.global_iter = encoder_ckpt['step']
            printstr = 'ckpt is loaded from: %s' % path
            print(printstr)

    def init_loss_functions(self):
        self.kl_criterion = kl_criterion
        self.nll = nn.NLLLoss()

    def train(self, x):
        self.encoder.zero_grad()
        self.decoder.zero_grad()
        self.frame_predictor.zero_grad()
        self.posterior.zero_grad()
        self.prior.zero_grad()

        kld = 0
        nll = 0
        prior_hidden = None
        posterior_hidden = None
        frame_predictor_hidden = None
        for i in range(1, self.opt.n_past + self.opt.n_future):
            x_in = x[i - 1]
            x_target = x[i]

            h = self.encoder(x_in)
            h_target = self.encoder(x_target)[0]

            if self.opt.last_frame_skip or i < self.opt.n_past + 1:
                h, skip = h
            else:
                h = h[0]

            z_t, mu, logvar, posterior_hidden = self.posterior(
                h_target, posterior_hidden)
            _, mu_p, logvar_p, prior_hidden = self.prior(h, prior_hidden)
            h_pred, frame_predictor_hidden = self.frame_predictor(
                torch.cat([h, z_t], 1), frame_predictor_hidden)

            x_pred = self.decoder([h_pred, skip])
            nll += self.nll(x_pred, x_target.squeeze(1).long())
            kld += self.kl_criterion(mu, logvar, mu_p, logvar_p)

        loss = nll + kld * self.opt.beta
        loss.backward()

        self.encoder_optimizer.step()
        self.decoder_optimizer.step()
        self.frame_predictor_optimizer.step()
        self.posterior_optimizer.step()
        self.prior_optimizer.step()

        output = dict()
        normalizer = self.opt.n_past + self.opt.n_future
        output['nll'] = nll.item() / normalizer
        output['kld'] = kld.item() / normalizer

        return output

    @torch.no_grad()
    def validate(self, x):
        kld = 0
        nll = 0
        prior_hidden = None
        posterior_hidden = None
        frame_predictor_hidden = None
        for i in range(1, self.opt.n_past + self.opt.n_future):
            x_in = x[i - 1]
            x_target = x[i]

            h = self.encoder(x_in)
            h_target = self.encoder(x_target)[0]

            if self.opt.last_frame_skip or i < self.opt.n_past + 1:
                h, skip = h
            else:
                h = h[0]

            z_t, mu, logvar, posterior_hidden = self.posterior(
                h_target, posterior_hidden)
            _, mu_p, logvar_p, prior_hidden = self.prior(h, prior_hidden)
            h_pred, frame_predictor_hidden = self.frame_predictor(
                torch.cat([h, z_t], 1), frame_predictor_hidden)

            x_pred = self.decoder([h_pred, skip])
            nll += self.nll(x_pred, x_target.squeeze(1).long())
            kld += self.kl_criterion(mu, logvar, mu_p, logvar_p)

        output = dict()
        normalizer = self.opt.n_past + self.opt.n_future
        output['nll'] = nll.item() / normalizer
        output['kld'] = kld.item() / normalizer

        return output

    def solve(self):
        pbar = tqdm(range(self.global_iter, self.opt.max_iter))
        start_time = time.time()
        for _ in pbar:
            self.global_iter += 1

            self.frame_predictor.train()
            self.posterior.train()
            self.prior.train()
            self.encoder.train()
            self.decoder.train()

            x, _ = next(self.training_batch_generator)

            # train
            output = self.train(x)
            nll = output['nll']
            kld = output['kld']

            if self.global_iter % self.opt.log_ckpt_iter == 0:
                # save the model
                self.dump_states(self.global_iter)
                self.dump_states('last')

            if time.time() - start_time > self.opt.log_ckpt_sec:
                # save the model
                self.dump_states('last')
                start_time = time.time()

            if self.global_iter % self.opt.print_iter == 0:
                printstr = '[%02d] nll: %.5f | kld loss: %.5f' % (
                    self.global_iter,
                    nll,
                    kld,
                )
                #tprint_rank_0(pbar, printstr)
                pbar.set_description(printstr)

            if self.global_iter % self.opt.log_line_iter == 0:
                self.writer.add_scalar('train_nll',
                                       nll,
                                       global_step=self.global_iter)
                self.writer.add_scalar('train_kld',
                                       kld,
                                       global_step=self.global_iter)

            if self.global_iter % self.opt.log_img_iter == 0:
                # plot some stuff
                self.frame_predictor.eval()
                self.posterior.eval()
                self.prior.eval()
                self.encoder.eval()
                self.decoder.eval()

                x, _ = next(self.testing_batch_generator)
                if torch.distributed.is_initialized():
                    if torch.distributed.get_rank() == 0:
                        plot(x, self)
                else:
                    plot(x, self)

            if self.global_iter % self.opt.validate_iter == 0:
                nll = 0
                kld = 0
                nvalsample = 0
                for _ in range(100):
                    x, _ = next(self.testing_batch_generator)
                    output = self.validate(x)
                    nll += output['nll']
                    kld += output['kld']
                    nvalsample += x[0].size(0)

                nll /= nvalsample
                kld /= nvalsample
                self.writer.add_scalar('test_nll',
                                       nll,
                                       global_step=self.global_iter)
                self.writer.add_scalar('test_kld',
                                       kld,
                                       global_step=self.global_iter)
        pbar.close()

    @torch.no_grad()
    def inference(self):
        topil = transforms.ToPILImage()

        n_prediction = self.opt.n_prediction

        self.frame_predictor.eval()
        self.posterior.eval()
        self.prior.eval()
        self.encoder.eval()
        self.decoder.eval()
        for batch_idx, (x_seqs, paths) in tqdm(
                enumerate(self.testing_batch_generator)):

            # When unrolling step is beyond the number of grund-truth data
            for _ in range(self.opt.n_past + self.opt.n_eval - len(x_seqs)):
                x_seqs.append(x_seqs[-1])
                path_parts = paths[-1][0].split('/')
                name = path_parts[-1]
                if 'KITTI' in self.opt.dataset:
                    newname = '%s_%010d.png' % (
                        '_'.join(name.strip('.png').split('_')[:-1]),
                        int(name.strip('.png').split('_')[-1]) +
                        self.opt.frame_sampling_rate)
                elif 'Cityscapes' in self.opt.dataset:
                    new_idx = '%06d' % (int(name.split('_')[-2]) +
                                        self.opt.frame_sampling_rate)
                    parts = name.split('_')
                    parts[-2] = new_idx
                    newname = '_'.join(parts)
                elif 'Pose' in self.opt.dataset:
                    newname = newname = 'frame%06d_IUV.png' % (
                        int(name.strip('.png').strip('frame').split('_')[0]) +
                        self.opt.frame_sampling_rate)
                newpath = ['/'.join(path_parts[:-1] + [newname])]
                paths.append(newpath)

            x_pred_seqs = []
            for s in range(n_prediction):
                skip = None
                prior_hidden = None
                posterior_hidden = None
                frame_predictor_hidden = None
                x_in = x_seqs[0]
                x_pred_seq = [x_in.data.cpu().byte()]
                for i in range(1, self.opt.n_past + self.opt.n_eval):

                    h = self.encoder(x_in)
                    if self.opt.last_frame_skip or i < self.opt.n_past + 1:
                        h, skip = h
                    else:
                        h = h[0]

                    if i < self.opt.n_past:
                        x_target = x_seqs[i]
                        h_target = self.encoder(x_target)[0]
                        z_t, _, _, posterior_hidden = self.posterior(
                            h_target, posterior_hidden)
                        _, _, _, prior_hidden = self.prior(h, prior_hidden)
                        _, frame_predictor_hidden = self.frame_predictor(
                            torch.cat([h, z_t], 1), frame_predictor_hidden)
                        x_in = x_target
                    else:
                        z_t, _, _, prior_hidden = self.prior(h, prior_hidden)
                        h_pred, frame_predictor_hidden = self.frame_predictor(
                            torch.cat([h, z_t], 1), frame_predictor_hidden)
                        x_in = self.decoder([h_pred,
                                             skip]).argmax(dim=1, keepdim=True)

                    x_pred_seq.append(x_in.data.cpu().byte())

                x_pred_seq = torch.stack(x_pred_seq, dim=1)
                x_pred_seqs.append(x_pred_seq)

            x_seqs = torch.cat(
                x_seqs).data.cpu().byte()  # (n_past+n_eval, 1, H, W)
            x_pred_seqs = torch.cat(x_pred_seqs).transpose(
                0, 1)  # (n_past+n_eval, n_prediction, 1, H, W)

            for x_gt, x_preds, path in zip(x_seqs, x_pred_seqs, paths):
                path = Path(path[0])
                if 'KITTI' in self.opt.dataset:
                    maskpath = self.preddump_dir.joinpath(
                        str(self.global_iter), 'batch_%05d' % (batch_idx + 1),
                        'sample_%05d' % (0), Path(path.parts[3], path.name))
                elif 'Cityscapes' in self.opt.dataset:
                    maskpath = self.preddump_dir.joinpath(
                        str(self.global_iter), 'batch_%05d' % (batch_idx + 1),
                        'sample_%05d' % (0), Path(path.parts[3], path.name))
                elif 'Pose' in self.opt.dataset:
                    vidname, clipname = path.parts[-3:-1]
                    maskpath = self.preddump_dir.joinpath(
                        str(self.global_iter), 'batch_%05d' % (batch_idx + 1),
                        'sample_%05d' % (0), vidname + '_' + clipname,
                        path.name)

                maskpath.parent.mkdir(exist_ok=True, parents=True)
                x_gt = topil(x_gt).convert('P', colors=self.opt.n_class)
                x_gt.putpalette(self.pallette)
                x_gt.save(maskpath)

                for num_x_pred, x_pred in enumerate(x_preds):
                    if 'KITTI' in self.opt.dataset:
                        maskpath = self.preddump_dir.joinpath(
                            str(self.global_iter),
                            'batch_%05d' % (batch_idx + 1),
                            'sample_%05d' % (num_x_pred + 1),
                            Path(path.parts[3], path.name))
                    elif 'Cityscapes' in self.opt.dataset:
                        maskpath = self.preddump_dir.joinpath(
                            str(self.global_iter),
                            'batch_%05d' % (batch_idx + 1),
                            'sample_%05d' % (num_x_pred + 1),
                            Path(path.parts[3], path.name))
                    elif 'Pose' in self.opt.dataset:
                        maskpath = self.preddump_dir.joinpath(
                            str(self.global_iter),
                            'batch_%05d' % (batch_idx + 1),
                            'sample_%05d' % (num_x_pred + 1),
                            vidname + '_' + clipname, path.name)
                    maskpath.parent.mkdir(exist_ok=True, parents=True)
                    x_pred = topil(x_pred).convert('P',
                                                   colors=self.opt.n_class)
                    x_pred.putpalette(self.pallette)
                    x_pred.save(maskpath)