예제 #1
0
	def train_latent(self, imgs, classes, model_dir, tensorboard_dir, retrain=False):

		data = dict(
			img=torch.from_numpy(imgs).permute(0, 3, 1, 2),
			img_id=torch.from_numpy(np.arange(imgs.shape[0])),
			class_id=torch.from_numpy(classes.astype(np.int64))
		)

		dataset = NamedTensorDataset(data)
		data_loader = DataLoader(
			dataset, batch_size=self.config['train']['batch_size'],
			shuffle=True, sampler=None, batch_sampler=None,
			num_workers=1, pin_memory=True, drop_last=True
		)

		if not retrain:
			self.latent_model = LatentModel(self.config)
			self.latent_model.init()
		self.latent_model.to(self.device)

		criterion = VGGDistance(self.config['perceptual_loss']['layers']).to(self.device)
		# content_criterion = nn.KLDivLoss()

		optimizer = Adam([
			{
				'params': itertools.chain(self.latent_model.modulation.parameters(), self.latent_model.generator.parameters()),
				'lr': self.config['train']['learning_rate']['generator']
			},
			{
				'params': itertools.chain(self.latent_model.content_embedding.parameters(), self.latent_model.class_embedding.parameters()),
				'lr': self.config['train']['learning_rate']['latent']
			}
		], betas=(0.5, 0.999))

		scheduler = CosineAnnealingLR(
			optimizer,
			T_max=self.config['train']['n_epochs'] * len(data_loader),
			eta_min=self.config['train']['learning_rate']['min']
		)

		summary = SummaryWriter(log_dir=tensorboard_dir)

		train_loss = AverageMeter()
		for epoch in range(self.config['train']['n_epochs']):
			self.latent_model.train()
			train_loss.reset()

			pbar = tqdm(iterable=data_loader)
			for batch in pbar:
				batch = {name: tensor.to(self.device) for name, tensor in batch.items()}

				optimizer.zero_grad()
				out = self.latent_model(batch['img_id'], batch['class_id'])

				content_penalty = torch.sum(out['content_code'] ** 2, dim=1).mean()
				# content_penalty = content_criterion(out['content_code'], torch.normal(0, self.config['content_std'], size=out['content_code'].shape).to(self.device))
				loss = criterion(out['img'], batch['img']) + self.config['content_decay'] * content_penalty

				loss.backward()
				optimizer.step()
				scheduler.step()

				train_loss.update(loss.item())
				pbar.set_description_str('epoch #{}'.format(epoch))
				pbar.set_postfix(loss=train_loss.avg)

			pbar.close()
			self.save(model_dir, latent=True, amortized=False)

			summary.add_scalar(tag='loss', scalar_value=train_loss.avg, global_step=epoch)

			fixed_sample_img = self.generate_samples(dataset, randomized=False)
			random_sample_img = self.generate_samples(dataset, randomized=True)

			summary.add_image(tag='sample-fixed', img_tensor=fixed_sample_img, global_step=epoch)
			summary.add_image(tag='sample-random', img_tensor=random_sample_img, global_step=epoch)

		summary.close()
