Пример #1
0
    def train(self, train_val_data_iterator):

        # counter, start_batch_id, start_epoch = self.initialize(train_val_data_iterator)
        counter = self.counter
        start_batch_id = self.start_batch_id
        start_epoch = self.start_epoch
        self.evaluate(start_epoch, start_batch_id, 0, val_data_iterator=train_val_data_iterator)
        num_batches_train = train_val_data_iterator.get_num_samples("train") // self.batch_size

        # loop for epoch
        start_time = time.time()

        for epoch in range(start_epoch, self.epoch):
            # get batch data
            for idx in range(start_batch_id, num_batches_train):
                # first 10 elements of manual_labels is actual one hot endoded labels
                # and next value is confidence currently binary(0/1). TODO change this to continuous
                #
                batch_images, _, manual_labels = train_val_data_iterator.get_next_batch("train")
                batch_z = prior.gaussian(self.batch_size, self.z_dim)

                # update autoencoder
                _, summary_str, loss, nll_loss, kl_loss, supervised_loss = self.sess.run(
                    [self.optim, self.merged_summary_op,
                     self.loss, self.neg_loglikelihood,
                     self.KL_divergence, self.supervised_loss],
                    feed_dict={self.inputs: batch_images,
                               self.labels: manual_labels[:, :10],
                               self.is_manual_annotated: manual_labels[:, 10],
                               self.standard_normal: batch_z})

                print("Epoch: [%2d] [%4d/%4d] time: %4.4f, loss: %.8f, nll: %.8f, kl: %.8f, supervised_loss: %.4f"
                      % (epoch, idx, num_batches_train, time.time() - start_time, loss, nll_loss, kl_loss,
                         supervised_loss))
                counter += 1
                if np.mod(idx, self.eval_interval) == 0:
                    self.evaluate(epoch + 1, idx - 1, counter - 1, val_data_iterator=train_val_data_iterator)
                    self.writer.add_summary(summary_str, counter-1)
                else:
                    self.writer.add_summary(summary_str, counter-1)

            # After an epoch, start_batch_id is set to zero
            # non-zero value is only for the first epoch after loading pre-trained model
            start_batch_id = 0

            # save model
            print("Saving check point", self.checkpoint_dir)
            self.save(self.checkpoint_dir, counter)
            train_val_data_iterator.reset_counter("train")

            # show temporal results
        # self.visualize_results()

        # save model for final step
        self.save(self.checkpoint_dir, counter)
Пример #2
0
    def evaluate(self, epoch, step, counter, val_data_iterator):
        print("Running evaluation after epoch:{:02d} and step:{:04d} ".format(
            epoch, step))
        # evaluate reconstruction loss
        start_eval_batch = 0
        reconstructed_images = []
        num_eval_batches = val_data_iterator.get_num_samples(
            "val") // self.batch_size
        for _idx in range(start_eval_batch, num_eval_batches):
            batch_eval_images, batch_eval_labels, manual_labels = val_data_iterator.get_next_batch(
                "val")
            integer_label = np.asarray([
                np.where(r == 1)[0][0] for r in batch_eval_labels
            ]).reshape([64, 1])
            batch_eval_labels = np.concatenate(
                [batch_eval_labels, integer_label], axis=1)
            columns = [str(i) for i in range(10)]
            columns.append("label")
            pd.DataFrame(batch_eval_labels,
                         columns=columns)\
                .to_csv(self.result_dir + "label_test_{:02d}.csv".format(_idx),
                        index=False)

            batch_z = prior.gaussian(self.batch_size, self.z_dim)
            reconstructed_image, summary = self.sess.run(
                [self.out, self.merged_summary_op],
                feed_dict={
                    self.inputs: batch_eval_images,
                    self.labels: manual_labels[:, :10],
                    self.is_manual_annotated: manual_labels[:, 10],
                    self.standard_normal: batch_z
                })

            self.writer_v.add_summary(summary, counter)

            manifold_w = 4
            tot_num_samples = min(self.sample_num, self.batch_size)
            manifold_h = tot_num_samples // manifold_w
            reconstructed_images.append(
                reconstructed_image[:manifold_h * manifold_w, :, :, :])
        print("epoch:{} step:{}".format(epoch, step))
        reconstructed_dir = get_eval_result_dir(self.result_dir, epoch, step)
        print(reconstructed_dir)

        for _idx in range(start_eval_batch, num_eval_batches):
            file = "im_" + str(_idx) + ".png"
            save_image(reconstructed_images[_idx], [manifold_h, manifold_w],
                       reconstructed_dir + file)
        val_data_iterator.reset_counter("val")

        print("Evaluation completed")
Пример #3
0
    def load_from_checkpoint(self):
        # initialize all variables
        tf.global_variables_initializer().run()

        # graph inputs for visualize training results
        self.sample_z = prior.gaussian(self.batch_size, self.z_dim)

        # restore check-point if it exits
        could_load, checkpoint_counter = self.load(self.checkpoint_dir)
        if could_load:
            print(" [*] Load SUCCESS")
        else:
            print(" [!] Load failed...")
        return checkpoint_counter
Пример #4
0
    def __init__(self,
                 sess,
                 epoch,
                 batch_size,
                 z_dim,
                 dataset_name,
                 beta=5,
                 num_units_in_layer=None,
                 log_dir=None,
                 checkpoint_dir=None,
                 result_dir=None,
                 train_val_data_iterator=None,
                 read_from_existing_checkpoint=True,
                 check_point_epochs=None,
                 supervise_weight=0):
        self.sess = sess
        self.dataset_name = dataset_name
        self.epoch = epoch
        self.batch_size = batch_size
        self.num_val_samples = 128
        self.log_dir = log_dir
        self.checkpoint_dir = checkpoint_dir
        self.result_dir = result_dir
        self.beta = beta
        self.supervise_weight = supervise_weight
        if dataset_name == 'mnist' or dataset_name == 'fashion-mnist':
            # parameters
            self.label_dim = 10  # one hot encoding for 10 classes
            self.input_height = 28
            self.input_width = 28
            self.output_height = 28
            self.output_width = 28

            self.z_dim = z_dim  # dimension of noise-vector
            self.c_dim = 1
            if num_units_in_layer is None or len(num_units_in_layer) == 0:
                self.n = [64, 128, 32, z_dim * 2]
            else:
                self.n = num_units_in_layer
            # train
            self.learning_rate = 0.0002
            self.beta1 = 0.5

            # test
            self.sample_num = 64  # number of generated images to be saved
            self.num_images_per_row = 4  # should be a factor of sample_num
            self.eval_interval = 300
            # self.num_eval_batches = 10
        else:
            raise NotImplementedError(
                "Dataset {} not implemented".format(dataset_name))
        self.mu = tf.placeholder(tf.float32, [self.batch_size, self.z_dim],
                                 name='mu')
        self.sigma = tf.placeholder(tf.float32, [self.batch_size, self.z_dim],
                                    name='sigma')
        self.images = None
        self._build_model()
        # initialize all variables
        tf.global_variables_initializer().run()
        # graph inputs for visualize training results
        self.sample_z = prior.gaussian(self.batch_size, self.z_dim)
        self.max_to_keep = 20

        self.counter, self.start_batch_id, self.start_epoch = self.initialize(
            train_val_data_iterator, read_from_existing_checkpoint,
            check_point_epochs)