def train(self, train_val_data_iterator): # counter, start_batch_id, start_epoch = self.initialize(train_val_data_iterator) counter = self.counter start_batch_id = self.start_batch_id start_epoch = self.start_epoch self.evaluate(start_epoch, start_batch_id, 0, val_data_iterator=train_val_data_iterator) num_batches_train = train_val_data_iterator.get_num_samples("train") // self.batch_size # loop for epoch start_time = time.time() for epoch in range(start_epoch, self.epoch): # get batch data for idx in range(start_batch_id, num_batches_train): # first 10 elements of manual_labels is actual one hot endoded labels # and next value is confidence currently binary(0/1). TODO change this to continuous # batch_images, _, manual_labels = train_val_data_iterator.get_next_batch("train") batch_z = prior.gaussian(self.batch_size, self.z_dim) # update autoencoder _, summary_str, loss, nll_loss, kl_loss, supervised_loss = self.sess.run( [self.optim, self.merged_summary_op, self.loss, self.neg_loglikelihood, self.KL_divergence, self.supervised_loss], feed_dict={self.inputs: batch_images, self.labels: manual_labels[:, :10], self.is_manual_annotated: manual_labels[:, 10], self.standard_normal: batch_z}) print("Epoch: [%2d] [%4d/%4d] time: %4.4f, loss: %.8f, nll: %.8f, kl: %.8f, supervised_loss: %.4f" % (epoch, idx, num_batches_train, time.time() - start_time, loss, nll_loss, kl_loss, supervised_loss)) counter += 1 if np.mod(idx, self.eval_interval) == 0: self.evaluate(epoch + 1, idx - 1, counter - 1, val_data_iterator=train_val_data_iterator) self.writer.add_summary(summary_str, counter-1) else: self.writer.add_summary(summary_str, counter-1) # After an epoch, start_batch_id is set to zero # non-zero value is only for the first epoch after loading pre-trained model start_batch_id = 0 # save model print("Saving check point", self.checkpoint_dir) self.save(self.checkpoint_dir, counter) train_val_data_iterator.reset_counter("train") # show temporal results # self.visualize_results() # save model for final step self.save(self.checkpoint_dir, counter)
def evaluate(self, epoch, step, counter, val_data_iterator): print("Running evaluation after epoch:{:02d} and step:{:04d} ".format( epoch, step)) # evaluate reconstruction loss start_eval_batch = 0 reconstructed_images = [] num_eval_batches = val_data_iterator.get_num_samples( "val") // self.batch_size for _idx in range(start_eval_batch, num_eval_batches): batch_eval_images, batch_eval_labels, manual_labels = val_data_iterator.get_next_batch( "val") integer_label = np.asarray([ np.where(r == 1)[0][0] for r in batch_eval_labels ]).reshape([64, 1]) batch_eval_labels = np.concatenate( [batch_eval_labels, integer_label], axis=1) columns = [str(i) for i in range(10)] columns.append("label") pd.DataFrame(batch_eval_labels, columns=columns)\ .to_csv(self.result_dir + "label_test_{:02d}.csv".format(_idx), index=False) batch_z = prior.gaussian(self.batch_size, self.z_dim) reconstructed_image, summary = self.sess.run( [self.out, self.merged_summary_op], feed_dict={ self.inputs: batch_eval_images, self.labels: manual_labels[:, :10], self.is_manual_annotated: manual_labels[:, 10], self.standard_normal: batch_z }) self.writer_v.add_summary(summary, counter) manifold_w = 4 tot_num_samples = min(self.sample_num, self.batch_size) manifold_h = tot_num_samples // manifold_w reconstructed_images.append( reconstructed_image[:manifold_h * manifold_w, :, :, :]) print("epoch:{} step:{}".format(epoch, step)) reconstructed_dir = get_eval_result_dir(self.result_dir, epoch, step) print(reconstructed_dir) for _idx in range(start_eval_batch, num_eval_batches): file = "im_" + str(_idx) + ".png" save_image(reconstructed_images[_idx], [manifold_h, manifold_w], reconstructed_dir + file) val_data_iterator.reset_counter("val") print("Evaluation completed")
def load_from_checkpoint(self): # initialize all variables tf.global_variables_initializer().run() # graph inputs for visualize training results self.sample_z = prior.gaussian(self.batch_size, self.z_dim) # restore check-point if it exits could_load, checkpoint_counter = self.load(self.checkpoint_dir) if could_load: print(" [*] Load SUCCESS") else: print(" [!] Load failed...") return checkpoint_counter
def __init__(self, sess, epoch, batch_size, z_dim, dataset_name, beta=5, num_units_in_layer=None, log_dir=None, checkpoint_dir=None, result_dir=None, train_val_data_iterator=None, read_from_existing_checkpoint=True, check_point_epochs=None, supervise_weight=0): self.sess = sess self.dataset_name = dataset_name self.epoch = epoch self.batch_size = batch_size self.num_val_samples = 128 self.log_dir = log_dir self.checkpoint_dir = checkpoint_dir self.result_dir = result_dir self.beta = beta self.supervise_weight = supervise_weight if dataset_name == 'mnist' or dataset_name == 'fashion-mnist': # parameters self.label_dim = 10 # one hot encoding for 10 classes self.input_height = 28 self.input_width = 28 self.output_height = 28 self.output_width = 28 self.z_dim = z_dim # dimension of noise-vector self.c_dim = 1 if num_units_in_layer is None or len(num_units_in_layer) == 0: self.n = [64, 128, 32, z_dim * 2] else: self.n = num_units_in_layer # train self.learning_rate = 0.0002 self.beta1 = 0.5 # test self.sample_num = 64 # number of generated images to be saved self.num_images_per_row = 4 # should be a factor of sample_num self.eval_interval = 300 # self.num_eval_batches = 10 else: raise NotImplementedError( "Dataset {} not implemented".format(dataset_name)) self.mu = tf.placeholder(tf.float32, [self.batch_size, self.z_dim], name='mu') self.sigma = tf.placeholder(tf.float32, [self.batch_size, self.z_dim], name='sigma') self.images = None self._build_model() # initialize all variables tf.global_variables_initializer().run() # graph inputs for visualize training results self.sample_z = prior.gaussian(self.batch_size, self.z_dim) self.max_to_keep = 20 self.counter, self.start_batch_id, self.start_epoch = self.initialize( train_val_data_iterator, read_from_existing_checkpoint, check_point_epochs)