def generate_samples(self, sess, plot_size, manifold=False, file_id=None): # if z is None: # gen_im = sess.run(self._generate_op) # else: n_samples = plot_size * plot_size label_indices = None if self._use_label: cur_r = 0 label_indices = [] cur_label = -1 while cur_r < plot_size: cur_label = cur_label + 1 if cur_label < self._n_labels - 1 else 0 row_label = np.ones(plot_size) * cur_label label_indices.extend(row_label) cur_r += 1 if manifold: if self._dist == 'gaussian': random_code = distribution.interpolate( plot_size=plot_size, interpolate_range=[-3, 3, -3, 3]) self.viz_samples(sess, random_code, plot_size, file_id=file_id) else: for mode_id in range(self._n_labels): random_code = distribution.interpolate_gm( plot_size=plot_size, interpolate_range=[-1., 1., -0.2, 0.2], mode_id=mode_id, n_mode=self._n_labels) self.viz_samples(sess, random_code, plot_size, file_id='{}_{}'.format(file_id, mode_id)) else: if self._dist == 'gaussian': random_code = distribution.diagonal_gaussian( n_samples, self._g_model.n_code, mean=0, var=1.0) else: random_code = distribution.gaussian_mixture( n_samples, n_dim=self._g_model.n_code, n_labels=self._n_labels, x_var=0.5, y_var=0.1, label_indices=label_indices) self.viz_samples(sess, random_code, plot_size, file_id=file_id)
def train_z_gan_epoch(self, sess, ae_dropout=1.0, summary_writer=None): self._t_model.set_is_training(True) display_name_list = ['loss', 'd_loss', 'g_loss'] cur_summary = None # if self.epoch_id == 50: # self._lr = self._lr / 10 # if self.epoch_id == 200: # self._lr = self._lr / 10 if self.epoch_id == 100: self._lr = self._lr / 10 if self.epoch_id == 300: self._lr = self._lr / 10 cur_epoch = self._train_data.epochs_completed step = 0 loss_sum = 0 d_loss_sum = 0 g_loss_sum = 0 self.epoch_id += 1 while cur_epoch == self._train_data.epochs_completed: self.global_step += 1 step += 1 # batch_data = self._train_data.next_batch_dict() # im = batch_data['im'] # label = batch_data['label'] # _, d_loss = sess.run( # [self._train_d_op, self._d_loss_op], # feed_dict={self._t_model.image: im, # self._t_model.lr: self._lr, # self._t_model.keep_prob: 1.}) batch_data = self._train_data.next_batch_dict() im = batch_data['im'] label = batch_data['label'] if self._use_label: label_indices = label else: label_indices = None if self._dist == 'gmm': real_sample = distribution.gaussian_mixture( len(im), n_dim=self._t_model.n_code, n_labels=10, x_var=0.5, y_var=0.1, label_indices=label_indices) else: real_sample = distribution.diagonal_gaussian( len(im), self._t_model.n_code, mean=0, var=1.0) # train autoencoder _, loss, cur_summary = sess.run( [self._train_op, self._loss_op, self._train_summary_op], feed_dict={ self._t_model.image: im, self._t_model.lr: self._lr, self._t_model.keep_prob: ae_dropout, self._t_model.label: label, self._t_model.real_distribution: real_sample }) # train discriminator _, d_loss = sess.run( [self._train_d_op, self._d_loss_op], feed_dict={ self._t_model.image: im, self._t_model.label: label, self._t_model.lr: self._lr, self._t_model.keep_prob: 1., self._t_model.real_distribution: real_sample }) # train generator _, g_loss = sess.run( [self._train_g_op, self._g_loss_op], feed_dict={ self._t_model.image: im, self._t_model.label: label, self._t_model.lr: self._lr, self._t_model.keep_prob: 1. }) # batch_data = self._train_data.next_batch_dict() # im = batch_data['im'] # label = batch_data['label'] loss_sum += loss d_loss_sum += d_loss g_loss_sum += g_loss if step % 100 == 0: display(self.global_step, step, [loss_sum, d_loss_sum, g_loss_sum], display_name_list, 'train', summary_val=cur_summary, summary_writer=summary_writer) print('==== epoch: {}, lr:{} ===='.format(cur_epoch, self._lr)) display(self.global_step, step, [loss_sum, d_loss_sum, g_loss_sum], display_name_list, 'train', summary_val=cur_summary, summary_writer=summary_writer)