def generate_samples(self, sess, plot_size, manifold=False, file_id=None): # if z is None: # gen_im = sess.run(self._generate_op) # else: n_samples = plot_size * plot_size label_indices = None if self._use_label: cur_r = 0 label_indices = [] cur_label = -1 while cur_r < plot_size: cur_label = cur_label + 1 if cur_label < self._n_labels - 1 else 0 row_label = np.ones(plot_size) * cur_label label_indices.extend(row_label) cur_r += 1 if manifold: if self._dist == 'gaussian': random_code = distribution.interpolate( plot_size=plot_size, interpolate_range=[-3, 3, -3, 3]) self.viz_samples(sess, random_code, plot_size, file_id=file_id) else: for mode_id in range(self._n_labels): random_code = distribution.interpolate_gm( plot_size=plot_size, interpolate_range=[-1., 1., -0.2, 0.2], mode_id=mode_id, n_mode=self._n_labels) self.viz_samples(sess, random_code, plot_size, file_id='{}_{}'.format(file_id, mode_id)) else: if self._dist == 'gaussian': random_code = distribution.diagonal_gaussian( n_samples, self._g_model.n_code, mean=0, var=1.0) else: random_code = distribution.gaussian_mixture( n_samples, n_dim=self._g_model.n_code, n_labels=self._n_labels, x_var=0.5, y_var=0.1, label_indices=label_indices) self.viz_samples(sess, random_code, plot_size, file_id=file_id)
def train_z_gan_epoch(self, sess, ae_dropout=1.0, summary_writer=None): self._t_model.set_is_training(True) display_name_list = ['loss', 'd_loss', 'g_loss'] cur_summary = None # if self.epoch_id == 50: # self._lr = self._lr / 10 # if self.epoch_id == 200: # self._lr = self._lr / 10 if self.epoch_id == 100: self._lr = self._lr / 10 if self.epoch_id == 300: self._lr = self._lr / 10 cur_epoch = self._train_data.epochs_completed step = 0 loss_sum = 0 d_loss_sum = 0 g_loss_sum = 0 self.epoch_id += 1 while cur_epoch == self._train_data.epochs_completed: self.global_step += 1 step += 1 # batch_data = self._train_data.next_batch_dict() # im = batch_data['im'] # label = batch_data['label'] # _, d_loss = sess.run( # [self._train_d_op, self._d_loss_op], # feed_dict={self._t_model.image: im, # self._t_model.lr: self._lr, # self._t_model.keep_prob: 1.}) batch_data = self._train_data.next_batch_dict() im = batch_data['im'] label = batch_data['label'] if self._use_label: label_indices = label else: label_indices = None if self._dist == 'gmm': real_sample = distribution.gaussian_mixture( len(im), n_dim=self._t_model.n_code, n_labels=10, x_var=0.5, y_var=0.1, label_indices=label_indices) else: real_sample = distribution.diagonal_gaussian( len(im), self._t_model.n_code, mean=0, var=1.0) # train autoencoder _, loss, cur_summary = sess.run( [self._train_op, self._loss_op, self._train_summary_op], feed_dict={ self._t_model.image: im, self._t_model.lr: self._lr, self._t_model.keep_prob: ae_dropout, self._t_model.label: label, self._t_model.real_distribution: real_sample }) # train discriminator _, d_loss = sess.run( [self._train_d_op, self._d_loss_op], feed_dict={ self._t_model.image: im, self._t_model.label: label, self._t_model.lr: self._lr, self._t_model.keep_prob: 1., self._t_model.real_distribution: real_sample }) # train generator _, g_loss = sess.run( [self._train_g_op, self._g_loss_op], feed_dict={ self._t_model.image: im, self._t_model.label: label, self._t_model.lr: self._lr, self._t_model.keep_prob: 1. }) # batch_data = self._train_data.next_batch_dict() # im = batch_data['im'] # label = batch_data['label'] loss_sum += loss d_loss_sum += d_loss g_loss_sum += g_loss if step % 100 == 0: display(self.global_step, step, [loss_sum, d_loss_sum, g_loss_sum], display_name_list, 'train', summary_val=cur_summary, summary_writer=summary_writer) print('==== epoch: {}, lr:{} ===='.format(cur_epoch, self._lr)) display(self.global_step, step, [loss_sum, d_loss_sum, g_loss_sum], display_name_list, 'train', summary_val=cur_summary, summary_writer=summary_writer)
def train_semisupervised_epoch(self, sess, ae_dropout=1.0, summary_writer=None): label_data = self._train_data['labeled'] unlabel_data = self._train_data['unlabeled'] display_name_list = [ 'loss', 'z_d_loss', 'z_g_loss', 'y_d_loss', 'y_g_loss', 'cls_loss', 'cls_accuracy' ] cur_summary = None cur_epoch = unlabel_data.epochs_completed self.epoch_id += 1 if self.epoch_id == 150: self._lr = self._lr / 10 if self.epoch_id == 200: self._lr = self._lr / 10 step = 0 loss_sum = 0 z_d_loss_sum = 0 z_g_loss_sum = 0 y_d_loss_sum = 0 y_g_loss_sum = 0 cls_loss_sum = 0 cls_accuracy_sum = 0 while cur_epoch == unlabel_data.epochs_completed: self.global_step += 1 step += 1 batch_data = unlabel_data.next_batch_dict() im = batch_data['im'] label = batch_data['label'] z_real_sample = distribution.diagonal_gaussian( len(im), self._t_model.n_code, mean=0, var=1.0) y_real_sample = np.random.choice(self._t_model.n_class, len(im)) # a = np.array([1, 0, len(im)]) # b = np.zeros((len(im), self._t_model.n_class)) # b[np.arange(len(im)), y_real_sample] = 1 # y_real_sample = b # print(y_real_sample) # train autoencoder _, loss, cur_summary = sess.run( [self._train_op, self._loss_op, self._train_summary_op], feed_dict={ self._t_model.image: im, self._t_model.lr: self._lr, self._t_model.keep_prob: ae_dropout, self._t_model.label: label, self._t_model.real_distribution: z_real_sample, self._t_model.real_y: y_real_sample }) # z discriminator _, z_d_loss = sess.run( [self._train_d_op, self._d_loss_op], feed_dict={ self._t_model.image: im, # self._t_model.label: label, self._t_model.lr: self._lr, self._t_model.keep_prob: 1., self._t_model.real_distribution: z_real_sample }) # z generator _, z_g_loss = sess.run( [self._train_g_op, self._g_loss_op], feed_dict={ self._t_model.image: im, # self._t_model.label: label, self._t_model.lr: self._lr, self._t_model.keep_prob: 1. }) # y discriminator _, y_d_loss = sess.run( [self._train_cat_d_op, self._cat_d_loss_op], feed_dict={ self._t_model.image: im, # self._t_model.label: label, self._t_model.lr: self._lr, self._t_model.keep_prob: 1., self._t_model.real_y: y_real_sample }) # y generator _, y_g_loss = sess.run( [self._train_cat_g_op, self._cat_g_loss_op], feed_dict={ self._t_model.image: im, # self._t_model.label: label, self._t_model.lr: self._lr, self._t_model.keep_prob: 1. }) batch_data = label_data.next_batch_dict() im = batch_data['im'] label = batch_data['label'] # semisupervise if self.global_step % 10 == 0: _, cls_loss, cls_accuracy = sess.run( [ self._cls_train_op, self._cls_loss_op, self._cls_accuracy_op ], feed_dict={ self._t_model.image: im, self._t_model.label: label, self._t_model.lr: self._lr, self._t_model.keep_prob: 1. }) cls_loss_sum += cls_loss cls_accuracy_sum += cls_accuracy loss_sum += loss z_d_loss_sum += z_d_loss z_g_loss_sum += z_g_loss y_d_loss_sum += y_d_loss y_g_loss_sum += y_g_loss if step % 100 == 0: display( self.global_step, step, [ loss_sum, z_d_loss_sum, z_g_loss_sum, y_d_loss_sum, y_g_loss_sum, cls_loss_sum * 10, cls_accuracy_sum * 10 ], display_name_list, 'train', summary_val=cur_summary, summary_writer=summary_writer) print('==== epoch: {}, lr:{} ===='.format(cur_epoch, self._lr)) display( self.global_step, step, [loss_sum, z_d_loss_sum, z_g_loss_sum, y_d_loss_sum, y_g_loss_sum], display_name_list, 'train', summary_val=cur_summary, summary_writer=summary_writer)