Ejemplo n.º 1
0
 def train(self, epoch_idx, batch_size, max_norm):
     logger, model = self.logger, self.model
     logger.info("At %d-th epoch with lr %f.", epoch_idx, self.get_lr())
     model.train()
     sampler, nb_batch = self.iterate_batch(TRAIN, batch_size)
     losses, cnt = 0, 0
     for batch in tqdm(sampler(batch_size), total=nb_batch):
         loss = model.get_loss(batch)
         self.optimizer.zero_grad()
         loss.backward()
         if max_norm > 0:
             torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
         logger.debug(
             "loss %f with total grad norm %f",
             loss,
             util.grad_norm(model.parameters()),
         )
         self.optimizer.step()
         if not isinstance(self.scheduler, ReduceLROnPlateau):
             self.scheduler.step()
         self.global_steps += 1
         losses += loss.item()
         cnt += 1
     loss = losses / cnt
     self.logger.info(f"Running average train loss is {loss} at epoch {epoch_idx}")
     return loss
Ejemplo n.º 2
0
 def train(self, epoch_idx, batch_size, max_norm):
     logger, model, data = self.logger, self.model, self.data
     logger.info('At %d-th epoch with lr %f.', epoch_idx,
                 self.optimizer.param_groups[0]['lr'])
     model.train()
     nb_train_batch = ceil(data.nb_train / batch_size)
     for src, src_mask, trg, _ in tqdm(data.train_batch_sample(batch_size),
                                       total=nb_train_batch):
         out = model(src, src_mask, trg)
         loss = model.loss(out, trg[1:])
         self.optimizer.zero_grad()
         loss.backward()
         if max_norm > 0:
             torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
         logger.debug('loss %f with total grad norm %f', loss,
                      util.grad_norm(model.parameters()))
         self.optimizer.step()
Ejemplo n.º 3
0
 def train(self, epoch_idx, batch_size, max_norm):
     logger, model, data = self.logger, self.model, self.data
     logger.info('At %d-th epoch with lr %f.', epoch_idx,
                 self.optimizer.param_groups[0]['lr'])
     model.train()
     nb_train_batch = ceil(data.nb_train / batch_size)
     for src, src_mask, trg, _ in tqdm(
             data.train_batch_sample(batch_size), total=nb_train_batch):
         out = model(src, src_mask, trg)
         loss = model.loss(out, trg[1:])
         self.optimizer.zero_grad()
         loss.backward()
         if max_norm > 0:
             torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
         logger.debug('loss %f with total grad norm %f', loss,
                      util.grad_norm(model.parameters()))
         self.optimizer.step()
