def make_pairs(img_original, arg): bn, c, h, w = img_original.shape # Make image and grid tps_param_dic = tps_parameters(bn, arg.scal, 0., 0., 0., 0., arg.augm_scal) coord, vector = make_input_tps_param(tps_param_dic) coord, vector = coord.to(arg.device), vector.to(arg.device) img, mesh = ThinPlateSpline(img_original, coord, vector, arg.reconstr_dim, device=arg.device) # Make transformed image and grid tps_param_dic_rot = tps_parameters(bn, arg.scal, arg.tps_scal, arg.rot_scal, arg.off_scal, arg.scal_var, arg.augm_scal) coord_rot, vector_rot = make_input_tps_param(tps_param_dic_rot) coord_rot, vector_rot = coord_rot.to(arg.device), vector_rot.to(arg.device) img_rot, mesh_rot = ThinPlateSpline(img_original, coord_rot, vector_rot, arg.reconstr_dim, device=arg.device) # Make augmentation img_stack = torch.cat([img, img_rot], dim=0) img_stack_augm = augm(img_stack, arg, arg.device) img_augm, img_rot_augm = img_stack_augm[:bn], img_stack_augm[bn:] # Make input stack input_images = F.interpolate(torch.cat([img_augm, img_rot], dim=0), size=arg.reconstr_dim).clamp(min=0., max=1.) reconstr_images = F.interpolate(torch.cat([img, img_rot_augm], dim=0), size=arg.reconstr_dim).clamp(min=0., max=1.) mesh_stack = torch.cat([mesh, mesh_rot], dim=0) return input_images, reconstr_images, mesh_stack
def forward(self, x): # tps image_orig = x.repeat(2, 1, 1, 1) tps_param_dic = tps_parameters(image_orig.shape[0], self.scal, self.tps_scal, self.rot_scal, self.off_scal, self.scal_var, self.augm_scal) coord, vector = make_input_tps_param(tps_param_dic) coord, vector = coord.to(self.device), vector.to(self.device) t_images, t_mesh = ThinPlateSpline(image_orig, coord, vector, 128, device=self.device) image_in, image_rec = prepare_pairs(t_images, self.arg) transform_mesh = F.interpolate(t_mesh, size=64) volume_mesh = AbsDetJacobian(transform_mesh, self.device) # encoding _, part_maps, sum_part_maps = self.E_sigma(image_in) mu, L_inv = get_mu_and_prec(part_maps, self.device, self.L_inv_scal) heat_map = get_heat_map(mu, L_inv, self.device) raw_features = self.E_alpha(sum_part_maps) features = get_local_part_appearances(raw_features, part_maps) # transform integrant = (part_maps.unsqueeze(-1) * volume_mesh.unsqueeze(-1)).squeeze() integrant = integrant / torch.sum(integrant, dim=[2, 3], keepdim=True) mu_t = contract('akij, alij -> akl', integrant, transform_mesh) transform_mesh_out_prod = contract('amij, anij -> amnij', transform_mesh, transform_mesh) mu_out_prod = contract('akm, akn -> akmn', mu_t, mu_t) stddev_t = contract('akij, amnij -> akmn', integrant, transform_mesh_out_prod) - mu_out_prod # processing encoding = feat_mu_to_enc(features, mu, L_inv, self.device, self.covariance) reconstruct_same_id = self.decoder(encoding) loss = nn.MSELoss()(image_rec, reconstruct_same_id) if self.mode == 'predict': return image_in, image_rec, mu, heat_map elif self.mode == 'train': return reconstruct_same_id, loss
def main(arg): os.environ["CUDA_VISIBLE_DEVICES"] = str(arg.gpu) model_save_dir = "./experiments/" + arg.name + "/" with tf.variable_scope("Data_prep"): if arg.mode == 'train': raw_dataset = dataset_map_train[arg.dataset](arg) elif arg.mode == 'predict': raw_dataset = dataset_map_test[arg.dataset](arg) dataset = raw_dataset.map(load_and_preprocess_image, num_parallel_calls=arg.data_parallel_calls) dataset = dataset.batch(arg['bn'], drop_remainder=True).repeat(arg.epochs) iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() b_images = next_element orig_images = tf.tile(b_images, [2, 1, 1, 1]) scal = tf.placeholder(dtype=tf.float32, shape=(), name='scal_placeholder') tps_scal = tf.placeholder(dtype=tf.float32, shape=(), name='tps_placeholder') rot_scal = tf.placeholder(dtype=tf.float32, shape=(), name='rot_scal_placeholder') off_scal = tf.placeholder(dtype=tf.float32, shape=(), name='off_scal_placeholder') scal_var = tf.placeholder(dtype=tf.float32, shape=(), name='scal_var_placeholder') augm_scal = tf.placeholder(dtype=tf.float32, shape=(), name='augm_scal_placeholder') tps_param_dic = tps_parameters(2 * arg.bn, scal, tps_scal, rot_scal, off_scal, scal_var) tps_param_dic.augm_scal = augm_scal ctr = 0 config = tf.ConfigProto() config.gpu_options.allow_growth = True config.gpu_options.per_process_gpu_memory_fraction = 0.95 with tf.Session(config=config) as sess: model = Model(orig_images, arg, tps_param_dic) tvar = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) saver = tf.train.Saver(var_list=tvar) merged = tf.summary.merge_all() if arg.mode == 'train': if arg.load: ckpt, ctr = find_ckpt(model_save_dir + 'saved_model/') saver.restore(sess, ckpt) else: save_python_files(save_dir=model_save_dir + 'bin/') write_hyperparameters(arg.toDict(), model_save_dir) sess.run(tf.global_variables_initializer()) writer = tf.summary.FileWriter("./summaries/" + arg.name, graph=sess.graph) elif arg.mode == 'predict': ckpt, ctr = find_ckpt(model_save_dir + 'saved_model/') saver.restore(sess, ckpt) initialize_uninitialized(sess) while True: print('iteration %d' %ctr) try: feed = transformation_parameters(arg, ctr, no_transform=(arg.mode == 'predict')) # no transform if arg.visualize trf = {scal: feed.scal, tps_scal: feed.tps_scal, scal_var: feed.scal_var, rot_scal: feed.rot_scal, off_scal: feed.off_scal, augm_scal: feed.augm_scal} ctr += 1 if arg.mode == 'train': if np.mod(ctr, arg.summary_interval) == 0: merged_summary = sess.run(merged, feed_dict=trf) writer.add_summary(merged_summary, global_step=ctr) _, loss = sess.run([model.optimize, model.loss], feed_dict=trf) if np.mod(ctr, arg.save_interval) == 0: saver.save(sess, model_save_dir + '/saved_model/' + 'save_net.ckpt', global_step=ctr) elif arg.mode == 'predict': img, img_rec, mu, heat_raw = sess.run([model.image_in, model.reconstruct_same_id, model.mu, batch_colour_map(model.part_maps)], feed_dict=trf) save(img, mu, ctr) except tf.errors.OutOfRangeError: print("End of training.") break
def forward(self, x): batch_size = x.shape[0] batch_size2 = 2 * x.shape[0] # tps image_orig = x.repeat(2, 1, 1, 1) tps_param_dic = tps_parameters(batch_size2, self.scal, self.tps_scal, self.rot_scal, self.off_scal, self.scal_var, self.augm_scal) coord, vector = make_input_tps_param(tps_param_dic) coord, vector = coord.to(self.device), vector.to(self.device) t_images, t_mesh = ThinPlateSpline(image_orig, coord, vector, self.reconstr_dim, device=self.device) image_in, image_rec = prepare_pairs(t_images, self.arg, self.device) transform_mesh = F.interpolate(t_mesh, size=64) volume_mesh = AbsDetJacobian(transform_mesh, self.device) # encoding part_maps_raw, part_maps_norm, sum_part_maps = self.E_sigma(image_in) mu, L_inv = get_mu_and_prec(part_maps_norm, self.device, self.L_inv_scal) raw_features = self.E_alpha(sum_part_maps) features = get_local_part_appearances(raw_features, part_maps_norm) heat_map = get_heat_map(mu, L_inv, self.device, self.background) norm = torch.sum(heat_map, 1, keepdim=True) + 1 heat_map_norm = heat_map / norm # transform integrant = (part_maps_norm.unsqueeze(-1) * volume_mesh.unsqueeze(-1)).squeeze() integrant = integrant / torch.sum(integrant, dim=[2, 3], keepdim=True) mu_t = contract('akij, alij -> akl', integrant, transform_mesh) transform_mesh_out_prod = contract('amij, anij -> amnij', transform_mesh, transform_mesh) mu_out_prod = contract('akm, akn -> akmn', mu_t, mu_t) stddev_t = contract('akij, amnij -> akmn', integrant, transform_mesh_out_prod) - mu_out_prod # processing encoding = feat_mu_to_enc(features, mu, L_inv, self.device, self.reconstr_dim, self.background) reconstruct_same_id = self.decoder(encoding) total_loss, rec_loss, transform_loss, precision_loss = loss_fn( batch_size, mu, L_inv, mu_t, stddev_t, reconstruct_same_id, image_rec, self.l_2_scal, self.l_2_threshold, self.L_mu, self.L_cov, self.L_rec, self.device, self.background, True) # norms original_part_maps_raw, original_part_maps_norm, original_sum_part_maps = self.E_sigma( x) mu_original, L_inv_original = get_mu_and_prec(original_part_maps_norm, self.device, self.L_inv_scal) if self.mode == 'predict': return image_rec, reconstruct_same_id, mu, L_inv, part_maps_norm, heat_map, heat_map_norm, total_loss elif self.mode == 'train': return image_rec, reconstruct_same_id, total_loss, rec_loss, transform_loss, precision_loss, mu[:, : -1], L_inv[:, : -1], mu_original[:, : -1]
def main(arg): # Set random seeds torch.manual_seed(7) torch.cuda.manual_seed(7) np.random.seed(7) # Get args bn = arg.bn mode = arg.mode name = arg.name load_from_ckpt = arg.load_from_ckpt lr = arg.lr epochs = arg.epochs device = torch.device('cuda:' + str(arg.gpu) if torch.cuda.is_available() else 'cpu') arg.device = device if mode == 'train': # Make new directory model_save_dir = '../results/' + name if not os.path.exists(model_save_dir): os.makedirs(model_save_dir) os.makedirs(model_save_dir + '/summary') # Save Hyperparameters write_hyperparameters(arg.toDict(), model_save_dir) # Define Model & Optimizer model = Model(arg).to(device) if load_from_ckpt: model = load_model(model, model_save_dir).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=lr) # Log with wandb wandb.init(project='Disentanglement', config=arg, name=arg.name) wandb.watch(model, log='all') # Load Datasets and DataLoader train_data, test_data = load_deep_fashion_dataset() train_dataset = ImageDataset(np.array(train_data)) test_dataset = ImageDataset(np.array(test_data)) train_loader = DataLoader(train_dataset, batch_size=bn, shuffle=True, num_workers=4) test_loader = DataLoader(test_dataset, batch_size=bn, num_workers=4) # Make Training with torch.autograd.set_detect_anomaly(False): for epoch in range(epochs + 1): # Train on Train Set model.train() model.mode = 'train' for step, original in enumerate(train_loader): original = original.to(device) # Make transformations tps_param_dic = tps_parameters(original.shape[0], arg.scal, arg.tps_scal, arg.rot_scal, arg.off_scal, arg.scal_var, arg.augm_scal) coord, vector = make_input_tps_param(tps_param_dic) coord, vector = coord.to(device), vector.to(device) image_spatial_t, _ = ThinPlateSpline( original, coord, vector, original.shape[3], device) image_appearance_t = K.ColorJitter(arg.brightness, arg.contrast, arg.saturation, arg.hue)(original) image_spatial_t, image_appearance_t = normalize( image_spatial_t), normalize(image_appearance_t) reconstruction, loss, rec_loss, equiv_loss, mu, L_inv = model( original, image_spatial_t, image_appearance_t, coord, vector) mu_norm = torch.mean(torch.norm( mu, p=1, dim=2)).cpu().detach().numpy() L_inv_norm = torch.mean( torch.linalg.norm(L_inv, ord='fro', dim=[2, 3])).cpu().detach().numpy() wandb.log({"Part Means": mu_norm}) wandb.log({"Precision Matrix": L_inv_norm}) # Zero out gradients optimizer.zero_grad() loss.backward() optimizer.step() # Track Loss if step == 0: loss_log = torch.tensor([loss]) rec_loss_log = torch.tensor([rec_loss]) else: loss_log = torch.cat([loss_log, torch.tensor([loss])]) rec_loss_log = torch.cat( [rec_loss_log, torch.tensor([rec_loss])]) training_loss = torch.mean(loss_log) training_rec_loss = torch.mean(rec_loss_log) wandb.log({"Training Loss": training_loss}) wandb.log({"Training Rec Loss": training_rec_loss}) print(f'Epoch: {epoch}, Train Loss: {training_loss}') # Evaluate on Test Set model.eval() for step, original in enumerate(test_loader): with torch.no_grad(): original = original.to(device) tps_param_dic = tps_parameters(original.shape[0], arg.scal, arg.tps_scal, arg.rot_scal, arg.off_scal, arg.scal_var, arg.augm_scal) coord, vector = make_input_tps_param(tps_param_dic) coord, vector = coord.to(device), vector.to(device) image_spatial_t, _ = ThinPlateSpline( original, coord, vector, original.shape[3], device) image_appearance_t = K.ColorJitter( arg.brightness, arg.contrast, arg.saturation, arg.hue)(original) image_spatial_t, image_appearance_t = normalize( image_spatial_t), normalize(image_appearance_t) reconstruction, loss, rec_loss, equiv_loss, mu, L_inv = model( original, image_spatial_t, image_appearance_t, coord, vector) if step == 0: loss_log = torch.tensor([loss]) else: loss_log = torch.cat( [loss_log, torch.tensor([loss])]) evaluation_loss = torch.mean(loss_log) wandb.log({"Evaluation Loss": evaluation_loss}) print(f'Epoch: {epoch}, Test Loss: {evaluation_loss}') # Track Progress if True: model.mode = 'predict' original, fmap_shape, fmap_app, reconstruction = model( original, image_spatial_t, image_appearance_t, coord, vector) make_visualization(original, reconstruction, image_spatial_t, image_appearance_t, fmap_shape, fmap_app, model_save_dir, epoch, device) save_model(model, model_save_dir) elif mode == 'predict': # Make Directory for Predictions model_save_dir = '../results/' + name if not os.path.exists(model_save_dir + '/predictions'): os.makedirs(model_save_dir + '/predictions') # Load Model and Dataset model = Model(arg).to(device) model = load_model(model, model_save_dir).to(device) data = load_deep_fashion_dataset() test_data = np.array(data[-4:]) test_dataset = ImageDataset(test_data) test_loader = DataLoader(test_dataset, batch_size=bn) model.mode = 'predict' model.eval() # Predict on Dataset for step, original in enumerate(test_loader): with torch.no_grad(): original = original.to(device) tps_param_dic = tps_parameters(original.shape[0], arg.scal, arg.tps_scal, arg.rot_scal, arg.off_scal, arg.scal_var, arg.augm_scal) coord, vector = make_input_tps_param(tps_param_dic) coord, vector = coord.to(device), vector.to(device) image_spatial_t, _ = ThinPlateSpline(original, coord, vector, original.shape[3], device) image_appearance_t = K.ColorJitter(arg.brightness, arg.contrast, arg.saturation, arg.hue)(original) image, reconstruction, mu, shape_stream_parts, heat_map = model( original, image_spatial_t, image_appearance_t, coord, vector)