def __init__(self, options): super(Stage2Trainer, self).__init__(options, subfolders = ['samples', 'reconstructions'], copy_keys = copy_keys) transforms = [] if options.crop_size is not None: transforms.append(T.CenterCrop(options.crop_size)) transforms.append(T.Resize(options.image_size)) transforms.append(T.ToTensor()) transforms.append(T.Normalize((0.5, 0.5, 0.5, 0), (0.5, 0.5, 0.5, 1))) image_set = ImageFolder(options.data_root, transform = T.Compose(transforms)) enc_codes = torch.load(os.path.join(options.enc_path, 'codes', '{0}_codes.pt'.format(options.enc_iter))) code_set = torch.utils.data.TensorDataset(enc_codes[:, 0], enc_codes[:, 1]) self.dataset = ParallelDataset(image_set, code_set) self.dataloader = torch.utils.data.DataLoader(self.dataset, batch_size = options.batch_size, shuffle = True, drop_last = True, num_workers = options.nloader) self.data_iter = iter(self.dataloader) enc_stats = torch.load(os.path.join(options.enc_path, 'codes', '{0}_stats.pt'.format(options.enc_iter))) self.con_full_mean = enc_stats['full_mean'] self.con_full_std = enc_stats['full_std'] self.con_eigval = enc_stats['eigval'] self.con_eigvec = enc_stats['eigvec'] self.dim_weight = enc_stats['dim_weight'] if self.con_weight > 0: self.enc = models.Encoder(options.image_size, options.image_size, options.enc_features, options.enc_blocks, options.enc_adain_features, options.enc_adain_blocks, options.content_size) self.enc.to(self.device) self.enc.load_state_dict(torch.load(os.path.join(options.enc_path, 'models', '{0}_enc.pt'.format(options.enc_iter)), map_location = self.device)) self.gen = models.TwoPartNestedDropoutGenerator(options.image_size, options.image_size, options.gen_features, options.gen_blocks, options.gen_adain_features, options.gen_adain_blocks, options.content_size, options.style_size) self.gen.to(self.device) if (self.load_path is None) and not options.reset_gen: self.gen.load_state_dict(torch.load(os.path.join(options.enc_path, 'models', '{0}_gen.pt'.format(options.enc_iter)), map_location = self.device)) self.gen_optim = optim.RMSprop(self.gen.parameters(), lr = self.lr, eps = 1e-4) self.add_model('gen', self.gen, self.gen_optim) self.cla = models.ClassifierOrDiscriminator(options.image_size, options.image_size, options.cla_features, options.cla_blocks, options.cla_adain_features, options.cla_adain_blocks, self.nclass) self.cla.to(self.device) if (self.load_path is None) and (options.cla_path is not None) and not options.reset_cla: self.cla.load_state_dict(torch.load(os.path.join(options.cla_path, 'models', '{0}_cla.pt'.format(options.cla_iter)), map_location = self.device)) self.cla_optim = optim.RMSprop(self.cla.parameters(), lr = self.lr, eps = 1e-4) self.add_model('cla', self.cla, self.cla_optim) self.dis = models.ClassifierOrDiscriminator(options.image_size, options.image_size, options.dis_features, options.dis_blocks, options.dis_adain_features, options.dis_adain_blocks, self.nclass) self.dis.to(self.device) if (self.load_path is None) and (options.cla_path is not None) and not options.reset_dis: self.dis.load_state_dict(torch.load(os.path.join(options.cla_path, 'models', '{0}_cla.pt'.format(options.cla_iter)), map_location = self.device)) self.dis.convert() self.dis_optim = optim.RMSprop(self.dis.parameters(), lr = self.lr, eps = 1e-4) self.add_model('dis', self.dis, self.dis_optim) self.sty = models.NormalizedStyleBank(self.nclass, options.style_size, image_set.get_class_freq()) self.sty.to(self.device) if (self.load_path is None) and not options.reset_sty: self.sty.load_state_dict(torch.load(os.path.join(options.enc_path, 'models', '{0}_sty.pt'.format(options.enc_iter)), map_location = self.device)) self.sty_optim = optim.Adam(self.sty.parameters(), lr = self.sty_lr, eps = 1e-8) self.add_model('sty', self.sty, self.sty_optim) if self.load_path is not None: rec_images = torch.load(os.path.join(self.load_path, 'reconstructions', 'images.pt'), map_location = self.device) self.rec_codes = torch.load(os.path.join(self.load_path, 'reconstructions', 'codes.pt'), map_location = self.device) self.rec_labels = torch.load(os.path.join(self.load_path, 'reconstructions', 'labels.pt'), map_location = self.device) self.vis_codes = torch.load(os.path.join(self.load_path, 'samples', 'codes.pt'), map_location = self.device) self.vis_style_noise = torch.load(os.path.join(self.load_path, 'samples', 'style_noise.pt'), map_location = self.device) self.load(options.load_iter) else: rec_images = [] rec_codes = [] rec_labels = [] rec_index = random.sample(range(len(self.dataset)), options.vis_row * options.vis_col) for k in rec_index: image, label, code, _ = self.dataset[k] rec_images.append(image) rec_codes.append(code) rec_labels.append(label) rec_images = torch.stack(rec_images, dim = 0) self.rec_codes = torch.stack(rec_codes, dim = 0).to(self.device) self.rec_labels = one_hot(torch.tensor(rec_labels, dtype = torch.int32), self.nclass).to(self.device) self.vis_codes = self.noise_to_con_code(gaussian_noise(options.vis_row * options.vis_col, options.content_size)).to(self.device) self.vis_style_noise = gaussian_noise(options.vis_row * options.vis_col, options.style_size).to(self.device) self.state.dis_total_batches = 0 if self.save_path != self.load_path: torch.save(rec_images, os.path.join(self.save_path, 'reconstructions', 'images.pt')) torch.save(self.rec_codes, os.path.join(self.save_path, 'reconstructions', 'codes.pt')) torch.save(self.rec_labels, os.path.join(self.save_path, 'reconstructions', 'labels.pt')) torch.save(self.vis_codes, os.path.join(self.save_path, 'samples', 'codes.pt')) torch.save(self.vis_style_noise, os.path.join(self.save_path, 'samples', 'style_noise.pt')) save_image(rec_images.add(1).div(2), os.path.join(self.save_path, 'reconstructions', 'target.png'), self.vis_col) self.add_periodic_func(self.visualize_fixed, options.visualize_iter) self.visualize_fixed() self.loss_avg_factor = 0.9 self.sty_drop_prob = torch.Tensor(options.style_size) for i in range(options.style_size): self.sty_drop_prob[i] = options.style_dropout ** i
def __init__(self, options): super(Stage1Trainer, self).__init__(options, subfolders=['reconstructions'], copy_keys=copy_keys) transforms = [] if options.crop_size is not None: transforms.append(T.CenterCrop(options.crop_size)) transforms.append(T.Resize(options.image_size)) transforms.append(T.ToTensor()) image_transforms = transforms + [ T.Normalize((0.5, 0.5, 0.5, 0), (0.5, 0.5, 0.5, 1)) ] image_set = ImageFolder(options.data_root, transform=T.Compose(image_transforms)) if options.weight_root is not None: self.has_weight = True weight_transforms = transforms + [lambda x: x[0]] weight_set = ImageFolder(options.weight_root, transform=T.Compose(weight_transforms)) self.dataset = ParallelDataset(image_set, weight_set) else: self.has_weight = False self.dataset = image_set self.dataloader = torch.utils.data.DataLoader( self.dataset, batch_size=options.batch_size, shuffle=True, drop_last=True, num_workers=options.nloader) self.data_iter = iter(self.dataloader) self.enc = models.Encoder(options.image_size, options.image_size, options.enc_features, options.enc_blocks, options.enc_adain_features, options.enc_adain_blocks, options.content_size) self.enc.to(self.device) self.enc_optim = optim.Adam(self.enc.parameters(), lr=self.lr, eps=1e-4) self.add_model('enc', self.enc, self.enc_optim) self.gen = models.TwoPartNestedDropoutGenerator( options.image_size, options.image_size, options.gen_features, options.gen_blocks, options.gen_adain_features, options.gen_adain_blocks, options.content_size, options.style_size) self.gen.to(self.device) self.gen_optim = optim.Adam(self.gen.parameters(), lr=self.lr, eps=1e-4) self.add_model('gen', self.gen, self.gen_optim) if self.mlp: self.cla = models.MLPClassifier(options.content_size, options.mlp_features, options.mlp_layers, self.nclass) else: self.cla = models.ClassifierOrDiscriminator( options.image_size, options.image_size, options.cla_features, options.cla_blocks, options.cla_adain_features, options.cla_adain_blocks, self.nclass) self.cla.to(self.device) self.cla_optim = optim.Adam(self.cla.parameters(), lr=self.lr, eps=1e-4) self.add_model('cla', self.cla, self.cla_optim) self.sty = models.NormalizedStyleBank(self.nclass, options.style_size, image_set.get_class_freq()) self.sty.to(self.device) self.sty_optim = optim.Adam(self.sty.parameters(), lr=self.sty_lr, eps=1e-8) self.add_model('sty', self.sty, self.sty_optim) if self.load_path is not None: self.vis_images = torch.load(os.path.join(self.load_path, 'reconstructions', 'images.pt'), map_location=self.device) self.vis_labels = torch.load(os.path.join(self.load_path, 'reconstructions', 'labels.pt'), map_location=self.device) if self.has_weight: self.vis_weights = torch.load(os.path.join( self.load_path, 'reconstructions', 'weights.pt'), map_location=self.device) self.load(options.load_iter) else: vis_images = [] vis_labels = [] if self.has_weight: vis_weights = [] vis_index = random.sample(range(len(image_set)), options.vis_row * options.vis_col) for k in vis_index: image, label = image_set[k] vis_images.append(image) vis_labels.append(label) if self.has_weight: weight, _ = weight_set[k] vis_weights.append(weight) self.vis_images = torch.stack(vis_images, dim=0).to(self.device) self.vis_labels = one_hot( torch.tensor(vis_labels, dtype=torch.int32), self.nclass).to(self.device) if self.has_weight: self.vis_weights = torch.stack(vis_weights, dim=0).to(self.device) if self.save_path != self.load_path: torch.save( self.vis_images, os.path.join(self.save_path, 'reconstructions', 'images.pt')) torch.save( self.vis_labels, os.path.join(self.save_path, 'reconstructions', 'labels.pt')) save_image( self.vis_images.add(1).div(2), os.path.join(self.save_path, 'reconstructions', 'target.png'), self.vis_col) if self.has_weight: torch.save( self.vis_weights, os.path.join(self.save_path, 'reconstructions', 'weights.pt')) save_image( self.vis_weights.unsqueeze(1), os.path.join(self.save_path, 'reconstructions', 'weight.png'), self.vis_col) self.add_periodic_func(self.visualize_fixed, options.visualize_iter) self.visualize_fixed() self.con_drop_prob = torch.Tensor(options.content_size) for i in range(options.content_size): self.con_drop_prob[i] = options.content_dropout**i self.sty_drop_prob = torch.Tensor(options.style_size) for i in range(options.style_size): self.sty_drop_prob[i] = options.style_dropout**i