예제 #2
0
    def train(self, imgs, classes, model_dir, tensorboard_dir):
        imgs = torch.from_numpy(imgs).permute(0, 3, 1, 2)
        class_ids = torch.from_numpy(classes.astype(int))
        img_ids = torch.arange(imgs.shape[0])

        tensor_dataset = TensorDataset(imgs, img_ids, class_ids)
        data_loader = DataLoader(tensor_dataset,
                                 batch_size=self.config['train']['batch_size'],
                                 shuffle=True,
                                 sampler=None,
                                 batch_sampler=None,
                                 num_workers=1,
                                 pin_memory=True,
                                 drop_last=True)

        self.model.init()
        self.model.to(self.device)

        criterion = VGGDistance(self.config['perceptual_loss']['layers']).to(
            self.device)

        optimizer = Adam(
            [{
                'params': self.model.generator.parameters(),
                'lr': self.config['train']['learning_rate']['generator']
            }, {
                'params': self.model.modulation.parameters(),
                'lr': self.config['train']['learning_rate']['generator']
            }, {
                'params': self.model.embeddings.parameters(),
                'lr': self.config['train']['learning_rate']['latent']
            }],
            betas=(0.5, 0.999))

        scheduler = CosineAnnealingLR(
            optimizer,
            T_max=self.config['train']['n_epochs'] * len(data_loader),
            eta_min=self.config['train']['learning_rate']['min'])

        with SummaryWriter(
                log_dir=os.path.join(tensorboard_dir, 'stage1')) as summary:
            train_loss = AverageMeter()
            for epoch in range(1, self.config['train']['n_epochs'] + 1):
                self.model.train()
                train_loss.reset()

                with tqdm(iterable=data_loader) as pbar:
                    for batch in pbar:
                        batch_imgs, batch_img_ids, batch_class_ids = (
                            tensor.to(self.device) for tensor in batch)
                        generated_imgs, batch_content_codes, batch_class_codes = self.model(
                            batch_img_ids, batch_class_ids)

                        optimizer.zero_grad()

                        content_penalty = torch.sum(batch_content_codes**2,
                                                    dim=1).mean()
                        loss = criterion(
                            generated_imgs, batch_imgs
                        ) + self.config['content_decay'] * content_penalty
                        loss.backward()

                        optimizer.step()
                        scheduler.step()

                        train_loss.update(loss.item())
                        pbar.set_description_str('epoch #{}'.format(epoch))
                        pbar.set_postfix(loss=train_loss.avg)

                torch.save(self.model.generator.state_dict(),
                           os.path.join(model_dir, 'generator.pth'))
                torch.save(self.model.embeddings.state_dict(),
                           os.path.join(model_dir, 'embeddings.pth'))
                torch.save(self.model.modulation.state_dict(),
                           os.path.join(model_dir, 'class_modulation.pth'))

                self.model.eval()
                fixed_sample_img = self.evaluate(imgs,
                                                 img_ids,
                                                 class_ids,
                                                 randomized=False)
                random_sample_img = self.evaluate(imgs,
                                                  img_ids,
                                                  class_ids,
                                                  randomized=True)

                summary.add_scalar(tag='loss',
                                   scalar_value=train_loss.avg,
                                   global_step=epoch)
                summary.add_image(tag='sample-fixed',
                                  img_tensor=fixed_sample_img,
                                  global_step=epoch)
                summary.add_image(tag='sample-random',
                                  img_tensor=random_sample_img,
                                  global_step=epoch)
예제 #3
0
	def train_amortized(self, imgs, classes, model_dir, tensorboard_dir):
		self.amortized_model = AmortizedModel(self.config)
		self.amortized_model.modulation.load_state_dict(self.latent_model.modulation.state_dict())
		self.amortized_model.generator.load_state_dict(self.latent_model.generator.state_dict())

		data = dict(
			img=torch.from_numpy(imgs).permute(0, 3, 1, 2),
			img_id=torch.from_numpy(np.arange(imgs.shape[0])),
			class_id=torch.from_numpy(classes.astype(np.int64))
		)

		dataset = NamedTensorDataset(data)
		data_loader = DataLoader(
			dataset, batch_size=self.config['train']['batch_size'],
			shuffle=True, sampler=None, batch_sampler=None,
			num_workers=1, pin_memory=True, drop_last=True
		)

		self.latent_model.to(self.device)
		self.amortized_model.to(self.device)

		reconstruction_criterion = VGGDistance(self.config['perceptual_loss']['layers']).to(self.device)
		embedding_criterion = nn.MSELoss()

		optimizer = Adam(
			params=self.amortized_model.parameters(),
			lr=self.config['train_encoders']['learning_rate']['max'],
			betas=(0.5, 0.999)
		)

		scheduler = CosineAnnealingLR(
			optimizer,
			T_max=self.config['train_encoders']['n_epochs'] * len(data_loader),
			eta_min=self.config['train_encoders']['learning_rate']['min']
		)

		summary = SummaryWriter(log_dir=tensorboard_dir)

		train_loss = AverageMeter()
		for epoch in range(self.config['train_encoders']['n_epochs']):
			self.latent_model.eval()
			self.amortized_model.train()

			train_loss.reset()

			pbar = tqdm(iterable=data_loader)
			for batch in pbar:
				batch = {name: tensor.to(self.device) for name, tensor in batch.items()}

				optimizer.zero_grad()

				target_content_code = self.latent_model.content_embedding(batch['img_id'])
				target_class_code = self.latent_model.class_embedding(batch['class_id'])

				out = self.amortized_model(batch['img'])

				loss_reconstruction = reconstruction_criterion(out['img'], batch['img'])
				loss_content = embedding_criterion(out['content_code'], target_content_code)
				loss_class = embedding_criterion(out['class_code'], target_class_code)

				loss = loss_reconstruction + 10 * loss_content + 10 * loss_class

				loss.backward()
				optimizer.step()
				scheduler.step()

				train_loss.update(loss.item())
				pbar.set_description_str('epoch #{}'.format(epoch))
				pbar.set_postfix(loss=train_loss.avg)

			pbar.close()
			self.save(model_dir, latent=False, amortized=True)

			summary.add_scalar(tag='loss-amortized', scalar_value=loss.item(), global_step=epoch)
			summary.add_scalar(tag='rec-loss-amortized', scalar_value=loss_reconstruction.item(), global_step=epoch)
			summary.add_scalar(tag='content-loss-amortized', scalar_value=loss_content.item(), global_step=epoch)
			summary.add_scalar(tag='class-loss-amortized', scalar_value=loss_class.item(), global_step=epoch)

			fixed_sample_img = self.generate_samples_amortized(dataset, randomized=False)
			random_sample_img = self.generate_samples_amortized(dataset, randomized=True)

			summary.add_image(tag='sample-fixed-amortized', img_tensor=fixed_sample_img, global_step=epoch)
			summary.add_image(tag='sample-random-amortized', img_tensor=random_sample_img, global_step=epoch)

		summary.close()
