def interactive_interp(sess, dcgan, config, sampling='uniform'): while True: z_samples = dcgan.z_sampler(config) has_labels = False try: if dcgan.has_labels: has_labels = True label = int(raw_input('Class label for first sample: ')) sample_labels = np.eye(dcgan.y_dim)[np.full(dcgan.batch_size, label)] except Exception: pass gauss_filter = gauss_kernel_fixed(config.gauss_sigma, config.gauss_trunc) if has_labels: samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_samples, dcgan.y: sample_labels, dcgan.gauss_kernel: gauss_filter}) else: samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_samples, dcgan.gauss_kernel: gauss_filter}) samples = np.array([((sample + 1) / 2 * 255).astype(np.uint8) for sample in samples]) grid_size = int(math.ceil(math.sqrt(dcgan.batch_size))) scipy.misc.imshow(merge(samples, (grid_size, grid_size))) # from IPython import embed; embed() start = int(raw_input('First sample number: ')) if has_labels: label2 = raw_input('Class label for second sample [same]: ') if label2 == '': label2 = label same = True else: label2 = int(label2) same = False sample_labels2 = np.eye(dcgan.y_dim)[np.full(dcgan.batch_size, label2)] if same: samples2 = samples else: samples2 = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_samples, dcgan.y: sample_labels2, dcgan.gauss_kernel: gauss_filter}) scipy.misc.imshow(merge(samples2, (grid_size, grid_size))) stop = int(raw_input('Second sample number: ')) n_steps = raw_input('Number of steps [62]: ') if n_steps == '': n_steps = 62 else: n_steps = int(n_steps) if has_labels: series = interpolate(sess, dcgan, z_start=z_samples[start - 1], z_stop=z_samples[stop - 1], n_steps=n_steps, y_start=label, y_stop=label2, transform=True) else: series = interpolate(sess, dcgan, z_start=z_samples[start-1], z_stop=z_samples[stop-1], n_steps=n_steps, transform=True) scipy.misc.imshow(merge(series, (int(math.ceil((n_steps + 2) / 8)), 8))) c = raw_input('Continue? [y/n]') if c != 'y': break
def interpolate(sess, dcgan, z_start, z_stop, n_steps=62, y_start=None, y_stop=None, transform=True): """Interpolates between two samples in z-space Input parameters: sess: TF session dcgan: DCGAN object for sampling z_start: z-vector of the first sample z_start: z-vector of the second sample n_steps: number of intermediate samples to produce sampling: the sampling method used for training ['uniform'] transform: if True, the pixel values will be transformed to their normal image range [True] y_start: label for first sample (numerical) y_stop: label for second sample (numerical) RETURNS an array of n_steps+2 samples""" y_dim = 0 if y_start is not None: y_dim = dcgan.y_dim if y_stop is None: y_stop = y_start if y_start != y_stop: z_start = np.concatenate((z_start, np.eye(y_dim)[y_start])) z_stop = np.concatenate((z_stop, np.eye(y_dim)[y_stop])) # limit to batch size for simplicity if n_steps > (dcgan.batch_size - 2): n_steps = dcgan.batch_size - 2 # sample along big circle for all distributions steps = np.linspace(0, 1, n_steps + 2) z_samples = [slerp(step, z_start, z_stop) for step in steps] gauss_filter = gauss_kernel_fixed(dcgan.gauss_sigma, (dcgan.kernel_size - 1) // 2) if n_steps != (dcgan.batch_size - 2): z_samples += [np.zeros(dcgan.z_dim + y_dim) for i in range(dcgan.batch_size - n_steps - 2)] if y_dim > 0: if y_start != y_stop: samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: np.array(z_samples)[:, :dcgan.z_dim], dcgan.y: np.array(z_samples)[:, dcgan.z_dim:], dcgan.gauss_kernel: gauss_filter},)[:n_steps + 2] else: samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: np.array(z_samples), dcgan.y: np.eye(y_dim) [np.full(dcgan.batch_size, y_start)], dcgan.gauss_kernel: gauss_filter})[:n_steps + 2] else: samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: np.array(z_samples), dcgan.gauss_kernel: gauss_filter})[:n_steps+2] if transform: samples = np.array([((sample + 1) / 2 * 255).astype(np.uint8) for sample in samples]) return samples
def train(self, config): """Train DCGAN""" train_gen, data_len = hdf5_images.load(batch_size=self.batch_size, data_file=self.data_dir, resolution=self.output_height, label_name=self.label_path) def inf_train_gen(): while True: for _images, _labels in train_gen(): yield _images, _labels gen = inf_train_gen() d_optim = tf.train.AdamOptimizer(self.lr, beta1=config.beta1) \ .minimize(self.d_loss, var_list=self.d_vars) g_optim = tf.train.AdamOptimizer(self.lr, beta1=config.beta1) \ .minimize(self.g_loss, var_list=self.g_vars) try: tf.global_variables_initializer().run() except: tf.initialize_all_tables().run() self.g_sum = merge_summary( [self.G_sum, self.d_loss_fake_sum, self.g_loss_sum]) #self.z_sum, self.d__sum (hist) self.d_sum = merge_summary( [self.lr_sum, self.z_sum, self.d_loss_real_sum, self.d_loss_sum]) #self.d_sum (hist) # label weight histogram # if self.has_labels: # self.g_sum = merge_summary([self.g_sum, self.w_y_sum]) # initialize summary writer: for each run create (enumerated) sub-directory if os.path.exists('./logs/' + self.dataset_name): # number of existing immediate child directories of log folder run_var = len(next(os.walk('./logs/' + self.dataset_name))[1]) + 1 else: run_var = 1 self.writer = SummaryWriter( '%s/%s' % ('./logs/' + self.dataset_name, run_var), self.sess.graph) sample_z = self.z_sampler(config) sample_inputs = gen.next() sample_images = sample_inputs[0] sample_images = sample_images.transpose((0, 2, 3, 1)) if self.has_labels: # use one of the following two lines to get either random samples or only samples from the first 8 classes # sample_y = np.eye(self.y_dim)[sample_inputs[1]] sample_y = np.eye(self.y_dim)[[1, 2, 3, 4, 5, 6, 7, 8] * 8] counter = 1 start_time = time.time() could_load, checkpoint_counter = self.load(self.checkpoint_dir) if could_load: counter = checkpoint_counter print(" [*] Load SUCCESS") else: print(" [!] Load failed...") batch_idxs = data_len // self.batch_size for epoch in xrange(config.epoch): if self.blur_fade: # sigma_used = self.gauss_sigma * (1 - (epoch / config.epoch)) # tapered version full_blur_ep = 0 # int(config.epoch * 0.2) no_blur_ep = 6 # int(config.epoch * 0.2) fade_ep = config.epoch - full_blur_ep - no_blur_ep if epoch < full_blur_ep: sigma_used = self.gauss_sigma elif epoch < full_blur_ep + fade_ep: sigma_used = self.gauss_sigma * (1 - ( (epoch + 1 - full_blur_ep) / fade_ep)) else: sigma_used = 0 else: sigma_used = self.gauss_sigma kernel_used = gauss_kernel_fixed(sigma_used, (self.kernel_size - 1) // 2) for idx in xrange(0, batch_idxs): batch_z = self.z_sampler(config) # when using random labels (commented-out line below), the fake label distribution # might not match the data and training can fail! # batch_y = tf.one_hot(np.random.random_integers(0, self.y_dim - 1, config.batch_size), self.y_dim) # better use the real images and labels: batch_images, batch_labels_num = gen.next() # if images are stored in BHWC-format, the following line should be commented out batch_images = batch_images.transpose( (0, 2, 3, 1)) #这里进行了维度转换!—————————————————————————— if self.has_labels: batch_labels = np.eye(self.y_dim)[batch_labels_num] if config.blur_input is not None: batch_images = gaussian_filter( batch_images, [0, config.blur_input, config.blur_input, 0]) # Update D network _, summary_str = self.sess.run( [d_optim, self.d_sum], feed_dict={ self.inputs: batch_images, self.z: batch_z, self.lr: config.learning_rate, self.gauss_kernel: kernel_used } if not self.has_labels else { self.inputs: batch_images, self.y: batch_labels, self.lr: config.learning_rate, self.z: batch_z, self.gauss_kernel: kernel_used }) self.writer.add_summary(summary_str, counter) # Update G network specified number of times for i in range(self.num_g_updates): _, summary_str = self.sess.run( [g_optim, self.g_sum], feed_dict={ self.inputs: batch_images, self.lr: config.learning_rate, self.z: batch_z, self.gauss_kernel: kernel_used } if not self.has_labels else { self.inputs: batch_images, self.y: batch_labels, self.lr: config.learning_rate, self.z: batch_z, self.gauss_kernel: kernel_used }) self.writer.add_summary(summary_str, counter) # Get losses for current batch errD_fake = self.d_loss_fake.eval( { self.z: batch_z, self.gauss_kernel: kernel_used } if not self.has_labels else { self.z: batch_z, self.y: batch_labels, self.gauss_kernel: kernel_used }) errD_real = self.d_loss_real.eval( { self.inputs: batch_images, self.gauss_kernel: kernel_used } if not self.has_labels else { self.inputs: batch_images, self.y: batch_labels, self.gauss_kernel: kernel_used }) errG = self.g_loss.eval({ self.z: batch_z, self.gauss_kernel: kernel_used } if not self.has_labels else { self.z: batch_z, self.y: batch_labels, self.gauss_kernel: kernel_used }) # add sigma to summary sigma_sum = tf.Summary(value=[ tf.Summary.Value(tag='gauss_sigma', simple_value=sigma_used) ]) self.writer.add_summary(sigma_sum, counter) counter += 1 print( "Epoch: [%2d/%2d] [%4d/%4d] [%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f, sigma: %4f" % (epoch, config.epoch, idx, batch_idxs, counter, time.time() - start_time, errD_fake + errD_real, errG, sigma_used)) # print losses for constant sample and save image with corresponding output if (counter % 500) == 1: try: samples, d_loss, g_loss = self.sess.run( [self.sampler, self.d_loss, self.g_loss], feed_dict={ self.z: sample_z, self.inputs: sample_images, self.y: sample_y, self.gauss_kernel: kernel_used } if self.has_labels else { self.z: sample_z, self.inputs: sample_images, self.gauss_kernel: kernel_used }) manifold_h = int(np.ceil(np.sqrt(samples.shape[0]))) manifold_w = int(np.floor(np.sqrt(samples.shape[0]))) save_images( samples, [manifold_h, manifold_w], './{}/train_{:02d}_{:04d}.png'.format( config.sample_dir, epoch, idx)) print("[Sample] d_loss: %.8f, g_loss: %.8f" % (d_loss, g_loss)) except Exception, e: print("Error saving sample image:") print e # save checkpoint if np.mod(counter, 3000) == 2: self.save(config.checkpoint_dir, counter)
def visualize(sess, dcgan, config, option): image_frame_dim = int(math.ceil(config.batch_size**.5)) # produce sample uniformly with nearest neighbour # option 0: additionally sort according to distance if (option == 1) or (option == 0): n_images = 20 has_labels = False try: if dcgan.has_labels: # generate one image for each cluster / category has_labels = True if option == 0: n_images = dcgan.y_dim except Exception: pass # sample DCGAN from uniform distribution in z print('sampling...') z_samples = dcgan.z_sampler(config) if has_labels: y_samples = np.eye(dcgan.y_dim)[np.random.choice( dcgan.y_dim, [n_images, config.batch_size])] samples = (z_samples, y_samples) samples = np.array([ sess.run(dcgan.sampler, { dcgan.z: batch, dcgan.y: batch_y }) for batch, batch_y in samples ]) else: samples = np.array([ sess.run(dcgan.sampler, feed_dict={dcgan.z: batch}) for batch in z_samples ]) # transform back to normal image value range and reshape to one array instead of batches print('transforming...') samples = np.array([((sample + 1) / 2 * 255).astype(np.uint8) for sample in samples]) \ .reshape((samples.shape[0] * samples.shape[1],) + samples.shape[2:]) # load and rescale training data to same size as samples print('loading and transforming orig data...') orig_data, _ = fh.load_icon_data(config.data_dir) orig_data = np.array([ scipy.misc.imresize(icon, (config.output_height, config.output_height)) for icon in orig_data ]) # get nearest neighbour indices from training set if option == 1: print('getting nearest neighbours...') nearest_idxs = metrics.nearest_icons(samples, orig_data) else: print('getting nearest neighbours...') nearest_idxs, distances = metrics.nearest_icons(samples, orig_data, get_dist=True) print('sorting...') # normalize distance over whole image content to prevent predominantly white images having low distance norms = np.sqrt(np.sum(np.power(samples, 2), axis=(1, 2, 3))) distances = np.array( [distance / n for distance, n in zip(distances, norms)]) sorting = np.argsort(distances) # import ipdb; ipdb.set_trace() samples = samples[sorting] nearest_idxs = np.array(nearest_idxs)[sorting] bs = config.batch_size for idx in xrange(n_images): print(" [*] %d" % idx) combined = [] # combine samples and nearest neighbours for each batch and save as png for sample, orig in zip( samples[idx * bs:(idx + 1) * bs], orig_data[nearest_idxs[idx * bs:(idx + 1) * bs]]): combined += [sample, orig] scipy.misc.imsave( os.path.join(config.sample_dir, 'test_uniform_nearest_%s.png' % (idx)), merge(np.array(combined), [image_frame_dim, image_frame_dim * 2])) # sample with uniform distribution if option == 2: n_images = 20 has_labels = False try: if dcgan.has_labels: # generate one image for each cluster / category n_images = dcgan.y_dim has_labels = True except Exception: pass for idx in xrange(n_images): print(" [*] %d" % idx) z_sample = dcgan.z_sampler(config) # create gaussian convolution kernel as defined in run parameters kernel = gauss_kernel_fixed(config.gauss_sigma, config.gauss_trunc) if has_labels: # y = np.random.choice(dcgan.y_dim, config.batch_size) # y_one_hot = np.zeros((config.batch_size, dcgan.y_dim)) # y_one_hot[np.arange(config.batch_size), y] = 1 y_one_hot = np.eye(dcgan.y_dim)[np.full( config.batch_size, idx)] # print(y_one_hot) samples = sess.run(dcgan.sampler, feed_dict={ dcgan.z: z_sample, dcgan.y: y_one_hot, dcgan.gauss_kernel: kernel }) else: samples = sess.run(dcgan.sampler, feed_dict={ dcgan.z: z_sample, dcgan.gauss_kernel: kernel }) save_images( samples, [image_frame_dim, image_frame_dim], os.path.join(config.sample_dir, 'test_uniform_%s.png' % (idx))) # sample with normal distribution if option == 3: for idx in xrange(100): print(" [*] %d" % idx) z_sample = np.random.normal(size=(config.batch_size, dcgan.z_dim)) samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample}) save_images(samples, [image_frame_dim, image_frame_dim], './samples/test_normal_%s.png' % (idx)) # single sample with uniform distribution if option == 4: z_sample = np.random.uniform(-0.5, 0.5, size=(config.batch_size, dcgan.z_dim)) samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample}) save_images( samples, [image_frame_dim, image_frame_dim], os.path.join(config.sample_dir, 'test_%s.png' % strftime("%Y%m%d%H%M%S", gmtime()))) # vary single z-component only if option == 5: values = np.arange(0, 1, 1. / config.batch_size) for idx in xrange(100): print(" [*] %d" % idx) z_sample = np.zeros([config.batch_size, dcgan.z_dim]) for kdx, z in enumerate(z_sample): z[idx] = values[kdx] save_images( samples, [image_frame_dim, image_frame_dim], os.path.join(config.sample_dir, 'test_arange_%s.png' % (idx)))