class IMLE(): def __init__(self, z_dim, Sx_dim): self.z_dim = z_dim self.Sx_dim = Sx_dim self.model = ConvolutionalImplicitModel(z_dim+Sx_dim, 0.5).cuda() self.model.apply(self.model.get_initializer()) self.dci_db = None #----------------------------------------------------------------------------------------------------------- # load pre-trained model state_dict = torch.load("../net_weights_2D_scattering_J=6_L=2_times=10.pth") self.model.load_state_dict(state_dict) #----------------------------------------------------------------------------------------------------------- #def train(self, data_np, data_Sx, name_JL, base_lr=1e-4, batch_size=128, num_epochs=3000,\ # decay_step=25, decay_rate=0.95, staleness=100, num_samples_factor=100): def train(self, data_np, data_Sx, name_JL, base_lr=1e-5, batch_size=128, num_epochs=3000,\ decay_step=25, decay_rate=0.98, staleness=100, num_samples_factor=100): # define metric loss_fn = nn.MSELoss().cuda() self.model.train() # train in batch num_batches = data_np.shape[0] // batch_size # truncate data to fit the batch size num_data = num_batches*batch_size data_np = data_np[:num_data] data_Sx = data_Sx[:num_data] # make it in 1D data image for DCI data_flat_np = np.reshape(data_np, (data_np.shape[0], np.prod(data_np.shape[1:]))) #----------------------------------------------------------------------------------------------------------- # make empty array to store results samples_predict = np.empty(data_np.shape) samples_np = np.empty((num_samples_factor,)+data_np.shape[1:]) #----------------------------------------------------------------------------------------------------------- # make global torch variables data_all = torch.from_numpy(data_np).float().cuda() Sx = torch.from_numpy(np.repeat(data_Sx,num_samples_factor,axis=0)).float().cuda() #----------------------------------------------------------------------------------------------------------- # initiate dci if self.dci_db is None: self.dci_db = DCI(np.prod(data_np.shape[1:]), num_comp_indices = 2, num_simp_indices = 7) #============================================================================================================= # train through various epochs for epoch in range(num_epochs): # decay the learning rate if epoch % decay_step == 0: lr = base_lr * decay_rate ** (epoch // decay_step) optimizer = optim.Adam(self.model.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=1e-5) #----------------------------------------------------------------------------------------------------------- # update the closest models routintely if epoch % staleness == 0: # draw random z z = torch.randn(num_data*num_samples_factor, self.z_dim, 1, 1).cuda() z_Sx_all = torch.cat((z, Sx), axis=1) # find the closest object for individual data nearest_indices = np.empty((num_data)).astype("int") for i in range(num_data): samples = self.model(z_Sx_all[i*num_samples_factor:(i+1)*num_samples_factor]) samples_np[:] = samples.cpu().data.numpy() samples_flat_np = np.reshape(samples_np, (samples_np.shape[0], np.prod(samples_np.shape[1:]))) #----------------------------------------------------------------------------------------------------------- # find the nearest neighbours self.dci_db.reset() self.dci_db.add(np.copy(samples_flat_np),\ num_levels = 2, field_of_view = 10, prop_to_retrieve = 0.002) nearest_indices_temp, _ = self.dci_db.query(data_flat_np[i:i+1],\ num_neighbours = 1, field_of_view = 20, prop_to_retrieve = 0.02) nearest_indices[i] = nearest_indices_temp[0][0] + i*num_samples_factor if epoch == 0: print(np.percentile(data_flat_np, 25), np.percentile(data_flat_np, 50), np.percentile(data_flat_np, 75)) print(np.percentile(samples_flat_np, 25), np.percentile(samples_flat_np, 50), np.percentile(samples_flat_np, 75)) # restrict latent parameters to the nearest neighbour z_Sx = z_Sx_all[nearest_indices] #============================================================================================================= # gradient descent err = 0. # loop over all batches for i in range(num_batches): self.model.zero_grad() cur_samples = self.model(z_Sx[i*batch_size:(i+1)*batch_size]) # save the mock sample if (epoch+1) % staleness == 0: samples_predict[i*batch_size:(i+1)*batch_size] = cur_samples.cpu().data.numpy() # gradient descent loss = loss_fn(cur_samples, data_all[i*batch_size:(i+1)*batch_size]) loss.backward() err += loss.item() optimizer.step() print("Epoch %d: Error: %f" % (epoch, err / num_batches)) #----------------------------------------------------------------------------------------------------------- # save the mock sample if (epoch+1) % staleness == 0: #np.savez("../results_2D_times=10_J=4_L=2_epoch=" + str(epoch) + ".npz", data_np=data_np,\ # z_Sx_np=z_Sx.cpu().data.numpy(),\ # samples_np=samples_predict) # make random mock samples_random = self.model(z_Sx_all[:10**4][::100]).cpu().data.numpy() np.savez("../results_2D_random_times=10_" + name_JL + "_epoch=" + str(epoch) + ".npz",\ samples_np=samples_random,\ mse_err=err / num_batches) # save network torch.save(self.model.state_dict(), '../net_weights_2D_times=10_' + name_JL + '_epoch=' \ + str(epoch) + '.pth')
class IMLE(): def __init__(self, z_dim): self.z_dim = z_dim self.dci_db = None def train(self, data_np, hyperparams, shuffle_data=True): batch_size = hyperparams.batch_size num_batches = data_np.shape[0] // batch_size num_samples = num_batches * hyperparams.num_samples_factor num_outputs = hyperparams.num_outputs if shuffle_data: data_ordering = np.random.permutation(data_np.shape[0]) data_np = data_np[data_ordering] data_flat_np = np.reshape( data_np, (data_np.shape[0], np.prod(data_np.shape[1:]))) input_shape = [64, 1, 1, self.z_dim] net = get_model(input_shape) train_weights = net.trainable_weights if self.dci_db is None: self.dci_db = DCI(np.prod(data_np.shape[1:]), num_comp_indices=2, num_simp_indices=7) for epoch in range(hyperparams.num_epochs): if epoch % hyperparams.decay_step == 0: lr = hyperparams.base_lr * hyperparams.decay_rate**( epoch // hyperparams.decay_step) optimizer = tf.optimizers.Adam(learning_rate=lr, beta_1=0.5, beta_2=0.999) # Optimizer API changes. betas->beta_1, beta_2 & remove weight_decay=1e-5 if epoch % hyperparams.staleness == 0: net.eval() #declare now in eval mode. z_np = np.empty((num_samples * batch_size, 1, 1, self.z_dim)) samples_np = np.empty((num_samples * batch_size, ) + data_np.shape[1:]) for i in range(num_samples): z = tf.random.normal([batch_size, 1, 1, self.z_dim]) samples = net(z) z_np[i * batch_size:(i + 1) * batch_size] = z.numpy() samples_np[i * batch_size:(i + 1) * batch_size] = samples.numpy() samples_flat_np = np.reshape( samples_np, (samples_np.shape[0], np.prod(samples_np.shape[1:]))) self.dci_db.reset() self.dci_db.add(samples_flat_np, num_levels=2, field_of_view=10, prop_to_retrieve=0.002) nearest_indices, _ = self.dci_db.query(data_flat_np, num_neighbours=1, field_of_view=20, prop_to_retrieve=0.02) nearest_indices = np.array(nearest_indices)[:, 0] z_np = z_np[nearest_indices] z_np += 0.01 * np.random.randn(*z_np.shape) del samples_np, samples_flat_np err = 0. for i in range(num_batches): net.eval() cur_z = tf.convert_to_tensor(z_np[i * batch_size:(i + 1) * batch_size], dtype=tf.float32) cur_data = tf.convert_to_tensor( data_np[i * batch_size:(i + 1) * batch_size], dtype=tf.float32) cur_samples = net(cur_z) _loss = tl.cost.mean_squared_error(cur_data, cur_samples, is_mean=False) err += _loss print("Epoch %d: Error: %f" % (epoch, err / num_batches)) for i in range(num_batches): net.train() with tf.GradientTape() as tape: cur_z = tf.convert_to_tensor(z_np[i * batch_size:(i + 1) * batch_size], dtype=tf.float32) cur_data = tf.convert_to_tensor( data_np[i * batch_size:(i + 1) * batch_size], dtype=tf.float32) cur_samples = net(cur_z) _loss = tl.cost.mean_squared_error(cur_data, cur_samples, is_mean=False) grad = tape.gradient(_loss, train_weights) optimizer.apply_gradients(zip(grad, train_weights)) #debug. extreamly lower the speed. but worth doing to visualize the process. (following 4 lines can be deleted directly) #cur_z = tf.convert_to_tensor(z_np[0: batch_size], dtype = tf.float32) #cur_samples = net(cur_z) #pics = np.reshape(cur_samples.numpy(), (batch_size, 28, 28)) #save_img(1, pics) cur_z = tf.random.normal([batch_size, 1, 1, self.z_dim]) cur_samples = net(cur_z) pics = np.reshape(cur_samples.numpy(), (batch_size, 28, 28)) save_img(2, pics) #cur_z = tf.convert_to_tensor(z_np[0: batch_size], dtype = tf.float32) cur_z = tf.random.normal([batch_size, 1, 1, self.z_dim]) cur_samples = net(cur_z) pics = np.reshape(cur_samples.numpy(), (batch_size, 28, 28)) save_img(batch_size, pics)
class IMLE(): def __init__(self, z_dim): self.z_dim = z_dim #self.model = ConvolutionalImplicitModel(z_dim).cpu() self.model = BigGAN.Generator(resolution=Resolution, dim_z=z_dim, n_classes=Classes).cpu() self.dci_db = None def train(self, data_np, label_np, hyperparams, shuffle_data=True, path='results'): loss_fn = nn.MSELoss().cpu() self.model.train() batch_size = hyperparams.batch_size num_batches = data_np.shape[0] // batch_size num_samples = num_batches * hyperparams.num_samples_factor grid_size = (5, 5) grid_z = torch.randn(np.prod(grid_size), self.z_dim).cpu() grid_y = torch.randint(low=0, high=Classes, size=(np.prod(grid_size), ), dtype=torch.int64).cpu() if shuffle_data: data_ordering = np.random.permutation(data_np.shape[0]) data_np = data_np[data_ordering] label_np = label_np[data_ordering] data_flat_np = np.reshape( data_np, (data_np.shape[0], np.prod(data_np.shape[1:]))) if self.dci_db is None: self.dci_db = DCI(np.prod(data_np.shape[1:]), num_comp_indices=2, num_simp_indices=7) for epoch in range(hyperparams.num_epochs): if epoch % 10 == 0: print('Saving net_weights.pth...') torch.save(self.model.state_dict(), os.path.join(path, 'net_weights.pth')) if epoch % 5 == 0: print('Saving grid...') save_grid(path=path, index=epoch, count=np.prod(grid_size), imle=self, z=grid_z, y=grid_y, z_dim=self.z_dim) print('Saved') if epoch % hyperparams.decay_step == 0: lr = hyperparams.base_lr * hyperparams.decay_rate**( epoch // hyperparams.decay_step) optimizer = optim.Adam(self.model.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=1e-5) if epoch % hyperparams.staleness == 0: z_np = np.empty((num_samples * batch_size, self.z_dim)) y_np = np.empty((num_samples * batch_size, 1), dtype=np.int64) samples_np = np.empty((num_samples * batch_size, ) + data_np.shape[1:]) for i in range(num_samples): z = torch.randn(batch_size, self.z_dim, requires_grad=False).cpu() y = torch.randint(low=0, high=Classes, size=(batch_size, 1), dtype=torch.int64, requires_grad=False).cpu() samples = self.model(z, self.model.shared(y)) #import pdb; pdb.set_trace() z_np[i * batch_size:(i + 1) * batch_size] = z.cpu().data.numpy() y_np[i * batch_size:(i + 1) * batch_size] = y.cpu().data.numpy() samples_np[i * batch_size:(i + 1) * batch_size] = samples.cpu().data.numpy() samples_flat_np = np.reshape( samples_np, (samples_np.shape[0], np.prod(samples_np.shape[1:]))).copy() self.dci_db.reset() self.dci_db.add(samples_flat_np, num_levels=2, field_of_view=10, prop_to_retrieve=0.002) nearest_indices, _ = self.dci_db.query(data_flat_np, num_neighbours=1, field_of_view=20, prop_to_retrieve=0.02) nearest_indices = np.array(nearest_indices)[:, 0] z_np = z_np[nearest_indices] z_np += 0.01 * np.random.randn(*z_np.shape) y_np = y_np[nearest_indices] del samples_np, samples_flat_np err = 0. for i in range(num_batches): self.model.zero_grad() cur_z = torch.from_numpy(z_np[i * batch_size:(i + 1) * batch_size]).float().cpu() cur_y = torch.from_numpy(y_np[i * batch_size:(i + 1) * batch_size]).long().cpu() cur_data = torch.from_numpy(data_np[i * batch_size:(i + 1) * batch_size]).float().cpu() cur_samples = self.model(cur_z, self.model.shared(cur_y)) loss = loss_fn(cur_samples, cur_data) loss.backward() err += loss.item() optimizer.step() print("Epoch %d: Error: %f" % (epoch, err / num_batches))
class IMLE(): def __init__(self, z_dim, Sx_dim): self.z_dim = z_dim self.Sx_dim = Sx_dim self.model = ConvolutionalImplicitModel(z_dim + Sx_dim).cuda() self.dci_db = None #----------------------------------------------------------------------------------------------------------- def train(self, data_np, data_Sx, base_lr=1e-4, batch_size=256, num_epochs=3000,\ decay_step=25, decay_rate=0.95, staleness=100, num_samples_factor=100): # define metric # loss_fn = nn.MSELoss().cuda() loss_fn = nn.L1Loss().cuda() # loss_fn = nn.BCELoss().cuda() self.model.train() # train in batch num_batches = data_np.shape[0] // batch_size # truncate data to fit the batch size num_data = num_batches * batch_size data_np = data_np[:num_data] data_Sx = data_Sx[:num_data] #----------------------------------------------------------------------------------------------------------- # make empty array to store results samples_predict = np.empty(data_np.shape) samples_np = np.empty((num_samples_factor, ) + data_np.shape[1:]) # samples_np = np.empty((num_data*num_samples_factor,)+data_np.shape[1:]) nearest_indices = np.empty((num_data)).astype("int") # make global torch variables data_all = torch.from_numpy(data_np).float().cuda() Sx = torch.from_numpy(np.repeat(data_Sx, num_samples_factor, axis=0)).float().cuda() # initiate dci if self.dci_db is None: self.dci_db = DCI(np.prod(data_np.shape[1:]), num_comp_indices=2, num_simp_indices=7) #============================================================================================================= # train through various epochs for epoch in range(num_epochs): # decay the learning rate if epoch % decay_step == 0: lr = base_lr * decay_rate**(epoch // decay_step) optimizer = optim.Adam(self.model.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=1e-5) #----------------------------------------------------------------------------------------------------------- # update the closest models routintely if epoch % staleness == 0: # draw random z z = torch.randn(num_data * num_samples_factor, self.z_dim).cuda() #z_Sx_all = torch.cat((z, Sx), axis=1) z_Sx_all = torch.cat((z, Sx), axis=1)[:, :, None] #----------------------------------------------------------------------------------------------------------- # find the closest object for individual data nearest_indices = np.empty((num_data)).astype("int") for i in range(num_data): samples = self.model( z_Sx_all[i * num_samples_factor:(i + 1) * num_samples_factor]) samples_np[:] = samples.cpu().data.numpy() # find the nearest neighbours self.dci_db.reset() self.dci_db.add(np.copy(samples_np),\ num_levels = 2, field_of_view = 10, prop_to_retrieve = 0.002) nearest_indices_temp, _ = self.dci_db.query(data_np[i:i+1],\ num_neighbours = 1, field_of_view = 20, prop_to_retrieve = 0.02) nearest_indices[i] = nearest_indices_temp[0][ 0] + i * num_samples_factor #----------------------------------------------------------------------------------------------------------- # # find the closest object for individual data # samples = self.model(z_Sx_all) # samples_np[:] = samples.cpu().data.numpy() # # # find the nearest neighbours # self.dci_db.reset() # self.dci_db.add(np.copy(samples_np),\ # num_levels = 2, field_of_view = 10, prop_to_retrieve = 0.002) # nearest_indices_temp, _ = self.dci_db.query(data_np,\ # num_neighbours = 1, field_of_view = 20, prop_to_retrieve = 0.02) # nearest_indices[:] = nearest_indices_temp #----------------------------------------------------------------------------------------------------------- # restrict latent parameters to the nearest neighbour z_Sx = z_Sx_all[nearest_indices] #============================================================================================================= # gradient descent err = 0. # loop over all batches for i in range(num_batches): self.model.zero_grad() cur_samples = self.model(z_Sx[i * batch_size:(i + 1) * batch_size]) # save the mock sample if (epoch + 1) % staleness == 0: samples_predict[i * batch_size:( i + 1) * batch_size] = cur_samples.cpu().data.numpy() # gradient descent loss = loss_fn(cur_samples, data_all[i * batch_size:(i + 1) * batch_size]) loss.backward() err += loss.item() optimizer.step() print("Epoch %d: Error: %f" % (epoch, err / num_batches)) #----------------------------------------------------------------------------------------------------------- # save the mock sample if (epoch + 1) % staleness == 0: # save closet models np.savez("../results_spectra_deconv_256x2_" + str(epoch) + ".npz", data_np=data_np,\ z_Sx_np=z_Sx.cpu().data.numpy(),\ samples_np=samples_predict) np.savez("../mse_err_deconv_256x2_" + str(epoch) + ".npz",\ mse_err=err/num_batches) # save network torch.save( self.model.state_dict(), '../net_weights_spectra_deconv_256x2_epoch=' + str(epoch) + '.pth')
class IMLE(): def __init__(self, z_dim): self.z_dim = z_dim self.model = ConvolutionalImplicitModel(z_dim).cuda() self.dci_db = None def train(self, data_np, hyperparams, shuffle_data=True): loss_fn = nn.MSELoss().cuda() self.model.train() batch_size = hyperparams.batch_size num_batches = data_np.shape[0] // batch_size num_samples = num_batches * hyperparams.num_samples_factor if shuffle_data: data_ordering = np.random.permutation(data_np.shape[0]) data_np = data_np[data_ordering] data_flat_np = np.reshape( data_np, (data_np.shape[0], np.prod(data_np.shape[1:]))) if self.dci_db is None: self.dci_db = DCI(np.prod(data_np.shape[1:]), num_comp_indices=2, num_simp_indices=7) for epoch in range(hyperparams.num_epochs): if epoch % hyperparams.decay_step == 0: lr = hyperparams.base_lr * hyperparams.decay_rate**( epoch // hyperparams.decay_step) optimizer = optim.Adam(self.model.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=1e-5) if epoch % hyperparams.staleness == 0: z_np = np.empty((num_samples * batch_size, self.z_dim, 1, 1)) samples_np = np.empty((num_samples * batch_size, ) + data_np.shape[1:]) for i in range(num_samples): z = torch.randn(batch_size, self.z_dim, 1, 1).cuda() samples = self.model(z) z_np[i * batch_size:(i + 1) * batch_size] = z.cpu().data.numpy() samples_np[i * batch_size:(i + 1) * batch_size] = samples.cpu().data.numpy() samples_flat_np = np.reshape( samples_np, (samples_np.shape[0], np.prod(samples_np.shape[1:]))) self.dci_db.reset() self.dci_db.add(samples_flat_np, num_levels=2, field_of_view=10, prop_to_retrieve=0.002) nearest_indices, _ = self.dci_db.query(data_flat_np, num_neighbours=1, field_of_view=20, prop_to_retrieve=0.02) nearest_indices = np.array(nearest_indices)[:, 0] z_np = z_np[nearest_indices] z_np += 0.01 * np.random.randn(*z_np.shape) del samples_np, samples_flat_np err = 0. for i in range(num_batches): self.model.zero_grad() cur_z = torch.from_numpy(z_np[i * batch_size:(i + 1) * batch_size]).float().cuda() cur_data = torch.from_numpy( data_np[i * batch_size:(i + 1) * batch_size]).float().cuda() cur_samples = self.model(cur_z) loss = loss_fn(cur_samples, cur_data) loss.backward() err += loss.item() optimizer.step() print("Epoch %d: Error: %f" % (epoch, err / num_batches))
class IMLE(): def __init__(self, z_dim): self.z_dim = z_dim self.model = ConvolutionalImplicitModel(z_dim).cuda() self.dci_db = None def train(self, data_np, hyperparams, shuffle_data=True): loss_fn = nn.MSELoss().cuda() self.model.train() batch_size = hyperparams.batch_size num_batches = data_np.shape[0] // batch_size num_samples = num_batches * hyperparams.num_samples_factor # number of generated samples if shuffle_data: data_ordering = np.random.permutation(data_np.shape[0]) data_np = data_np[data_ordering] data_flat_np = np.reshape( data_np, (data_np.shape[0], np.prod(data_np.shape[1:]))) if self.dci_db is None: self.dci_db = DCI(np.prod(data_np.shape[1:]), num_comp_indices=2, num_simp_indices=7) for epoch in range(hyperparams.num_epochs): if epoch % hyperparams.decay_step == 0: lr = hyperparams.base_lr * hyperparams.decay_rate**( epoch // hyperparams.decay_step) optimizer = optim.Adam(self.model.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=1e-5) # Data generation step - do you re-sample training steps? # It does not seem like it is selecting the samples S? It always uses the original dataset data_np if epoch % hyperparams.staleness == 0: z_np = np.empty((num_samples * batch_size, self.z_dim, 1, 1)) samples_np = np.empty((num_samples * batch_size, ) + data_np.shape[1:]) for i in range(num_samples): z = torch.randn(batch_size, self.z_dim, 1, 1).cuda() samples = self.model(z) z_np[i * batch_size:(i + 1) * batch_size] = z.cpu().data.numpy() samples_np[i * batch_size:(i + 1) * batch_size] = samples.cpu().data.numpy() samples_flat_np = np.reshape( samples_np, (samples_np.shape[0], np.prod(samples_np.shape[1:]))) self.dci_db.reset() self.dci_db.add(samples_flat_np, num_levels=2, field_of_view=10, prop_to_retrieve=0.002) nearest_indices, _ = self.dci_db.query(data_flat_np, num_neighbours=1, field_of_view=20, prop_to_retrieve=0.02) nearest_indices = np.array(nearest_indices)[:, 0] z_np = z_np[nearest_indices] z_np += 0.01 * np.random.randn(*z_np.shape) del samples_np, samples_flat_np # z_np consists of z values whose value makes up the closest sample to each data point # But why do we do it this way? # We want to compare the real data point x_i with the closest generated point. # Call this point hat{x}_{sigma(i)}, and this was generated by the random noise, say, G_{theta}(z_i). # Once we have determined this point, how can we change G such that hat{x}_{sigma(i)} is closer to x_i? # We minimize || G_{theta)(z_i) - x_i|| # What happens in the conditional case, where we condition on a random variable A? # The idea is to generate m different samples for each a in A, and then out of # m samples, find the one that is closest to the data point err = 0. # batch-gradient descent for i in range(num_batches): self.model.zero_grad() cur_z = torch.from_numpy(z_np[i * batch_size:(i + 1) * batch_size]).float().cuda() cur_data = torch.from_numpy( data_np[i * batch_size:(i + 1) * batch_size]).float().cuda() cur_samples = self.model(cur_z) loss = loss_fn(cur_samples, cur_data) loss.backward() err += loss.item() optimizer.step() print("Epoch %d: Error: %f" % (epoch, err / num_batches))
class CoolSystem(pl.LightningModule): def __init__(self, z_dim, hyperparams, shuffle_data=True): super(CoolSystem, self).__init__() # not the best model... #self.l1 = torch.nn.Linear(28 * 28, 10) self.z_dim = z_dim self.z_grid = None self.hyperparams = hyperparams self.shuffle_data = shuffle_data self.dci_db = None self.model = ConvolutionalImplicitModel(z_dim) #self.model = Generator128(z_dim, n_class=10) #self.squeezeNet = NetPerLayer() self.squeezeNet = None self.loss_fn = nn.MSELoss() self._step = 0 def regen(self, batch): #import pdb; pdb.set_trace() imgs, labels = batch data_np = imgs.numpy() #import pdb; pdb.set_trace() data_flat_np = np.reshape(data_np, (data_np.shape[0], np.prod(data_np.shape[1:]))) self.data_np = data_np self.data_flat_np = data_flat_np hyperparams = self.hyperparams batch_size = hyperparams.batch_size #num_batches = data_np.shape[0] // batch_size num_batches = 1 num_samples = num_batches * hyperparams.num_samples_factor #import pdb; pdb.set_trace() z_np = np.empty((num_samples * batch_size, self.z_dim, 1, 1)) samples_np = np.empty((num_samples * batch_size,)+data_np.shape[1:]) for i in range(num_samples): z = torch.randn(batch_size, self.z_dim, 1, 1).cpu() samples = self.model(z, labels) z_np[i*batch_size:(i+1)*batch_size] = z.cpu().data.numpy() samples_np[i*batch_size:(i+1)*batch_size] = samples.cpu().data.numpy() samples_flat_np = np.reshape(samples_np, (samples_np.shape[0], np.prod(samples_np.shape[1:]))).copy() if self.dci_db is None: self.dci_db = DCI(np.prod(data_np.shape[1:]), num_comp_indices = 2, num_simp_indices = 7) self.dci_db.reset() self.dci_db.add(samples_flat_np, num_levels = 2, field_of_view = 10, prop_to_retrieve = 0.002) nearest_indices, _ = self.dci_db.query(data_flat_np, num_neighbours = 1, field_of_view = 20, prop_to_retrieve = 0.02) nearest_indices = np.array(nearest_indices)[:,0] z_np = z_np[nearest_indices] z_np += 0.01*np.random.randn(*z_np.shape) self.z_np = z_np if self.z_grid is None: self.z_grid = self.z_np.copy() * 0.7 del samples_np, samples_flat_np return z_np # def forward(self, x): # return torch.relu(self.l1(x.view(x.size(0), -1))) def forward(self, z, class_id=None): cur_samples = self.model(z, class_id) return cur_samples def training_step(self, batch, batch_idx): # REQUIRED #x, y = batch #y_hat = self.forward(x) #loss = F.cross_entropy(y_hat, y) #batch_size = batch.shape[0] hyperparams = self.hyperparams batch_size = hyperparams.batch_size i = batch_idx data_np = self.data_np #z_np = self.regen() z_np = self.z_np z_grid = self.z_grid #cur_z = torch.from_numpy(z_np[i*batch_size:(i+1)*batch_size]).float().cpu() #imgTarget = torch.from_numpy(data_np[i*batch_size:(i+1)*batch_size]).float().cpu() cur_z = torch.from_numpy(z_np).float().cpu() imgTarget = torch.from_numpy(data_np).float().cpu() print(cur_z.shape, batch_idx) def nimg(img): #img = img + 1 #img = img / 2 #import pdb; pdb.set_trace() img = torch.clamp(img, 0, 1) return img imgInput = batch[0] imgLabels = batch[1] self.logger.experiment.add_image('imgInput', torchvision.utils.make_grid(nimg(imgInput)), self._step) imgOutput = self.forward(cur_z, imgLabels) self.logger.experiment.add_image('imgOutput', torchvision.utils.make_grid(nimg(imgOutput)), self._step) if self._step % me.interp_every == 0 and self._step > -1 and True: cur_zgrid = torch.from_numpy(z_grid).float().cpu() imgGrid = self.forward(cur_zgrid, imgLabels) self.logger.experiment.add_image('imgGrid', torchvision.utils.make_grid(nimg(imgGrid)), self._step) #print('circle...') #import pdb; pdb.set_trace() imgs = me.circle_interpolation(self.model, imgInput.shape[0]) #imgs = np.array(imgs) #imgs = torch.from_numpy(imgs).float().cpu() self.logger.experiment.add_image('imgInterp', torchvision.utils.make_grid(nimg(imgs)), self._step) #import pdb; pdb.set_trace() if self.squeezeNet is not None: #print('squeezeNet(imgTarget)...') activationsTarget = self.squeezeNet(imgTarget.repeat(1,3,1,1)) #print('squeezeNet(imgOutput)...') activationsOutput = self.squeezeNet(imgOutput.repeat(1,3,1,1)) #print('loss...') featLoss = None #for actTarget, actOutput in zip(activationsTarget[1:3], activationsOutput[1:3]): #for actTarget, actOutput in tqdm.tqdm(list(zip(activationsTarget, activationsOutput))): for actTarget, actOutput in zip(activationsTarget, activationsOutput): #l = F.mse_loss(actTarget, actOutput) #l = torch.abs(actTarget - actOutput).sum() #l = F.l1_loss(actTarget, actOutput) l = -pearsonr(actTarget.view(-1), actOutput.view(-1)) if featLoss is None: featLoss = l else: featLoss += l else: featLoss = 0.0 #pixelLoss = F.mse_loss(imgOutput, imgTarget) pixelLoss = self.loss_fn(imgOutput, imgTarget) loss = featLoss + pixelLoss lr = self.lr_fn(self.current_epoch) tensorboard_logs = {'train_loss': loss, 'lr': lr, 'epoch': self.current_epoch} self._step += 1 return {'loss': loss, 'log': tensorboard_logs} def number_interpolation(self, count, lo, hi): hyperparams = self.hyperparams batch_size = hyperparams.batch_size def gen_latent(pos): z = gen_func(pos) z = z.reshape([1,-1,1,1]) return z # def validation_step(self, batch, batch_idx): # # OPTIONAL # x, y = batch # y_hat = self.forward(x) # return {'val_loss': F.cross_entropy(y_hat, y)} # def validation_end(self, outputs): # # OPTIONAL # avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() # tensorboard_logs = {'val_loss': avg_loss} # return {'avg_val_loss': avg_loss, 'log': tensorboard_logs} # def test_step(self, batch, batch_idx): # # OPTIONAL # x, y = batch # y_hat = self.forward(x) # return {'test_loss': F.cross_entropy(y_hat, y)} # def test_end(self, outputs): # # OPTIONAL # avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean() # tensorboard_logs = {'test_loss': avg_loss} # return {'avg_test_loss': avg_loss, 'log': tensorboard_logs} def configure_optimizers(self): # REQUIRED # can return multiple optimizers and learning_rate schedulers # (LBFGS it is automatically supported, no need for closure function) #return torch.optim.Adam(self.parameters(), lr=0.0004) #epoch = 0 hyperparams = self.hyperparams lr = hyperparams.base_lr # * hyperparams.decay_rate ** (epoch // hyperparams.decay_step) self.lr_fn = lambda epoch: hyperparams.decay_rate ** (epoch // hyperparams.decay_step) #optimizer = torch.optim.Adam(self.model.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=1e-5) #optimizer = torch.optim.Adam(self.model.parameters(), lr=lr, amsgrad=True) optimizer = torch.optim.Adam(self.model.parameters(), lr=lr, betas=(0.5, 0.999), amsgrad=True, weight_decay=1e-5) #optimizer = torch.optim.Adam(self.model.parameters(), lr=lr, betas=(0.5, 0.999), amsgrad=True) #optimizer = torch.optim.Adam(self.model.parameters(), lr=lr, betas=(0.0, 0.999)) from torch.optim.lr_scheduler import LambdaLR scheduler = LambdaLR(optimizer, lr_lambda=[self.lr_fn]) return [optimizer], [scheduler] def prepare_data(self): #MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()) #MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor()) CIFAR10(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()) CIFAR10(os.getcwd(), train=False, download=True, transform=transforms.ToTensor()) #self.data_np = np.random.randn(128, 1, 28, 28) # self.data_np = np.array([img.numpy() for img, label in dataset]) # print('ready') # data_np = self.data_np # hyperparams = self.hyperparams # if self.shuffle_data: # data_ordering = np.random.permutation(data_np.shape[0]) # data_np = data_np[data_ordering] # batch_size = hyperparams.batch_size # num_batches = data_np.shape[0] // batch_size # num_samples = num_batches * hyperparams.num_samples_factor # self.data_flat_np = np.reshape(data_np, (data_np.shape[0], np.prod(data_np.shape[1:]))) def on_batch_start(self, batch): self.regen(batch) def train_dataloader(self): # REQUIRED #dataset = MNIST(os.getcwd(), train=True, download=False, transform=transforms.ToTensor()) #dataset = CIFAR10(os.getcwd(), train=True, download=False, transform=transforms.ToTensor()) #loader = DataLoader(dataset, batch_size=32) # Downloading/Louding CIFAR10 data trainset = CIFAR10(root=os.getcwd(), train=True , download=False)#, transform = transform_with_aug) testset = CIFAR10(root=os.getcwd(), train=False, download=False)#, transform = transform_no_aug) classDict = {'plane':0, 'car':1, 'bird':2, 'cat':3, 'deer':4, 'dog':5, 'frog':6, 'horse':7, 'ship':8, 'truck':9} # Separating trainset/testset data/label x_train = trainset.data x_test = testset.data y_train = trainset.targets y_test = testset.targets # Define a function to separate CIFAR classes by class index def get_class_i(x, y, i): """ x: trainset.train_data or testset.test_data y: trainset.train_labels or testset.test_labels i: class label, a number between 0 to 9 return: x_i """ # Convert to a numpy array y = np.array(y) # Locate position of labels that equal to i pos_i = np.argwhere(y == i) # Convert the result into a 1-D list pos_i = list(pos_i[:,0]) # Collect all data that match the desired label x_i = [x[j] for j in pos_i] return x_i class DatasetMaker(Dataset): def __init__(self, datasets, transformFunc = transforms.ToTensor()): """ datasets: a list of get_class_i outputs, i.e. a list of list of images for selected classes """ self.datasets = datasets self.lengths = [len(d) for d in self.datasets] self.transformFunc = transformFunc def __getitem__(self, i): class_label, index_wrt_class = self.index_of_which_bin(self.lengths, i) img = self.datasets[class_label][index_wrt_class] if self.transformFunc: img = self.transformFunc(img) return img, class_label def __len__(self): return sum(self.lengths) def index_of_which_bin(self, bin_sizes, absolute_index, verbose=False): """ Given the absolute index, returns which bin it falls in and which element of that bin it corresponds to. """ # Which class/bin does i fall into? accum = np.add.accumulate(bin_sizes) if verbose: print("accum =", accum) bin_index = len(np.argwhere(accum <= absolute_index)) if verbose: print("class_label =", bin_index) # Which element of the fallent class/bin does i correspond to? index_wrt_class = absolute_index - np.insert(accum, 0, 0)[bin_index] if verbose: print("index_wrt_class =", index_wrt_class) return bin_index, index_wrt_class # ================== Usage ================== # # Let's choose cats (class 3 of CIFAR) and dogs (class 5 of CIFAR) as trainset/testset cat_dog_trainset = \ DatasetMaker( [get_class_i(x_train, y_train, classDict['cat']), get_class_i(x_train, y_train, classDict['dog'])] #transform_with_aug ) cat_dog_testset = \ DatasetMaker( [get_class_i(x_test , y_test , classDict['cat']), get_class_i(x_test , y_test , classDict['dog'])] #transform_no_aug ) kwargs = {'num_workers': 2, 'pin_memory': False} hyperparams = self.hyperparams batch_size = hyperparams.batch_size # Create datasetLoaders from trainset and testset trainsetLoader = DataLoader(cat_dog_trainset, batch_size=batch_size, shuffle=True , **kwargs) testsetLoader = DataLoader(cat_dog_testset , batch_size=batch_size, shuffle=False, **kwargs) #loader = DataLoader(dataset, batch_size=batch_size, shuffle=True) # self.data_np = np.array([img.numpy() for img, label in dataset]) #self._train_loader = loader #return loader return trainsetLoader
class IMLE(): def __init__(self, z_dim): self.z_dim = z_dim self.model = ConvolutionalImplicitModel(z_dim).cuda() self.model2 = ConvolutionalImplicitModel(z_dim).cuda() self.dci_db = None self.dci_db2 = None #----------------------------------------------------------------------------------------------------------- def train(self, data_np, data_np2, base_lr=1e-3, batch_size=64, num_epochs=10000,\ decay_step=25, decay_rate=1.0, staleness=500, num_samples_factor=100): # define metric loss_fn = nn.MSELoss().cuda() # make model trainable self.model.train() self.model2.train() # train in batch num_batches = data_np.shape[0] // batch_size # true data to mock sample num_samples = num_batches * num_samples_factor #----------------------------------------------------------------------------------------------------------- # make it in 1D data image for DCI data_flat_np = np.reshape(data_np, (data_np.shape[0], np.prod(data_np.shape[1:]))) # initiate dci if self.dci_db is None: self.dci_db = DCI(np.prod(data_np.shape[1:]), num_comp_indices = 2, num_simp_indices = 7) if self.dci_db2 is None: self.dci_db2 = DCI(np.prod(data_np2.shape[1:]), num_comp_indices = 2, num_simp_indices = 7) #----------------------------------------------------------------------------------------------------------- # train through various epochs for epoch in range(num_epochs): # decay the learning rate if epoch % decay_step == 0: lr = base_lr * decay_rate ** (epoch // decay_step) optimizer = optim.Adam(list(self.model.parameters()) + list(self.model2.parameters()),\ lr=lr, betas=(0.5, 0.999), weight_decay=1e-5) #----------------------------------------------------------------------------------------------------------- # re-evaluate the closest models routinely if epoch % staleness == 0: # initiate numpy array to store latent draws and the associate sample z_np = np.empty((num_samples * batch_size, self.z_dim, 1, 1, 1)) samples_np = np.empty((num_samples * batch_size,)+data_np.shape[1:]) samples_np2 = np.empty((num_samples * batch_size,)+data_np2.shape[1:]) # make sample (in batch to avoid GPU memory problem) for i in range(num_samples): z = torch.randn(batch_size, self.z_dim, 1, 1, 1).cuda() samples = self.model(z) samples2 = self.model2(z) z_np[i*batch_size:(i+1)*batch_size] = z.cpu().data.numpy() samples_np[i*batch_size:(i+1)*batch_size] = samples.cpu().data.numpy() samples_np2[i*batch_size:(i+1)*batch_size] = samples2.cpu().data.numpy() # make 1D images samples_flat_np = np.reshape(samples_np, (samples_np.shape[0], np.prod(samples_np.shape[1:]))) #----------------------------------------------------------------------------------------------------------- # find the nearest neighbours self.dci_db.reset() self.dci_db.add(np.copy(samples_flat_np), num_levels = 2, field_of_view = 10, prop_to_retrieve = 0.002) nearest_indices, _ = self.dci_db.query(data_flat_np, num_neighbours = 1, field_of_view = 20, prop_to_retrieve = 0.02) nearest_indices = np.array(nearest_indices)[:,0] z_np = z_np[nearest_indices] # add random noise to the latent space to faciliate training z_np += 0.01*np.random.randn(*z_np.shape) # delete to save Hyperparameters del samples_np, samples_np2, samples_flat_np #============================================================================================================= # permute data data_ordering = np.random.permutation(data_np.shape[0]) data_np = data_np[data_ordering] data_np2 = data_np2[data_ordering] data_flat_np = np.reshape(data_np, (data_np.shape[0], np.prod(data_np.shape[1:]))) z_np = z_np[data_ordering] #----------------------------------------------------------------------------------------------------------- # gradient descent err = 0. # loop over all batches for i in range(num_batches): # set up backprop self.model.zero_grad() self.model2.zero_grad() #----------------------------------------------------------------------------------------------------------- # evaluate the models cur_z = torch.from_numpy(z_np[i*batch_size:(i+1)*batch_size]).float().cuda() cur_data = torch.from_numpy(data_np[i*batch_size:(i+1)*batch_size]).float().cuda() cur_data2 = torch.from_numpy(data_np2[i*batch_size:(i+1)*batch_size]).float().cuda() cur_samples = self.model(cur_z) cur_samples2 = self.model2(cur_z) #----------------------------------------------------------------------------------------------------------- # calculate MSE loss of the two images loss = loss_fn(cur_samples, cur_data) + loss_fn(cur_samples2, cur_data2) loss.backward() err += loss.item() optimizer.step() print("Epoch %d: Error: %f" % (epoch, err / num_batches)) # save the mock sample if (epoch+1) % staleness == 0: np.savez("../results_3D.npz", data_np=data_np,\ data_np2=data_np2,\ samples_np=self.model(torch.from_numpy(z_np).float().cuda()).cpu().data.numpy(),\ samples_np2=self.model2(torch.from_numpy(z_np).float().cuda()).cpu().data.numpy())