예제 #4
0
    def train_encoders(self, imgs, classes, model_dir, tensorboard_dir):
        imgs = torch.from_numpy(imgs).permute(0, 3, 1, 2)
        class_ids = torch.from_numpy(classes.astype(int))
        img_ids = torch.arange(imgs.shape[0])

        tensor_dataset = TensorDataset(imgs, img_ids, class_ids)
        data_loader = DataLoader(
            tensor_dataset,
            batch_size=self.config['train_encoders']['batch_size'],
            shuffle=True,
            sampler=None,
            batch_sampler=None,
            num_workers=1,
            pin_memory=True,
            drop_last=True)

        self.embeddings = LordEmbeddings(self.config)
        self.modulation = LordModulation(self.config)
        self.encoders = LordEncoders(self.config)
        self.generator = LordGenerator(self.config)
        self.embeddings.load_state_dict(
            torch.load(os.path.join(model_dir, 'embeddings.pth')))
        self.modulation.load_state_dict(
            torch.load(os.path.join(model_dir, 'class_modulation.pth')))
        self.generator.load_state_dict(
            torch.load(os.path.join(model_dir, 'generator.pth')))
        self.encoders.init()

        self.model = LordStage2(self.encoders, self.modulation, self.generator)

        self.model.to(self.device)
        self.embeddings.to(self.device)

        criterion = VGGDistance(self.config['perceptual_loss']['layers']).to(
            self.device)

        optimizer = Adam([{
            'params': self.model.encoders.parameters(),
            'lr': self.config['train_encoders']['learning_rate']
        }, {
            'params': self.model.modulation.parameters(),
            'lr': self.config['train_encoders']['learning_rate']
        }, {
            'params': self.model.generator.parameters(),
            'lr': self.config['train_encoders']['learning_rate']
        }],
                         betas=(0.5, 0.999))

        scheduler = ReduceLROnPlateau(optimizer,
                                      mode='min',
                                      factor=0.5,
                                      patience=20,
                                      verbose=1)

        with SummaryWriter(
                log_dir=os.path.join(tensorboard_dir, 'stage2')) as summary:
            train_loss = AverageMeter()
            for epoch in range(1, self.config['train']['n_epochs'] + 1):
                self.model.train()
                train_loss.reset()

                with tqdm(iterable=data_loader) as pbar:
                    for batch in pbar:
                        batch_imgs, batch_img_ids, batch_class_ids = (
                            tensor.to(self.device) for tensor in batch)
                        batch_content_codes, batch_class_codes = self.embeddings(
                            batch_img_ids, batch_class_ids)

                        generated_imgs, predicted_content_codes, predicted_class_codes = self.model(
                            batch_imgs)

                        optimizer.zero_grad()

                        perc_loss = criterion(generated_imgs, batch_imgs)
                        loss_content = F.mse_loss(batch_content_codes,
                                                  predicted_content_codes)
                        loss_class = F.mse_loss(batch_class_codes,
                                                predicted_class_codes)
                        loss = perc_loss + 10 * loss_content + 10 * loss_class
                        loss.backward()

                        optimizer.step()

                        train_loss.update(loss.item())
                        pbar.set_description_str('epoch #{}'.format(epoch))
                        pbar.set_postfix(loss=train_loss.avg)

                torch.save(self.model.encoders.state_dict(),
                           os.path.join(model_dir, 'encoders.pth'))
                torch.save(self.model.generator.state_dict(),
                           os.path.join(model_dir, 'generator.pth'))
                torch.save(self.model.modulation.state_dict(),
                           os.path.join(model_dir, 'class_modulation.pth'))

                scheduler.step(train_loss.avg)

                self.model.eval()
                fixed_sample_img = self.encoder_evaluate(imgs,
                                                         randomized=False)
                random_sample_img = self.encoder_evaluate(imgs,
                                                          randomized=True)

                summary.add_scalar(tag='loss',
                                   scalar_value=train_loss.avg,
                                   global_step=epoch)
                summary.add_image(tag='sample-fixed',
                                  img_tensor=fixed_sample_img,
                                  global_step=epoch)
                summary.add_image(tag='sample-random',
                                  img_tensor=random_sample_img,
                                  global_step=epoch)