def main(*args): dim = 5000 intrinsic_dim = 50 num_points = 10000 num_queries = 5 num_neighbours = 10 # The k in k-NN # Guide for tuning hyperparameters: # num_comp_indices trades off accuracy vs. construction and query time - high values lead to more accurate results, but slower construction and querying # num_simp_indices trades off accuracy vs. construction and query time - high values lead to more accurate results, but slower construction and querying; if num_simp_indices is increased, may need to increase num_comp_indices # num_levels trades off construction time vs. query time - higher values lead to faster querying, but slower construction; if num_levels is increased, may need to increase query_field_of_view and construction_field_of_view # construction_field_of_view trades off accuracy/query time vs. construction time - higher values lead to *slightly* more accurate results and/or *slightly* faster querying, but *slightly* slower construction # construction_prop_to_retrieve trades off acrruacy vs. construction time - higher values lead to *slightly* more accurate results, but slower construction # query_field_of_view trades off accuracy vs. query time - higher values lead to more accurate results, but *slightly* slower querying # query_prop_to_retrieve trades off accuracy vs. query time - higher values lead to more accurate results, but slower querying num_comp_indices = 2 num_simp_indices = 7 num_levels = 2 construction_field_of_view = 10 construction_prop_to_retrieve = 0.002 query_field_of_view = 100 query_prop_to_retrieve = 0.05 print("Generating Data... ") t0 = time() data_and_queries = gen_data(dim, intrinsic_dim, num_points + num_queries) data = np.copy(data_and_queries[:num_points, :]) queries = data_and_queries[num_points:, :] print("Took %.4fs" % (time() - t0)) print("Constructing Data Structure... ") t0 = time() dci_db = DCI(dim, num_comp_indices, num_simp_indices) dci_db.add(data, num_levels=num_levels, field_of_view=construction_field_of_view, prop_to_retrieve=construction_prop_to_retrieve) print("Took %.4fs" % (time() - t0)) print("Querying... ") t0 = time() nearest_neighbour_idx, nearest_neighbour_dists = dci_db.query( queries, num_neighbours=num_neighbours, field_of_view=query_field_of_view, prop_to_retrieve=query_prop_to_retrieve, blind=True) print("Took %.4fs" % (time() - t0)) print(nearest_neighbour_idx) print(nearest_neighbour_dists)
class IMLE(): def __init__(self, z_dim, Sx_dim): self.z_dim = z_dim self.Sx_dim = Sx_dim self.model = ConvolutionalImplicitModel(z_dim+Sx_dim, 0.5).cuda() self.model.apply(self.model.get_initializer()) self.dci_db = None #----------------------------------------------------------------------------------------------------------- # load pre-trained model state_dict = torch.load("../net_weights_2D_scattering_J=6_L=2_times=10.pth") self.model.load_state_dict(state_dict) #----------------------------------------------------------------------------------------------------------- #def train(self, data_np, data_Sx, name_JL, base_lr=1e-4, batch_size=128, num_epochs=3000,\ # decay_step=25, decay_rate=0.95, staleness=100, num_samples_factor=100): def train(self, data_np, data_Sx, name_JL, base_lr=1e-5, batch_size=128, num_epochs=3000,\ decay_step=25, decay_rate=0.98, staleness=100, num_samples_factor=100): # define metric loss_fn = nn.MSELoss().cuda() self.model.train() # train in batch num_batches = data_np.shape[0] // batch_size # truncate data to fit the batch size num_data = num_batches*batch_size data_np = data_np[:num_data] data_Sx = data_Sx[:num_data] # make it in 1D data image for DCI data_flat_np = np.reshape(data_np, (data_np.shape[0], np.prod(data_np.shape[1:]))) #----------------------------------------------------------------------------------------------------------- # make empty array to store results samples_predict = np.empty(data_np.shape) samples_np = np.empty((num_samples_factor,)+data_np.shape[1:]) #----------------------------------------------------------------------------------------------------------- # make global torch variables data_all = torch.from_numpy(data_np).float().cuda() Sx = torch.from_numpy(np.repeat(data_Sx,num_samples_factor,axis=0)).float().cuda() #----------------------------------------------------------------------------------------------------------- # initiate dci if self.dci_db is None: self.dci_db = DCI(np.prod(data_np.shape[1:]), num_comp_indices = 2, num_simp_indices = 7) #============================================================================================================= # train through various epochs for epoch in range(num_epochs): # decay the learning rate if epoch % decay_step == 0: lr = base_lr * decay_rate ** (epoch // decay_step) optimizer = optim.Adam(self.model.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=1e-5) #----------------------------------------------------------------------------------------------------------- # update the closest models routintely if epoch % staleness == 0: # draw random z z = torch.randn(num_data*num_samples_factor, self.z_dim, 1, 1).cuda() z_Sx_all = torch.cat((z, Sx), axis=1) # find the closest object for individual data nearest_indices = np.empty((num_data)).astype("int") for i in range(num_data): samples = self.model(z_Sx_all[i*num_samples_factor:(i+1)*num_samples_factor]) samples_np[:] = samples.cpu().data.numpy() samples_flat_np = np.reshape(samples_np, (samples_np.shape[0], np.prod(samples_np.shape[1:]))) #----------------------------------------------------------------------------------------------------------- # find the nearest neighbours self.dci_db.reset() self.dci_db.add(np.copy(samples_flat_np),\ num_levels = 2, field_of_view = 10, prop_to_retrieve = 0.002) nearest_indices_temp, _ = self.dci_db.query(data_flat_np[i:i+1],\ num_neighbours = 1, field_of_view = 20, prop_to_retrieve = 0.02) nearest_indices[i] = nearest_indices_temp[0][0] + i*num_samples_factor if epoch == 0: print(np.percentile(data_flat_np, 25), np.percentile(data_flat_np, 50), np.percentile(data_flat_np, 75)) print(np.percentile(samples_flat_np, 25), np.percentile(samples_flat_np, 50), np.percentile(samples_flat_np, 75)) # restrict latent parameters to the nearest neighbour z_Sx = z_Sx_all[nearest_indices] #============================================================================================================= # gradient descent err = 0. # loop over all batches for i in range(num_batches): self.model.zero_grad() cur_samples = self.model(z_Sx[i*batch_size:(i+1)*batch_size]) # save the mock sample if (epoch+1) % staleness == 0: samples_predict[i*batch_size:(i+1)*batch_size] = cur_samples.cpu().data.numpy() # gradient descent loss = loss_fn(cur_samples, data_all[i*batch_size:(i+1)*batch_size]) loss.backward() err += loss.item() optimizer.step() print("Epoch %d: Error: %f" % (epoch, err / num_batches)) #----------------------------------------------------------------------------------------------------------- # save the mock sample if (epoch+1) % staleness == 0: #np.savez("../results_2D_times=10_J=4_L=2_epoch=" + str(epoch) + ".npz", data_np=data_np,\ # z_Sx_np=z_Sx.cpu().data.numpy(),\ # samples_np=samples_predict) # make random mock samples_random = self.model(z_Sx_all[:10**4][::100]).cpu().data.numpy() np.savez("../results_2D_random_times=10_" + name_JL + "_epoch=" + str(epoch) + ".npz",\ samples_np=samples_random,\ mse_err=err / num_batches) # save network torch.save(self.model.state_dict(), '../net_weights_2D_times=10_' + name_JL + '_epoch=' \ + str(epoch) + '.pth')
class IMLE(): def __init__(self, z_dim): self.z_dim = z_dim #self.model = ConvolutionalImplicitModel(z_dim).cpu() self.model = BigGAN.Generator(resolution=Resolution, dim_z=z_dim, n_classes=Classes).cpu() self.dci_db = None def train(self, data_np, label_np, hyperparams, shuffle_data=True, path='results'): loss_fn = nn.MSELoss().cpu() self.model.train() batch_size = hyperparams.batch_size num_batches = data_np.shape[0] // batch_size num_samples = num_batches * hyperparams.num_samples_factor grid_size = (5, 5) grid_z = torch.randn(np.prod(grid_size), self.z_dim).cpu() grid_y = torch.randint(low=0, high=Classes, size=(np.prod(grid_size), ), dtype=torch.int64).cpu() if shuffle_data: data_ordering = np.random.permutation(data_np.shape[0]) data_np = data_np[data_ordering] label_np = label_np[data_ordering] data_flat_np = np.reshape( data_np, (data_np.shape[0], np.prod(data_np.shape[1:]))) if self.dci_db is None: self.dci_db = DCI(np.prod(data_np.shape[1:]), num_comp_indices=2, num_simp_indices=7) for epoch in range(hyperparams.num_epochs): if epoch % 10 == 0: print('Saving net_weights.pth...') torch.save(self.model.state_dict(), os.path.join(path, 'net_weights.pth')) if epoch % 5 == 0: print('Saving grid...') save_grid(path=path, index=epoch, count=np.prod(grid_size), imle=self, z=grid_z, y=grid_y, z_dim=self.z_dim) print('Saved') if epoch % hyperparams.decay_step == 0: lr = hyperparams.base_lr * hyperparams.decay_rate**( epoch // hyperparams.decay_step) optimizer = optim.Adam(self.model.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=1e-5) if epoch % hyperparams.staleness == 0: z_np = np.empty((num_samples * batch_size, self.z_dim)) y_np = np.empty((num_samples * batch_size, 1), dtype=np.int64) samples_np = np.empty((num_samples * batch_size, ) + data_np.shape[1:]) for i in range(num_samples): z = torch.randn(batch_size, self.z_dim, requires_grad=False).cpu() y = torch.randint(low=0, high=Classes, size=(batch_size, 1), dtype=torch.int64, requires_grad=False).cpu() samples = self.model(z, self.model.shared(y)) #import pdb; pdb.set_trace() z_np[i * batch_size:(i + 1) * batch_size] = z.cpu().data.numpy() y_np[i * batch_size:(i + 1) * batch_size] = y.cpu().data.numpy() samples_np[i * batch_size:(i + 1) * batch_size] = samples.cpu().data.numpy() samples_flat_np = np.reshape( samples_np, (samples_np.shape[0], np.prod(samples_np.shape[1:]))).copy() self.dci_db.reset() self.dci_db.add(samples_flat_np, num_levels=2, field_of_view=10, prop_to_retrieve=0.002) nearest_indices, _ = self.dci_db.query(data_flat_np, num_neighbours=1, field_of_view=20, prop_to_retrieve=0.02) nearest_indices = np.array(nearest_indices)[:, 0] z_np = z_np[nearest_indices] z_np += 0.01 * np.random.randn(*z_np.shape) y_np = y_np[nearest_indices] del samples_np, samples_flat_np err = 0. for i in range(num_batches): self.model.zero_grad() cur_z = torch.from_numpy(z_np[i * batch_size:(i + 1) * batch_size]).float().cpu() cur_y = torch.from_numpy(y_np[i * batch_size:(i + 1) * batch_size]).long().cpu() cur_data = torch.from_numpy(data_np[i * batch_size:(i + 1) * batch_size]).float().cpu() cur_samples = self.model(cur_z, self.model.shared(cur_y)) loss = loss_fn(cur_samples, cur_data) loss.backward() err += loss.item() optimizer.step() print("Epoch %d: Error: %f" % (epoch, err / num_batches))
def get_code_for_data(model, data, opt): options = opt['train'] dci_num_comp_indices = int(options['dci_num_comp_indices']) dci_num_simp_indices = int(options['dci_num_simp_indices']) dci_num_levels = int(options['dci_num_levels']) dci_construction_field_of_view = int(options['dci_construction_field_of_view']) dci_query_field_of_view = int(options['dci_query_field_of_view']) dci_prop_to_visit = float(options['dci_prop_to_visit']) dci_prop_to_retrieve = float(options['dci_prop_to_retrieve']) sample_perturbation_magnitude = float(options['sample_perturbation_magnitude']) code_nc = int(opt['network_G']['in_code_nc']) pull_num_sample_per_img = int(options['num_code_per_img']) show_message = False if 'show_message' not in options else options['show_message'] pull_gen_img = data['LR'] real_gen_img = data['HR'] pull_gen_code_0 = torch.empty(pull_gen_img.shape[0], code_nc, pull_gen_img.shape[2], pull_gen_img.shape[3]) if show_message: print("Generating Pull Samples") data_length = pull_gen_img.shape[0] out_feature_shape = model.netF(data['HR'][:1]).shape[1:] # initialize dci db pull_samples_dci_db = DCI(np.prod(out_feature_shape), dci_num_comp_indices, dci_num_simp_indices) for sample_index in range(data_length): if (sample_index + 1) % 10 == 0 and show_message: print_without_newline( '\rFinding first stack code: Processed %d out of %d instances' % ( sample_index + 1, data_length)) if 'zero_code' in options and options['zero_code']: pull_gen_code_pool_0 = torch.zeros(pull_num_sample_per_img, code_nc, pull_gen_img.shape[2], pull_gen_img.shape[3]) elif 'rand_code' in options and options['rand_code']: pull_gen_code_pool_0 = torch.rand(pull_num_sample_per_img, code_nc, pull_gen_img.shape[2], pull_gen_img.shape[3]) else: pull_gen_code_pool_0 = torch.randn(pull_num_sample_per_img, code_nc, pull_gen_img.shape[2], pull_gen_img.shape[3]) pull_img = pull_gen_img[sample_index].expand(pull_num_sample_per_img, -1, -1, -1) # cur_data = {'LR': pull_img, 'HR': real_gen_img[sample_index: sample_index + 1]} cur_data = {'LR': pull_img, 'HR': real_gen_img[sample_index: sample_index + 1].expand( (pull_num_sample_per_img,) + real_gen_img.shape[1:])} model.feed_data(cur_data, code=pull_gen_code_pool_0) feature_output = model.get_features() pull_gen_features_pool = feature_output['gen_feat'] target_feature = feature_output['gen_feat'] pull_gen_features_pool = pull_gen_features_pool.reshape(-1, np.prod( pull_gen_features_pool.shape[1:])).double().numpy().copy() target_feature = target_feature.reshape(-1, np.prod(target_feature.shape[1:])) pull_samples_dci_db.add(pull_gen_features_pool, num_levels=dci_num_levels, field_of_view=dci_construction_field_of_view, prop_to_visit=dci_prop_to_visit, prop_to_retrieve=dci_prop_to_retrieve) pull_sample_idx_for_img, _ = pull_samples_dci_db.query( target_feature.numpy(), num_neighbours=1, field_of_view=dci_query_field_of_view, prop_to_visit=dci_prop_to_visit, prop_to_retrieve=dci_prop_to_retrieve) pull_gen_code_0[sample_index, :] = pull_gen_code_pool_0[int(pull_sample_idx_for_img[0][0]), :] # clear the db for next query pull_samples_dci_db.clear() if show_message: print('\rFinding first stack code: Processed %d out of %d instances' % ( data_length, data_length)) if 'zero_code' in options and options['zero_code']: pull_gen_code_0 += sample_perturbation_magnitude * torch.zeros(pull_gen_img.shape[0], code_nc, pull_gen_img.shape[2], pull_gen_img.shape[3]) elif 'rand_code' in options and options['rand_code']: pull_gen_code_0 += sample_perturbation_magnitude * torch.rand(pull_gen_img.shape[0], code_nc, pull_gen_img.shape[2], pull_gen_img.shape[3]) else: pull_gen_code_0 += sample_perturbation_magnitude * torch.randn(pull_gen_img.shape[0], code_nc, pull_gen_img.shape[2], pull_gen_img.shape[3]) return pull_gen_code_0
def get_code_for_data_three(model, data, opt): options = opt['train'] dci_num_comp_indices = int(options['dci_num_comp_indices']) dci_num_simp_indices = int(options['dci_num_simp_indices']) dci_num_levels = int(options['dci_num_levels']) dci_construction_field_of_view = int(options['dci_construction_field_of_view']) dci_query_field_of_view = int(options['dci_query_field_of_view']) dci_prop_to_visit = float(options['dci_prop_to_visit']) dci_prop_to_retrieve = float(options['dci_prop_to_retrieve']) sample_perturbation_magnitude = float(options['sample_perturbation_magnitude']) code_nc = int(opt['network_G']['in_code_nc']) pull_num_sample_per_img = int(options['num_code_per_img']) pull_gen_img = data['LR'] d1_gen_img = data['D1'] d2_gen_img = data['D2'] real_gen_img = data['HR'] pull_gen_code_0 = torch.empty(pull_gen_img.shape[0], code_nc, pull_gen_img.shape[2] * 2, pull_gen_img.shape[3] * 2) forward_bs = 10 print("Generating Pull Samples") data_length = pull_gen_img.shape[0] out_feature_shape = model.netF(data['D1'][:1])[-1].shape[1:] # ============ ADD POINT ================== pull_samples_dci_db = DCI(np.prod(out_feature_shape), dci_num_comp_indices, dci_num_simp_indices) for sample_index in range(data_length): if (sample_index + 1) % 10 == 0: print_without_newline( '\rFinding first stack code: Processed %d out of %d instances' % ( sample_index + 1, data_length)) pull_gen_code_pool_0 = torch.randn(pull_num_sample_per_img, code_nc, pull_gen_img.shape[2] * 2, pull_gen_img.shape[3] * 2) pull_gen_code_pool_1 = torch.randn(pull_num_sample_per_img, code_nc, pull_gen_img.shape[2] * 4, pull_gen_img.shape[3] * 4) pull_gen_code_pool_2 = torch.randn(pull_num_sample_per_img, code_nc, pull_gen_img.shape[2] * 8, pull_gen_img.shape[3] * 8) pull_gen_features_pool = [] for i in range(0, pull_num_sample_per_img, forward_bs): pull_img = pull_gen_img[sample_index].expand(forward_bs, -1, -1, -1) pull_target = real_gen_img[sample_index].expand(forward_bs, -1, -1, -1) pull_d1 = d1_gen_img[sample_index].expand(forward_bs, -1, -1, -1) pull_d2 = d2_gen_img[sample_index].expand(forward_bs, -1, -1, -1) cur_data = {'LR': pull_img, 'HR': pull_target, 'D1': pull_d1, 'D2': pull_d2} start = i end = i + forward_bs model.feed_data(cur_data, code=[pull_gen_code_pool_0[start:end], pull_gen_code_pool_1[start:end], pull_gen_code_pool_2[start:end]]) feature_output = model.get_features() pull_gen_features_pool.append(feature_output['gen_feat_D1'].double().numpy()) pull_gen_features_pool = np.concatenate(pull_gen_features_pool, axis=0) pull_gen_features_pool = pull_gen_features_pool.reshape(-1, np.prod(pull_gen_features_pool.shape[1:])) pull_samples_dci_db.add(pull_gen_features_pool.copy(), num_levels=dci_num_levels, field_of_view=dci_construction_field_of_view, prop_to_visit=dci_prop_to_visit, prop_to_retrieve=dci_prop_to_retrieve) target_feature = feature_output['real_feat_D1'] target_feature = target_feature[0].reshape(1, np.prod(target_feature.shape[1:])).double().numpy().copy() pull_sample_idx_for_img, _ = pull_samples_dci_db.query( target_feature, num_neighbours=1, field_of_view=dci_query_field_of_view, prop_to_visit=dci_prop_to_visit, prop_to_retrieve=dci_prop_to_retrieve) pull_gen_code_0[sample_index, :] = pull_gen_code_pool_0[int(pull_sample_idx_for_img[0][0]), :] # clear the db pull_samples_dci_db.clear() print('\rFinding first stack code: Processed %d out of %d instances' % ( data_length, data_length)) # ============ ADD POINT ================== pull_gen_code_1 = torch.empty(pull_gen_img.shape[0], code_nc, pull_gen_img.shape[2] * 4, pull_gen_img.shape[3] * 4) out_feature_shape = model.netF(data['D2'][:1])[-1].shape[1:] pull_samples_dci_db = DCI(np.prod(out_feature_shape), dci_num_comp_indices, dci_num_simp_indices) for sample_index in range(data_length): if (sample_index + 1) % 10 == 0: print_without_newline( '\rFinding second stack code: Processed %d out of %d instances' % ( sample_index + 1, data_length)) # ============ ADD POINT ================== pull_gen_code_pool_0 = pull_gen_code_0[sample_index].expand(pull_num_sample_per_img, -1, -1, -1) pull_gen_code_pool_1 = torch.randn(pull_num_sample_per_img, code_nc, pull_gen_img.shape[2] * 4, pull_gen_img.shape[3] * 4) pull_gen_code_pool_2 = torch.randn(pull_num_sample_per_img, code_nc, pull_gen_img.shape[2] * 8, pull_gen_img.shape[3] * 8) pull_gen_features_pool = [] for i in range(0, pull_num_sample_per_img, forward_bs): pull_img = pull_gen_img[sample_index].expand(forward_bs, -1, -1, -1) pull_target = real_gen_img[sample_index].expand(forward_bs, -1, -1, -1) pull_d1 = d1_gen_img[sample_index].expand(forward_bs, -1, -1, -1) pull_d2 = d2_gen_img[sample_index].expand(forward_bs, -1, -1, -1) cur_data = {'LR': pull_img, 'HR': pull_target, 'D1': pull_d1, 'D2': pull_d2} start = i end = i + forward_bs model.feed_data(cur_data, code=[pull_gen_code_pool_0[start:end], pull_gen_code_pool_1[start:end], pull_gen_code_pool_2[start:end]]) feature_output = model.get_features() pull_gen_features_pool.append(feature_output['gen_feat_D2'].double().numpy()) pull_gen_features_pool = np.concatenate(pull_gen_features_pool, axis=0) pull_gen_features_pool = pull_gen_features_pool.reshape(-1, np.prod(pull_gen_features_pool.shape[1:])) pull_samples_dci_db.add(pull_gen_features_pool.copy(), num_levels=dci_num_levels, field_of_view=dci_construction_field_of_view, prop_to_visit=dci_prop_to_visit, prop_to_retrieve=dci_prop_to_retrieve) target_feature = feature_output['real_feat_D2'].double().numpy().copy() target_feature = target_feature[0].reshape(1, np.prod(target_feature.shape[1:])) pull_sample_idx_for_img, _ = pull_samples_dci_db.query( target_feature, num_neighbours=1, field_of_view=dci_query_field_of_view, prop_to_visit=dci_prop_to_visit, prop_to_retrieve=dci_prop_to_retrieve) pull_gen_code_1[sample_index, :] = pull_gen_code_pool_1[int(pull_sample_idx_for_img[0][0]), :] # clear the db pull_samples_dci_db.clear() print('\rFinding second stack code: Processed %d out of %d instances' % ( data_length, data_length)) # ============ ADD POINT ================== pull_gen_code_2 = torch.empty(pull_gen_img.shape[0], code_nc, pull_gen_img.shape[2] * 8, pull_gen_img.shape[3] * 8) out_feature_shape = model.netF(data['HR'][:1])[-1].shape[1:] pull_samples_dci_db = DCI(np.prod(out_feature_shape), dci_num_comp_indices, dci_num_simp_indices) for sample_index in range(data_length): if (sample_index + 1) % 10 == 0: print_without_newline( '\rFinding third stack code: Processed %d out of %d instances' % ( sample_index + 1, data_length)) # ============ ADD POINT ================== pull_gen_code_pool_0 = pull_gen_code_0[sample_index].expand(pull_num_sample_per_img, -1, -1, -1) pull_gen_code_pool_1 = pull_gen_code_1[sample_index].expand(pull_num_sample_per_img, -1, -1, -1) pull_gen_code_pool_2 = torch.randn(pull_num_sample_per_img, code_nc, pull_gen_img.shape[2] * 8, pull_gen_img.shape[3] * 8) pull_gen_features_pool = [] for i in range(0, pull_num_sample_per_img, forward_bs): pull_img = pull_gen_img[sample_index].expand(forward_bs, -1, -1, -1) pull_target = real_gen_img[sample_index].expand(forward_bs, -1, -1, -1) pull_d1 = d1_gen_img[sample_index].expand(forward_bs, -1, -1, -1) pull_d2 = d2_gen_img[sample_index].expand(forward_bs, -1, -1, -1) cur_data = {'LR': pull_img, 'HR': pull_target, 'D1': pull_d1, 'D2': pull_d2} start = i end = i + forward_bs model.feed_data(cur_data, code=[pull_gen_code_pool_0[start:end], pull_gen_code_pool_1[start:end], pull_gen_code_pool_2[start:end]]) feature_output = model.get_features() pull_gen_features_pool.append(feature_output['gen_feat'].double().numpy()) pull_gen_features_pool = np.concatenate(pull_gen_features_pool, axis=0) pull_gen_features_pool = pull_gen_features_pool.reshape(-1, np.prod(pull_gen_features_pool.shape[1:])) pull_samples_dci_db.add(pull_gen_features_pool.copy(), num_levels=dci_num_levels, field_of_view=dci_construction_field_of_view, prop_to_visit=dci_prop_to_visit, prop_to_retrieve=dci_prop_to_retrieve) target_feature = feature_output['real_feat'].double().numpy().copy() target_feature = target_feature[0].reshape(1, np.prod(target_feature.shape[1:])) pull_sample_idx_for_img, _ = pull_samples_dci_db.query( target_feature, num_neighbours=1, field_of_view=dci_query_field_of_view, prop_to_visit=dci_prop_to_visit, prop_to_retrieve=dci_prop_to_retrieve) pull_gen_code_2[sample_index, :] = pull_gen_code_pool_2[int(pull_sample_idx_for_img[0][0]), :] # clear the db pull_samples_dci_db.clear() print('\rFinding third stack code: Processed %d out of %d instances' % ( data_length, data_length)) pull_gen_code_0 += sample_perturbation_magnitude * torch.randn(pull_gen_img.shape[0], code_nc, pull_gen_img.shape[2] * 2, pull_gen_img.shape[3] * 2) pull_gen_code_1 += sample_perturbation_magnitude * torch.randn(pull_gen_img.shape[0], code_nc, pull_gen_img.shape[2] * 4, pull_gen_img.shape[3] * 4) pull_gen_code_2 += sample_perturbation_magnitude * torch.randn(pull_gen_img.shape[0], code_nc, pull_gen_img.shape[2] * 8, pull_gen_img.shape[3] * 8) return [pull_gen_code_0, pull_gen_code_1, pull_gen_code_2]
class IMLE(): def __init__(self, z_dim): self.z_dim = z_dim self.dci_db = None def train(self, data_np, hyperparams, shuffle_data=True): batch_size = hyperparams.batch_size num_batches = data_np.shape[0] // batch_size num_samples = num_batches * hyperparams.num_samples_factor num_outputs = hyperparams.num_outputs if shuffle_data: data_ordering = np.random.permutation(data_np.shape[0]) data_np = data_np[data_ordering] data_flat_np = np.reshape( data_np, (data_np.shape[0], np.prod(data_np.shape[1:]))) input_shape = [64, 1, 1, self.z_dim] net = get_model(input_shape) train_weights = net.trainable_weights if self.dci_db is None: self.dci_db = DCI(np.prod(data_np.shape[1:]), num_comp_indices=2, num_simp_indices=7) for epoch in range(hyperparams.num_epochs): if epoch % hyperparams.decay_step == 0: lr = hyperparams.base_lr * hyperparams.decay_rate**( epoch // hyperparams.decay_step) optimizer = tf.optimizers.Adam(learning_rate=lr, beta_1=0.5, beta_2=0.999) # Optimizer API changes. betas->beta_1, beta_2 & remove weight_decay=1e-5 if epoch % hyperparams.staleness == 0: net.eval() #declare now in eval mode. z_np = np.empty((num_samples * batch_size, 1, 1, self.z_dim)) samples_np = np.empty((num_samples * batch_size, ) + data_np.shape[1:]) for i in range(num_samples): z = tf.random.normal([batch_size, 1, 1, self.z_dim]) samples = net(z) z_np[i * batch_size:(i + 1) * batch_size] = z.numpy() samples_np[i * batch_size:(i + 1) * batch_size] = samples.numpy() samples_flat_np = np.reshape( samples_np, (samples_np.shape[0], np.prod(samples_np.shape[1:]))) self.dci_db.reset() self.dci_db.add(samples_flat_np, num_levels=2, field_of_view=10, prop_to_retrieve=0.002) nearest_indices, _ = self.dci_db.query(data_flat_np, num_neighbours=1, field_of_view=20, prop_to_retrieve=0.02) nearest_indices = np.array(nearest_indices)[:, 0] z_np = z_np[nearest_indices] z_np += 0.01 * np.random.randn(*z_np.shape) del samples_np, samples_flat_np err = 0. for i in range(num_batches): net.eval() cur_z = tf.convert_to_tensor(z_np[i * batch_size:(i + 1) * batch_size], dtype=tf.float32) cur_data = tf.convert_to_tensor( data_np[i * batch_size:(i + 1) * batch_size], dtype=tf.float32) cur_samples = net(cur_z) _loss = tl.cost.mean_squared_error(cur_data, cur_samples, is_mean=False) err += _loss print("Epoch %d: Error: %f" % (epoch, err / num_batches)) for i in range(num_batches): net.train() with tf.GradientTape() as tape: cur_z = tf.convert_to_tensor(z_np[i * batch_size:(i + 1) * batch_size], dtype=tf.float32) cur_data = tf.convert_to_tensor( data_np[i * batch_size:(i + 1) * batch_size], dtype=tf.float32) cur_samples = net(cur_z) _loss = tl.cost.mean_squared_error(cur_data, cur_samples, is_mean=False) grad = tape.gradient(_loss, train_weights) optimizer.apply_gradients(zip(grad, train_weights)) #debug. extreamly lower the speed. but worth doing to visualize the process. (following 4 lines can be deleted directly) #cur_z = tf.convert_to_tensor(z_np[0: batch_size], dtype = tf.float32) #cur_samples = net(cur_z) #pics = np.reshape(cur_samples.numpy(), (batch_size, 28, 28)) #save_img(1, pics) cur_z = tf.random.normal([batch_size, 1, 1, self.z_dim]) cur_samples = net(cur_z) pics = np.reshape(cur_samples.numpy(), (batch_size, 28, 28)) save_img(2, pics) #cur_z = tf.convert_to_tensor(z_np[0: batch_size], dtype = tf.float32) cur_z = tf.random.normal([batch_size, 1, 1, self.z_dim]) cur_samples = net(cur_z) pics = np.reshape(cur_samples.numpy(), (batch_size, 28, 28)) save_img(batch_size, pics)
class IMLE(): def __init__(self, z_dim, Sx_dim): self.z_dim = z_dim self.Sx_dim = Sx_dim self.model = ConvolutionalImplicitModel(z_dim + Sx_dim).cuda() self.dci_db = None #----------------------------------------------------------------------------------------------------------- def train(self, data_np, data_Sx, base_lr=1e-4, batch_size=256, num_epochs=3000,\ decay_step=25, decay_rate=0.95, staleness=100, num_samples_factor=100): # define metric # loss_fn = nn.MSELoss().cuda() loss_fn = nn.L1Loss().cuda() # loss_fn = nn.BCELoss().cuda() self.model.train() # train in batch num_batches = data_np.shape[0] // batch_size # truncate data to fit the batch size num_data = num_batches * batch_size data_np = data_np[:num_data] data_Sx = data_Sx[:num_data] #----------------------------------------------------------------------------------------------------------- # make empty array to store results samples_predict = np.empty(data_np.shape) samples_np = np.empty((num_samples_factor, ) + data_np.shape[1:]) # samples_np = np.empty((num_data*num_samples_factor,)+data_np.shape[1:]) nearest_indices = np.empty((num_data)).astype("int") # make global torch variables data_all = torch.from_numpy(data_np).float().cuda() Sx = torch.from_numpy(np.repeat(data_Sx, num_samples_factor, axis=0)).float().cuda() # initiate dci if self.dci_db is None: self.dci_db = DCI(np.prod(data_np.shape[1:]), num_comp_indices=2, num_simp_indices=7) #============================================================================================================= # train through various epochs for epoch in range(num_epochs): # decay the learning rate if epoch % decay_step == 0: lr = base_lr * decay_rate**(epoch // decay_step) optimizer = optim.Adam(self.model.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=1e-5) #----------------------------------------------------------------------------------------------------------- # update the closest models routintely if epoch % staleness == 0: # draw random z z = torch.randn(num_data * num_samples_factor, self.z_dim).cuda() #z_Sx_all = torch.cat((z, Sx), axis=1) z_Sx_all = torch.cat((z, Sx), axis=1)[:, :, None] #----------------------------------------------------------------------------------------------------------- # find the closest object for individual data nearest_indices = np.empty((num_data)).astype("int") for i in range(num_data): samples = self.model( z_Sx_all[i * num_samples_factor:(i + 1) * num_samples_factor]) samples_np[:] = samples.cpu().data.numpy() # find the nearest neighbours self.dci_db.reset() self.dci_db.add(np.copy(samples_np),\ num_levels = 2, field_of_view = 10, prop_to_retrieve = 0.002) nearest_indices_temp, _ = self.dci_db.query(data_np[i:i+1],\ num_neighbours = 1, field_of_view = 20, prop_to_retrieve = 0.02) nearest_indices[i] = nearest_indices_temp[0][ 0] + i * num_samples_factor #----------------------------------------------------------------------------------------------------------- # # find the closest object for individual data # samples = self.model(z_Sx_all) # samples_np[:] = samples.cpu().data.numpy() # # # find the nearest neighbours # self.dci_db.reset() # self.dci_db.add(np.copy(samples_np),\ # num_levels = 2, field_of_view = 10, prop_to_retrieve = 0.002) # nearest_indices_temp, _ = self.dci_db.query(data_np,\ # num_neighbours = 1, field_of_view = 20, prop_to_retrieve = 0.02) # nearest_indices[:] = nearest_indices_temp #----------------------------------------------------------------------------------------------------------- # restrict latent parameters to the nearest neighbour z_Sx = z_Sx_all[nearest_indices] #============================================================================================================= # gradient descent err = 0. # loop over all batches for i in range(num_batches): self.model.zero_grad() cur_samples = self.model(z_Sx[i * batch_size:(i + 1) * batch_size]) # save the mock sample if (epoch + 1) % staleness == 0: samples_predict[i * batch_size:( i + 1) * batch_size] = cur_samples.cpu().data.numpy() # gradient descent loss = loss_fn(cur_samples, data_all[i * batch_size:(i + 1) * batch_size]) loss.backward() err += loss.item() optimizer.step() print("Epoch %d: Error: %f" % (epoch, err / num_batches)) #----------------------------------------------------------------------------------------------------------- # save the mock sample if (epoch + 1) % staleness == 0: # save closet models np.savez("../results_spectra_deconv_256x2_" + str(epoch) + ".npz", data_np=data_np,\ z_Sx_np=z_Sx.cpu().data.numpy(),\ samples_np=samples_predict) np.savez("../mse_err_deconv_256x2_" + str(epoch) + ".npz",\ mse_err=err/num_batches) # save network torch.save( self.model.state_dict(), '../net_weights_spectra_deconv_256x2_epoch=' + str(epoch) + '.pth')
class IMLE(): def __init__(self, z_dim): self.z_dim = z_dim self.model = ConvolutionalImplicitModel(z_dim).cuda() self.dci_db = None def train(self, data_np, hyperparams, shuffle_data=True): loss_fn = nn.MSELoss().cuda() self.model.train() batch_size = hyperparams.batch_size num_batches = data_np.shape[0] // batch_size num_samples = num_batches * hyperparams.num_samples_factor if shuffle_data: data_ordering = np.random.permutation(data_np.shape[0]) data_np = data_np[data_ordering] data_flat_np = np.reshape( data_np, (data_np.shape[0], np.prod(data_np.shape[1:]))) if self.dci_db is None: self.dci_db = DCI(np.prod(data_np.shape[1:]), num_comp_indices=2, num_simp_indices=7) for epoch in range(hyperparams.num_epochs): if epoch % hyperparams.decay_step == 0: lr = hyperparams.base_lr * hyperparams.decay_rate**( epoch // hyperparams.decay_step) optimizer = optim.Adam(self.model.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=1e-5) if epoch % hyperparams.staleness == 0: z_np = np.empty((num_samples * batch_size, self.z_dim, 1, 1)) samples_np = np.empty((num_samples * batch_size, ) + data_np.shape[1:]) for i in range(num_samples): z = torch.randn(batch_size, self.z_dim, 1, 1).cuda() samples = self.model(z) z_np[i * batch_size:(i + 1) * batch_size] = z.cpu().data.numpy() samples_np[i * batch_size:(i + 1) * batch_size] = samples.cpu().data.numpy() samples_flat_np = np.reshape( samples_np, (samples_np.shape[0], np.prod(samples_np.shape[1:]))) self.dci_db.reset() self.dci_db.add(samples_flat_np, num_levels=2, field_of_view=10, prop_to_retrieve=0.002) nearest_indices, _ = self.dci_db.query(data_flat_np, num_neighbours=1, field_of_view=20, prop_to_retrieve=0.02) nearest_indices = np.array(nearest_indices)[:, 0] z_np = z_np[nearest_indices] z_np += 0.01 * np.random.randn(*z_np.shape) del samples_np, samples_flat_np err = 0. for i in range(num_batches): self.model.zero_grad() cur_z = torch.from_numpy(z_np[i * batch_size:(i + 1) * batch_size]).float().cuda() cur_data = torch.from_numpy( data_np[i * batch_size:(i + 1) * batch_size]).float().cuda() cur_samples = self.model(cur_z) loss = loss_fn(cur_samples, cur_data) loss.backward() err += loss.item() optimizer.step() print("Epoch %d: Error: %f" % (epoch, err / num_batches))
def runknn(sz=4, num_points=202599, num_queries=1, num_comp_indices=2, num_simp_indices=7, num_outer_iterations=202599, max_num_candidates=5000, num_neighbours=10, patch_size=5): num_levels = 3 construction_field_of_view = 10 construction_prop_to_retrieve = 0.02 query_field_of_view = 12000 query_prop_to_retrieve = 1.0 dim = patch_size * patch_size * 5 #setmodulename('\"../church_npy/church_%d.npy\"'%sz) #data_and_queries = get_img(dim, num_points + num_queries) #data = np.load("../church_npy/church_%d.npy"%sz,mmap_mode='r')#get_img(dim, num_points)#np.copy(data_and_queries[:num_points,:]) down = int(math.log(256 / sz, 2)) dci_db = DCI(dim, num_comp_indices, num_simp_indices) st = time.time() print("before adding") dci_db.add(num_points, "../church_npy/church_%d_vgg_12_5.npy" % sz, num_levels=num_levels, field_of_view=construction_field_of_view, prop_to_retrieve=construction_prop_to_retrieve, load_from_file=1) print("construction time:", time.time() - st) imgid = 270 flip = 0 datamat = np.load("../church_npy/church_%d.npy" % 128, mmap_mode='r') rawimg = imread("../results_churchoutdoor/fake_%04d.png" % imgid) if flip: rawimg = np.flip(rawimg, 1) #im = Image.fromarray(target, 'RGB') for j in range(1): rawimg = np.mean(np.concatenate([ rawimg[0::2, 0::2, None], rawimg[0::2, 1::2, None], rawimg[1::2, 0::2, None], rawimg[1::2, 1::2, None] ], axis=2), axis=2) rawimg = rawimg.astype(np.uint8) if flip: fakeimgs = np.load("../church_npy/fakechurch_%d_vgg_12_flip.npy" % sz, mmap_mode='r') else: fakeimgs = np.load("../church_npy/fakechurch_%d_vgg_12.npy" % sz, mmap_mode='r') mularr = np.load("../church_npy/rpvec_64_5.npy") fakeimgs_rp = fakeimgs[imgid].copy() print(fakeimgs_rp.shape) fakeimg_rp = np.dot(fakeimgs_rp, mularr) del fakeimgs del mularr minx = 0 miny = 0 maxx = 128 maxy = 128 ambient_dim = 32 * 32 * 5 numx = (max(minx + 1, maxx - 32 + 1) - minx + 32 - 1) / 32 numy = (max(miny + 1, maxy - 32 + 1) - miny + 32 - 1) / 32 print(numx, numy) queries = np.empty([25 * 25, ambient_dim], dtype=np.float32) d = 32 st = 0 for j in range(miny, max(miny + 1, maxy - 32 + 1), 4): for k in range(minx, max(minx + 1, maxx - 32 + 1), 4): eyeimg = fakeimg_rp[j:j + d, k:k + d, :].flatten().astype(np.float32) queries[st] = eyeimg / np.linalg.norm(eyeimg) * 255 st += 1 print(queries.shape, queries.dtype, st) st = time.time() num_neighbours = 200 query_field_of_view = 10000 #11000 query_prop_to_retrieve = 1.0 nearest_neighbour_idx, nearest_neighbour_dists = dci_db.query( queries, num_neighbours, field_of_view=query_field_of_view, prop_to_retrieve=query_prop_to_retrieve, blind=False) print("query time:", time.time() - st) finaldist = np.array(nearest_neighbour_dists) rawidx = np.array(nearest_neighbour_idx) print(rawidx.shape) finalidx = rawidx / (25 * 25) finalpos = np.empty([rawidx.shape[0], 200, 2], dtype=np.int) for i in range(rawidx.shape[0]): for j in range(200): offset = rawidx[i, j] % (25 * 25) finalpos[i, j, 0] = offset / 25 * 4 finalpos[i, j, 1] = offset % 25 * 4 np.savez("fake_%04d_all_overlap.npz" % imgid, finalidx=finalidx, finaldist=finaldist, finalpos=finalpos) return queries = get_query(dim, 1, down, sz, patch_size, patch_size) #data_and_queries[num_points:,:] print(queries.shape, queries.dtype) queries = queries[0:2] st = time.time() nearest_neighbour_idx, nearest_neighbour_dists = dci_db.query( queries, num_neighbours, field_of_view=query_field_of_view, prop_to_retrieve=query_prop_to_retrieve, blind=False) print("query time:", time.time() - st) #gen_res(num_points, queries, nearest_neighbour_idx, down, sz) #resfile = "../save_%d/res.txt"%(sz) #f = open(resfile,'w') print(np.array(nearest_neighbour_idx)[:, 0]) print(np.array(nearest_neighbour_dists)[:, 0]) np.save("churchidx_18_v7.npy", np.array(nearest_neighbour_idx)) np.save("churchdist_18_v7.npy", np.array(nearest_neighbour_dists)) dci_db.clear()
class IMLE(): def __init__(self, z_dim): self.z_dim = z_dim self.model = ConvolutionalImplicitModel(z_dim).cuda() self.dci_db = None def train(self, data_np, hyperparams, shuffle_data=True): loss_fn = nn.MSELoss().cuda() self.model.train() batch_size = hyperparams.batch_size num_batches = data_np.shape[0] // batch_size num_samples = num_batches * hyperparams.num_samples_factor # number of generated samples if shuffle_data: data_ordering = np.random.permutation(data_np.shape[0]) data_np = data_np[data_ordering] data_flat_np = np.reshape( data_np, (data_np.shape[0], np.prod(data_np.shape[1:]))) if self.dci_db is None: self.dci_db = DCI(np.prod(data_np.shape[1:]), num_comp_indices=2, num_simp_indices=7) for epoch in range(hyperparams.num_epochs): if epoch % hyperparams.decay_step == 0: lr = hyperparams.base_lr * hyperparams.decay_rate**( epoch // hyperparams.decay_step) optimizer = optim.Adam(self.model.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=1e-5) # Data generation step - do you re-sample training steps? # It does not seem like it is selecting the samples S? It always uses the original dataset data_np if epoch % hyperparams.staleness == 0: z_np = np.empty((num_samples * batch_size, self.z_dim, 1, 1)) samples_np = np.empty((num_samples * batch_size, ) + data_np.shape[1:]) for i in range(num_samples): z = torch.randn(batch_size, self.z_dim, 1, 1).cuda() samples = self.model(z) z_np[i * batch_size:(i + 1) * batch_size] = z.cpu().data.numpy() samples_np[i * batch_size:(i + 1) * batch_size] = samples.cpu().data.numpy() samples_flat_np = np.reshape( samples_np, (samples_np.shape[0], np.prod(samples_np.shape[1:]))) self.dci_db.reset() self.dci_db.add(samples_flat_np, num_levels=2, field_of_view=10, prop_to_retrieve=0.002) nearest_indices, _ = self.dci_db.query(data_flat_np, num_neighbours=1, field_of_view=20, prop_to_retrieve=0.02) nearest_indices = np.array(nearest_indices)[:, 0] z_np = z_np[nearest_indices] z_np += 0.01 * np.random.randn(*z_np.shape) del samples_np, samples_flat_np # z_np consists of z values whose value makes up the closest sample to each data point # But why do we do it this way? # We want to compare the real data point x_i with the closest generated point. # Call this point hat{x}_{sigma(i)}, and this was generated by the random noise, say, G_{theta}(z_i). # Once we have determined this point, how can we change G such that hat{x}_{sigma(i)} is closer to x_i? # We minimize || G_{theta)(z_i) - x_i|| # What happens in the conditional case, where we condition on a random variable A? # The idea is to generate m different samples for each a in A, and then out of # m samples, find the one that is closest to the data point err = 0. # batch-gradient descent for i in range(num_batches): self.model.zero_grad() cur_z = torch.from_numpy(z_np[i * batch_size:(i + 1) * batch_size]).float().cuda() cur_data = torch.from_numpy( data_np[i * batch_size:(i + 1) * batch_size]).float().cuda() cur_samples = self.model(cur_z) loss = loss_fn(cur_samples, cur_data) loss.backward() err += loss.item() optimizer.step() print("Epoch %d: Error: %f" % (epoch, err / num_batches))
class CoolSystem(pl.LightningModule): def __init__(self, z_dim, hyperparams, shuffle_data=True): super(CoolSystem, self).__init__() # not the best model... #self.l1 = torch.nn.Linear(28 * 28, 10) self.z_dim = z_dim self.z_grid = None self.hyperparams = hyperparams self.shuffle_data = shuffle_data self.dci_db = None self.model = ConvolutionalImplicitModel(z_dim) #self.model = Generator128(z_dim, n_class=10) #self.squeezeNet = NetPerLayer() self.squeezeNet = None self.loss_fn = nn.MSELoss() self._step = 0 def regen(self, batch): #import pdb; pdb.set_trace() imgs, labels = batch data_np = imgs.numpy() #import pdb; pdb.set_trace() data_flat_np = np.reshape(data_np, (data_np.shape[0], np.prod(data_np.shape[1:]))) self.data_np = data_np self.data_flat_np = data_flat_np hyperparams = self.hyperparams batch_size = hyperparams.batch_size #num_batches = data_np.shape[0] // batch_size num_batches = 1 num_samples = num_batches * hyperparams.num_samples_factor #import pdb; pdb.set_trace() z_np = np.empty((num_samples * batch_size, self.z_dim, 1, 1)) samples_np = np.empty((num_samples * batch_size,)+data_np.shape[1:]) for i in range(num_samples): z = torch.randn(batch_size, self.z_dim, 1, 1).cpu() samples = self.model(z, labels) z_np[i*batch_size:(i+1)*batch_size] = z.cpu().data.numpy() samples_np[i*batch_size:(i+1)*batch_size] = samples.cpu().data.numpy() samples_flat_np = np.reshape(samples_np, (samples_np.shape[0], np.prod(samples_np.shape[1:]))).copy() if self.dci_db is None: self.dci_db = DCI(np.prod(data_np.shape[1:]), num_comp_indices = 2, num_simp_indices = 7) self.dci_db.reset() self.dci_db.add(samples_flat_np, num_levels = 2, field_of_view = 10, prop_to_retrieve = 0.002) nearest_indices, _ = self.dci_db.query(data_flat_np, num_neighbours = 1, field_of_view = 20, prop_to_retrieve = 0.02) nearest_indices = np.array(nearest_indices)[:,0] z_np = z_np[nearest_indices] z_np += 0.01*np.random.randn(*z_np.shape) self.z_np = z_np if self.z_grid is None: self.z_grid = self.z_np.copy() * 0.7 del samples_np, samples_flat_np return z_np # def forward(self, x): # return torch.relu(self.l1(x.view(x.size(0), -1))) def forward(self, z, class_id=None): cur_samples = self.model(z, class_id) return cur_samples def training_step(self, batch, batch_idx): # REQUIRED #x, y = batch #y_hat = self.forward(x) #loss = F.cross_entropy(y_hat, y) #batch_size = batch.shape[0] hyperparams = self.hyperparams batch_size = hyperparams.batch_size i = batch_idx data_np = self.data_np #z_np = self.regen() z_np = self.z_np z_grid = self.z_grid #cur_z = torch.from_numpy(z_np[i*batch_size:(i+1)*batch_size]).float().cpu() #imgTarget = torch.from_numpy(data_np[i*batch_size:(i+1)*batch_size]).float().cpu() cur_z = torch.from_numpy(z_np).float().cpu() imgTarget = torch.from_numpy(data_np).float().cpu() print(cur_z.shape, batch_idx) def nimg(img): #img = img + 1 #img = img / 2 #import pdb; pdb.set_trace() img = torch.clamp(img, 0, 1) return img imgInput = batch[0] imgLabels = batch[1] self.logger.experiment.add_image('imgInput', torchvision.utils.make_grid(nimg(imgInput)), self._step) imgOutput = self.forward(cur_z, imgLabels) self.logger.experiment.add_image('imgOutput', torchvision.utils.make_grid(nimg(imgOutput)), self._step) if self._step % me.interp_every == 0 and self._step > -1 and True: cur_zgrid = torch.from_numpy(z_grid).float().cpu() imgGrid = self.forward(cur_zgrid, imgLabels) self.logger.experiment.add_image('imgGrid', torchvision.utils.make_grid(nimg(imgGrid)), self._step) #print('circle...') #import pdb; pdb.set_trace() imgs = me.circle_interpolation(self.model, imgInput.shape[0]) #imgs = np.array(imgs) #imgs = torch.from_numpy(imgs).float().cpu() self.logger.experiment.add_image('imgInterp', torchvision.utils.make_grid(nimg(imgs)), self._step) #import pdb; pdb.set_trace() if self.squeezeNet is not None: #print('squeezeNet(imgTarget)...') activationsTarget = self.squeezeNet(imgTarget.repeat(1,3,1,1)) #print('squeezeNet(imgOutput)...') activationsOutput = self.squeezeNet(imgOutput.repeat(1,3,1,1)) #print('loss...') featLoss = None #for actTarget, actOutput in zip(activationsTarget[1:3], activationsOutput[1:3]): #for actTarget, actOutput in tqdm.tqdm(list(zip(activationsTarget, activationsOutput))): for actTarget, actOutput in zip(activationsTarget, activationsOutput): #l = F.mse_loss(actTarget, actOutput) #l = torch.abs(actTarget - actOutput).sum() #l = F.l1_loss(actTarget, actOutput) l = -pearsonr(actTarget.view(-1), actOutput.view(-1)) if featLoss is None: featLoss = l else: featLoss += l else: featLoss = 0.0 #pixelLoss = F.mse_loss(imgOutput, imgTarget) pixelLoss = self.loss_fn(imgOutput, imgTarget) loss = featLoss + pixelLoss lr = self.lr_fn(self.current_epoch) tensorboard_logs = {'train_loss': loss, 'lr': lr, 'epoch': self.current_epoch} self._step += 1 return {'loss': loss, 'log': tensorboard_logs} def number_interpolation(self, count, lo, hi): hyperparams = self.hyperparams batch_size = hyperparams.batch_size def gen_latent(pos): z = gen_func(pos) z = z.reshape([1,-1,1,1]) return z # def validation_step(self, batch, batch_idx): # # OPTIONAL # x, y = batch # y_hat = self.forward(x) # return {'val_loss': F.cross_entropy(y_hat, y)} # def validation_end(self, outputs): # # OPTIONAL # avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() # tensorboard_logs = {'val_loss': avg_loss} # return {'avg_val_loss': avg_loss, 'log': tensorboard_logs} # def test_step(self, batch, batch_idx): # # OPTIONAL # x, y = batch # y_hat = self.forward(x) # return {'test_loss': F.cross_entropy(y_hat, y)} # def test_end(self, outputs): # # OPTIONAL # avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean() # tensorboard_logs = {'test_loss': avg_loss} # return {'avg_test_loss': avg_loss, 'log': tensorboard_logs} def configure_optimizers(self): # REQUIRED # can return multiple optimizers and learning_rate schedulers # (LBFGS it is automatically supported, no need for closure function) #return torch.optim.Adam(self.parameters(), lr=0.0004) #epoch = 0 hyperparams = self.hyperparams lr = hyperparams.base_lr # * hyperparams.decay_rate ** (epoch // hyperparams.decay_step) self.lr_fn = lambda epoch: hyperparams.decay_rate ** (epoch // hyperparams.decay_step) #optimizer = torch.optim.Adam(self.model.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=1e-5) #optimizer = torch.optim.Adam(self.model.parameters(), lr=lr, amsgrad=True) optimizer = torch.optim.Adam(self.model.parameters(), lr=lr, betas=(0.5, 0.999), amsgrad=True, weight_decay=1e-5) #optimizer = torch.optim.Adam(self.model.parameters(), lr=lr, betas=(0.5, 0.999), amsgrad=True) #optimizer = torch.optim.Adam(self.model.parameters(), lr=lr, betas=(0.0, 0.999)) from torch.optim.lr_scheduler import LambdaLR scheduler = LambdaLR(optimizer, lr_lambda=[self.lr_fn]) return [optimizer], [scheduler] def prepare_data(self): #MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()) #MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor()) CIFAR10(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()) CIFAR10(os.getcwd(), train=False, download=True, transform=transforms.ToTensor()) #self.data_np = np.random.randn(128, 1, 28, 28) # self.data_np = np.array([img.numpy() for img, label in dataset]) # print('ready') # data_np = self.data_np # hyperparams = self.hyperparams # if self.shuffle_data: # data_ordering = np.random.permutation(data_np.shape[0]) # data_np = data_np[data_ordering] # batch_size = hyperparams.batch_size # num_batches = data_np.shape[0] // batch_size # num_samples = num_batches * hyperparams.num_samples_factor # self.data_flat_np = np.reshape(data_np, (data_np.shape[0], np.prod(data_np.shape[1:]))) def on_batch_start(self, batch): self.regen(batch) def train_dataloader(self): # REQUIRED #dataset = MNIST(os.getcwd(), train=True, download=False, transform=transforms.ToTensor()) #dataset = CIFAR10(os.getcwd(), train=True, download=False, transform=transforms.ToTensor()) #loader = DataLoader(dataset, batch_size=32) # Downloading/Louding CIFAR10 data trainset = CIFAR10(root=os.getcwd(), train=True , download=False)#, transform = transform_with_aug) testset = CIFAR10(root=os.getcwd(), train=False, download=False)#, transform = transform_no_aug) classDict = {'plane':0, 'car':1, 'bird':2, 'cat':3, 'deer':4, 'dog':5, 'frog':6, 'horse':7, 'ship':8, 'truck':9} # Separating trainset/testset data/label x_train = trainset.data x_test = testset.data y_train = trainset.targets y_test = testset.targets # Define a function to separate CIFAR classes by class index def get_class_i(x, y, i): """ x: trainset.train_data or testset.test_data y: trainset.train_labels or testset.test_labels i: class label, a number between 0 to 9 return: x_i """ # Convert to a numpy array y = np.array(y) # Locate position of labels that equal to i pos_i = np.argwhere(y == i) # Convert the result into a 1-D list pos_i = list(pos_i[:,0]) # Collect all data that match the desired label x_i = [x[j] for j in pos_i] return x_i class DatasetMaker(Dataset): def __init__(self, datasets, transformFunc = transforms.ToTensor()): """ datasets: a list of get_class_i outputs, i.e. a list of list of images for selected classes """ self.datasets = datasets self.lengths = [len(d) for d in self.datasets] self.transformFunc = transformFunc def __getitem__(self, i): class_label, index_wrt_class = self.index_of_which_bin(self.lengths, i) img = self.datasets[class_label][index_wrt_class] if self.transformFunc: img = self.transformFunc(img) return img, class_label def __len__(self): return sum(self.lengths) def index_of_which_bin(self, bin_sizes, absolute_index, verbose=False): """ Given the absolute index, returns which bin it falls in and which element of that bin it corresponds to. """ # Which class/bin does i fall into? accum = np.add.accumulate(bin_sizes) if verbose: print("accum =", accum) bin_index = len(np.argwhere(accum <= absolute_index)) if verbose: print("class_label =", bin_index) # Which element of the fallent class/bin does i correspond to? index_wrt_class = absolute_index - np.insert(accum, 0, 0)[bin_index] if verbose: print("index_wrt_class =", index_wrt_class) return bin_index, index_wrt_class # ================== Usage ================== # # Let's choose cats (class 3 of CIFAR) and dogs (class 5 of CIFAR) as trainset/testset cat_dog_trainset = \ DatasetMaker( [get_class_i(x_train, y_train, classDict['cat']), get_class_i(x_train, y_train, classDict['dog'])] #transform_with_aug ) cat_dog_testset = \ DatasetMaker( [get_class_i(x_test , y_test , classDict['cat']), get_class_i(x_test , y_test , classDict['dog'])] #transform_no_aug ) kwargs = {'num_workers': 2, 'pin_memory': False} hyperparams = self.hyperparams batch_size = hyperparams.batch_size # Create datasetLoaders from trainset and testset trainsetLoader = DataLoader(cat_dog_trainset, batch_size=batch_size, shuffle=True , **kwargs) testsetLoader = DataLoader(cat_dog_testset , batch_size=batch_size, shuffle=False, **kwargs) #loader = DataLoader(dataset, batch_size=batch_size, shuffle=True) # self.data_np = np.array([img.numpy() for img, label in dataset]) #self._train_loader = loader #return loader return trainsetLoader
class IMLE(): def __init__(self, z_dim): self.z_dim = z_dim self.model = ConvolutionalImplicitModel(z_dim).cuda() self.model2 = ConvolutionalImplicitModel(z_dim).cuda() self.dci_db = None self.dci_db2 = None #----------------------------------------------------------------------------------------------------------- def train(self, data_np, data_np2, base_lr=1e-3, batch_size=64, num_epochs=10000,\ decay_step=25, decay_rate=1.0, staleness=500, num_samples_factor=100): # define metric loss_fn = nn.MSELoss().cuda() # make model trainable self.model.train() self.model2.train() # train in batch num_batches = data_np.shape[0] // batch_size # true data to mock sample num_samples = num_batches * num_samples_factor #----------------------------------------------------------------------------------------------------------- # make it in 1D data image for DCI data_flat_np = np.reshape(data_np, (data_np.shape[0], np.prod(data_np.shape[1:]))) # initiate dci if self.dci_db is None: self.dci_db = DCI(np.prod(data_np.shape[1:]), num_comp_indices = 2, num_simp_indices = 7) if self.dci_db2 is None: self.dci_db2 = DCI(np.prod(data_np2.shape[1:]), num_comp_indices = 2, num_simp_indices = 7) #----------------------------------------------------------------------------------------------------------- # train through various epochs for epoch in range(num_epochs): # decay the learning rate if epoch % decay_step == 0: lr = base_lr * decay_rate ** (epoch // decay_step) optimizer = optim.Adam(list(self.model.parameters()) + list(self.model2.parameters()),\ lr=lr, betas=(0.5, 0.999), weight_decay=1e-5) #----------------------------------------------------------------------------------------------------------- # re-evaluate the closest models routinely if epoch % staleness == 0: # initiate numpy array to store latent draws and the associate sample z_np = np.empty((num_samples * batch_size, self.z_dim, 1, 1, 1)) samples_np = np.empty((num_samples * batch_size,)+data_np.shape[1:]) samples_np2 = np.empty((num_samples * batch_size,)+data_np2.shape[1:]) # make sample (in batch to avoid GPU memory problem) for i in range(num_samples): z = torch.randn(batch_size, self.z_dim, 1, 1, 1).cuda() samples = self.model(z) samples2 = self.model2(z) z_np[i*batch_size:(i+1)*batch_size] = z.cpu().data.numpy() samples_np[i*batch_size:(i+1)*batch_size] = samples.cpu().data.numpy() samples_np2[i*batch_size:(i+1)*batch_size] = samples2.cpu().data.numpy() # make 1D images samples_flat_np = np.reshape(samples_np, (samples_np.shape[0], np.prod(samples_np.shape[1:]))) #----------------------------------------------------------------------------------------------------------- # find the nearest neighbours self.dci_db.reset() self.dci_db.add(np.copy(samples_flat_np), num_levels = 2, field_of_view = 10, prop_to_retrieve = 0.002) nearest_indices, _ = self.dci_db.query(data_flat_np, num_neighbours = 1, field_of_view = 20, prop_to_retrieve = 0.02) nearest_indices = np.array(nearest_indices)[:,0] z_np = z_np[nearest_indices] # add random noise to the latent space to faciliate training z_np += 0.01*np.random.randn(*z_np.shape) # delete to save Hyperparameters del samples_np, samples_np2, samples_flat_np #============================================================================================================= # permute data data_ordering = np.random.permutation(data_np.shape[0]) data_np = data_np[data_ordering] data_np2 = data_np2[data_ordering] data_flat_np = np.reshape(data_np, (data_np.shape[0], np.prod(data_np.shape[1:]))) z_np = z_np[data_ordering] #----------------------------------------------------------------------------------------------------------- # gradient descent err = 0. # loop over all batches for i in range(num_batches): # set up backprop self.model.zero_grad() self.model2.zero_grad() #----------------------------------------------------------------------------------------------------------- # evaluate the models cur_z = torch.from_numpy(z_np[i*batch_size:(i+1)*batch_size]).float().cuda() cur_data = torch.from_numpy(data_np[i*batch_size:(i+1)*batch_size]).float().cuda() cur_data2 = torch.from_numpy(data_np2[i*batch_size:(i+1)*batch_size]).float().cuda() cur_samples = self.model(cur_z) cur_samples2 = self.model2(cur_z) #----------------------------------------------------------------------------------------------------------- # calculate MSE loss of the two images loss = loss_fn(cur_samples, cur_data) + loss_fn(cur_samples2, cur_data2) loss.backward() err += loss.item() optimizer.step() print("Epoch %d: Error: %f" % (epoch, err / num_batches)) # save the mock sample if (epoch+1) % staleness == 0: np.savez("../results_3D.npz", data_np=data_np,\ data_np2=data_np2,\ samples_np=self.model(torch.from_numpy(z_np).float().cuda()).cpu().data.numpy(),\ samples_np2=self.model2(torch.from_numpy(z_np).float().cuda()).cpu().data.numpy())