Ejemplo n.º 4
0
    def train(self,
              dataset,
              epoch,
              batch_size=GAN_BATCH,
              test_sample_every=5,
              hist_avg=False,
              evaluator=None):

        start_time = time.strftime("%Y-%m-%d-%H-%M", time.gmtime())
        os.makedirs(f'{OUT_DIR}/{start_time}')

        # print('cun')
        # for e in tqdm(range(epoch), desc='Epochs', dynamic_ncols=True):
        #     self.gen.eval()
        #     generated_samples = [resolution.unsqueeze(0) for resolution in self.sample_test_set(dataset)]
        #     self._save_generated(generated_samples, e, f'{OUT_DIR}/{start_time}')
        #
        #     return

        if hist_avg:
            avg_g_params = deepcopy(list(p.data
                                         for p in self.gen.parameters()))

        loader_config = {
            'batch_size': batch_size,
            'shuffle': True,
            'drop_last': True,
            'collate_fn': dataset.collate_fn
        }

        train_loader = DataLoader(dataset.train, **loader_config)

        metrics = {
            'IS': [],
            'FID': [],
            'loss': {
                'g': [],
                'd': []
            },
            'accuracy': {
                'real': [],
                'fake': [],
                'mismatched': [],
                'unconditional_real': [],
                'unconditional_fake': []
            }
        }

        if evaluator is not None:
            evaluator = evaluator(dataset, self.damsm.img_enc.inception_model,
                                  batch_size, self.device)

        noise = torch.FloatTensor(batch_size, D_Z).to(self.device)
        gen_updates = 0

        self.disc.train()

        for e in tqdm(range(epoch), desc='Epochs', dynamic_ncols=True):
            self.gen.train(), self.disc.train()
            g_loss = 0
            w_loss = 0
            s_loss = 0
            kl_loss = 0
            g_stage_loss = np.zeros(3, dtype=float)
            d_loss = np.zeros(3, dtype=float)
            real_acc = np.zeros(3, dtype=float)
            fake_acc = np.zeros(3, dtype=float)
            mismatched_acc = np.zeros(3, dtype=float)
            uncond_real_acc = np.zeros(3, dtype=float)
            uncond_fake_acc = np.zeros(3, dtype=float)
            disc_skips = np.zeros(3, dtype=int)

            train_pbar = tqdm(train_loader,
                              desc='Training',
                              leave=False,
                              dynamic_ncols=True)
            for batch in train_pbar:
                real_imgs = [batch['img64'], batch['img128'], batch['img256']]

                with torch.no_grad():
                    word_embs, sent_embs = self.damsm.txt_enc(batch['caption'])
                attn_mask = torch.tensor(batch['caption']).to(
                    self.device) == dataset.vocab[END_TOKEN]

                # Generate images
                noise.data.normal_(0, 1)
                generated, att, mu, logvar = self.gen(noise, sent_embs,
                                                      word_embs, attn_mask)

                # Discriminator loss (with label smoothing)
                batch_d_loss, batch_real_acc, batch_fake_acc, batch_mismatched_acc, batch_uncond_real_acc, batch_uncond_fake_acc, batch_disc_skips = self.discriminator_step(
                    real_imgs, generated, sent_embs, 0.1)

                d_grad_norm = [grad_norm(d) for d in self.discriminators]

                d_loss += batch_d_loss
                real_acc += batch_real_acc
                fake_acc += batch_fake_acc
                mismatched_acc += batch_mismatched_acc
                uncond_real_acc += batch_uncond_real_acc
                uncond_fake_acc += batch_uncond_fake_acc
                disc_skips += batch_disc_skips

                # Generator loss
                batch_g_losses = self.generator_step(generated, word_embs,
                                                     sent_embs, mu, logvar,
                                                     batch['label'])
                g_total, batch_g_stage_loss, batch_w_loss, batch_s_loss, batch_kl_loss = batch_g_losses
                g_stage_loss += batch_g_stage_loss
                w_loss += batch_w_loss
                s_loss += (batch_s_loss)
                kl_loss += (batch_kl_loss)
                gen_updates += 1

                avg_g_loss = g_total.item() / batch_size
                g_loss += float(avg_g_loss)

                if hist_avg:
                    for p, avg_p in zip(self.gen.parameters(), avg_g_params):
                        avg_p.mul_(0.999).add_(0.001, p.data)

                    if gen_updates % 1000 == 0:
                        tqdm.write(
                            'Replacing generator weights with their moving average'
                        )
                        for p, avg_p in zip(self.gen.parameters(),
                                            avg_g_params):
                            p.data.copy_(avg_p)

                train_pbar.set_description(
                    f'Training (G: {grad_norm(self.gen):.2f}  '
                    f'D64: {d_grad_norm[0]:.2f}  '
                    f'D128: {d_grad_norm[1]:.2f}  '
                    f'D256: {d_grad_norm[2]:.2f})')

            batches = len(train_loader)

            g_loss /= batches
            g_stage_loss /= batches
            w_loss /= batches
            s_loss /= batches
            kl_loss /= batches
            d_loss /= batches
            real_acc /= batches
            fake_acc /= batches
            mismatched_acc /= batches
            uncond_real_acc /= batches
            uncond_fake_acc /= batches

            metrics['loss']['g'].append(g_loss)
            metrics['loss']['d'].append(d_loss)
            metrics['accuracy']['real'].append(real_acc)
            metrics['accuracy']['fake'].append(fake_acc)
            metrics['accuracy']['mismatched'].append(mismatched_acc)
            metrics['accuracy']['unconditional_real'].append(uncond_real_acc)
            metrics['accuracy']['unconditional_fake'].append(uncond_fake_acc)

            sep = '_' * 10
            tqdm.write(f'{sep}Epoch {e}{sep}')

            if e % test_sample_every == 0:
                self.gen.eval()
                generated_samples = [
                    resolution.unsqueeze(0)
                    for resolution in self.sample_test_set(dataset)
                ]
                self._save_generated(generated_samples, e,
                                     f'{OUT_DIR}/{start_time}')

                if evaluator is not None:
                    scores = evaluator.evaluate(self)
                    for k, v in scores.items():
                        metrics[k].append(v)
                        tqdm.write(f'{k}: {v:.2f}')

            tqdm.write(
                f'Generator avg loss: total({g_loss:.3f})  '
                f'stage0({g_stage_loss[0]:.3f})  stage1({g_stage_loss[1]:.3f})  stage2({g_stage_loss[2]:.3f})  '
                f'w({w_loss:.3f})  s({s_loss:.3f})  kl({kl_loss:.3f})')

            for i, _ in enumerate(self.discriminators):
                tqdm.write(f'Discriminator{i} avg: '
                           f'loss({d_loss[i]:.3f})  '
                           f'r-acc({real_acc[i]:.3f})  '
                           f'f-acc({fake_acc[i]:.3f})  '
                           f'm-acc({mismatched_acc[i]:.3f})  '
                           f'ur-acc({uncond_real_acc[i]:.3f})  '
                           f'uf-acc({uncond_fake_acc[i]:.3f})  '
                           f'skips({disc_skips[i]})')

        return metrics