Ejemplo n.º 1
0
def _main_inspect_checkpoint():
    from Training.Saver import Saver
    save_dir = os.path.join(os.getcwd(), "weight")
    saver = Saver(save_dir)
    _, filename, _ = saver._findfilename()
    fn_inspect_checkpoint(filename)
    fn_inspect_checkpoint(filename, tensor_name='conv2d/kernel')
Ejemplo n.º 2
0
    def test(self, Model, DataSet, filename_list):
        # Reset tf graph.
        tf.reset_default_graph()
        # Create input node
        image_batch, label_batch, init_op, = self._input_fn_w_label(
            DataSet, filename_list)

        # Build up the graph and loss
        with tf.device('/gpu:0'):
            # Sample the generated image
            model = Model(self.config)
            logits = model.forward_pass(image_batch)
            accuracy, update_op, reset_op = self._metric(logits, label_batch)

        # Add saver
        saver = Saver(self.save_dir)
        # Create Session
        sess_config = tf.ConfigProto(allow_soft_placement=True)
        # Use soft_placement to place those variables, which can be placed, on GPU
        with tf.Session(config=sess_config) as sess:
            assert self.config.RESTORE, "RESTORE must be true for the sampler mode!"

            _ = saver.restore(sess, self.config.RUN, self.config.RESTORE_EPOCH)

            # Start Sampling
            tf.logging.info("Start inference!")
            sess.run([init_op, reset_op])
            while True:
                try:
                    accuracy_o, _ = sess.run([accuracy] + update_op)
                except (tf.errors.InvalidArgumentError,
                        tf.errors.OutOfRangeError):
                    break
        print("Validation accuracy: %.8f" % (accuracy_o))
        return
    def deploy_model(self, model, dir_names=None, epoch=None):
        # Reset the tensorflow graph
        tf.reset_default_graph()

        # Build up the graph
        with tf.device('/gpu:0'):
            input_feature, logits, preds, probs, main_graph \
                = self._build_inference_graph(model)

            # Add saver
        saver = Saver(self.save_dir)

        # Prepare the input array
        self._input_processing()

        # Create a session
        with tf.Session() as sess:
            # restore the weights
            _ = saver.restore(sess, dir_names=dir_names, epoch=epoch)
            # initialize the unitialized variables
            initialize_uninitialized_vars(sess)
            logits_o, probs_o, preds_o = sess.run([logits, preds, probs], \
                                                  feed_dict={input_feature: self.input_array,
                                                             main_graph.is_training: False})
            saver.save(sess, 'model_final' + '.ckpt')
            # Note this save the model in the original folder, you need manully move the file
            # to Deploy/FinalModel
            # TODO: save the model directly in Deploy/FinalModel
        return logits_o, probs_o, preds_o
Ejemplo n.º 4
0
    def evaler(self, model, dir_names=None, epoch=None):
        # Reset the tensorflow graph
        tf.reset_default_graph()
        # Input node
        init_val_op, val_input, val_lab = self._input_fn_eval()
        val_lab = tf.argmax(val_lab, axis=-1)
        # Build up the graph
        with tf.device('/gpu:0'):
            logits, preds, probs, accuracy, update_op_a, roc_auc, update_op_roc,\
            pr_auc, update_op_pr, main_graph\
              = self._build_test_graph(val_input, val_lab, model)

        # Add saver
        saver = Saver(self.save_dir)

        # List to store the results
        Val_lab = []
        Preds = []
        Logits = []
        Probs = []

        # Create a session
        with tf.Session() as sess:
            # restore the weights
            _ = saver.restore(sess, dir_names=dir_names, epoch=epoch)
            # initialize the unitialized variables
            initialize_uninitialized_vars(sess)
            # initialize the dataset iterator
            sess.run(init_val_op)
            # start evaluation
            count = 1
            while True:
                try:
                    val_lab_o, logits_o, preds_o, probs_o, accuracy_o, roc_auc_o, pr_auc_o,\
                    _, _, _ = \
                      sess.run([val_lab, logits, preds, probs, accuracy, \
                                roc_auc, pr_auc, update_op_a, update_op_roc, update_op_pr], \
                                feed_dict={main_graph.is_training: False})
                    # store results
                    Val_lab.append(val_lab_o)
                    Preds.append(preds_o)
                    Logits.append(logits_o)
                    Probs.append(probs_o[:, -1])
                    tf.logging.debug("The current validation sample batch num is {}."\
                                        .format(count))
                    count += 1
                except (tf.errors.InvalidArgumentError,
                        tf.errors.OutOfRangeError):
                    # print out the evaluation results
                    tf.logging.info("The validation results are: accuracy {:.2f}; roc_auc {:.2f}; pr_auc {:.2f}."\
                                        .format(accuracy_o, roc_auc_o, pr_auc_o))
                    break
        return Val_lab, Preds, Logits, Probs, accuracy_o, roc_auc_o, pr_auc_o
    def evaler(self, model, dir_names=None, epoch=None):
        # Reset the tensorflow graph
        tf.reset_default_graph()
        # Input node
        init_val_op, val_input, val_lab = self._input_fn_eval()
        val_lab = tf.argmax(val_lab, axis=-1)
        # Build up the graph
        with tf.device('/gpu:0'):
            output_list, main_graph = self._build_test_graph(val_input, val_lab, model)

            # Add saver
        saver = Saver(self.save_dir)
        # List to store the results
        Out =  []

        # Create a session
        with tf.Session() as sess:
            # restore the weights
            _ = saver.restore(sess, dir_names=dir_names, epoch=epoch)
            # initialize the unitialized variables
            initialize_uninitialized_vars(sess)
            # initialize the dataset iterator
            sess.run(init_val_op)
            # start evaluation
            count = 1
            while True:
                try:
                    out = \
                        sess.run(output_list,\
                                 feed_dict={main_graph.is_training: False})
                    # store results
                    Out.append(out)
                    tf.logging.debug("The current validation sample batch num is {}." \
                                     .format(count))
                    count += 1
                except (tf.errors.InvalidArgumentError, tf.errors.OutOfRangeError):
                    # print out the evaluation results
                    tf.logging.info("The validation results are: accuracy {:.2f}; roc_auc {:.2f}; pr_auc {:.2f}." \
                                    .format(out))
                    break
        return Out
Ejemplo n.º 6
0
    def evaler(self, Model, dir_names=None, epoch=None):
        # Reset the tensorflow graph
        tf.reset_default_graph()
        # Input node
        im_batch, lm_batch, init_op = self._input_fn()

        # Build up the train graph
        with tf.name_scope("Test"):
            model = Model(self.config)
            output_list = model.forward(im_batch, lm_batch)

        # Add saver
        saver = Saver(self.save_dir)
        # List to store the results
        Out = []

        # Create a session
        with tf.Session() as sess:
            # restore the weights
            _ = saver.restore(sess, dir_names=dir_names, epoch=epoch)
            # initialize the unitialized variables
            initialize_uninitialized_vars(sess)
            # initialize the dataset iterator
            sess.run(init_op)
            # start evaluation
            count = 1
            while True:
                try:
                    out = \
                        sess.run(output_list)
                    # store results
                    Out.append(out)
                    tf.logging.debug("The current validation sample batch num is {}." \
                                     .format(count))
                    count += 1
                except (tf.errors.InvalidArgumentError, tf.errors.OutOfRangeError):
                    break
        return Out
Ejemplo n.º 7
0
    def interpol(self,
                 Model,
                 im_batch_i,
                 lm_batch_i,
                 dir_names=None,
                 epoch=None):
        # Reset the tensorflow graph
        tf.reset_default_graph()
        # Input node
        im_latent_batch = tf.placeholder(tf.float32,
                                         shape=(None, 50),
                                         name="apperance_latent")
        lm_latent_batch = tf.placeholder(tf.float32,
                                         shape=(None, 10),
                                         name="landmark_latent")

        # Build up the train graph
        with tf.name_scope("Test"):
            model = Model(self.config)
            output_list = model.forward_latent(im_latent_batch,
                                               lm_latent_batch)

        # Add saver
        saver = Saver(self.save_dir)
        # List to store the results

        # Create a session
        with tf.Session() as sess:
            # restore the weights
            _ = saver.restore(sess, dir_names=dir_names, epoch=epoch)
            # initialize the unitialized variables
            initialize_uninitialized_vars(sess)
            out = sess.run(output_list,
                           feed_dict={
                               im_latent_batch: im_batch_i,
                               lm_latent_batch: lm_batch_i
                           })
        return out
Ejemplo n.º 8
0
    def train(self, Model):
        tf.reset_default_graph()
        # Input node
        im_batch, lm_batch, init_op = self._input_fn()

        # Build up the train graph
        with tf.name_scope("Train"):
            model = Model(self.config)
            out_im_batch, _, out_lm_batch, _ = model.forward(im_batch, lm_batch)
            im_loss, lm_loss = self._loss(im_batch, out_im_batch, lm_batch, out_lm_batch)
            optimizer = self._Adam_optimizer()
            im_solver, im_grads = self._train_op_w_grads(optimizer, im_loss)
            lm_solver, lm_grads = self._train_op_w_grads(optimizer, lm_loss)

        # Add summary
        if self.config.SUMMARY:
            summary_dict_train = {}
            if self.config.SUMMARY_SCALAR:
                scalar_train = {'appearance_loss': im_loss, \
                                'landmarks_loss': lm_loss}
                summary_dict_train['scalar'] = scalar_train

            # Merge summary
            merged_summary_train = \
                self.summary_train.add_summary(summary_dict_train)

        # Add saver
        saver = Saver(self.save_dir)

        sess_config = tf.ConfigProto(allow_soft_placement=True)
        # Use soft_placement to place those variables, which can be placed, on GPU
        # Create Session
        with tf.Session(config=sess_config) as sess:
            # Add graph to tensorboard
            if self.config.SUMMARY and self.config.SUMMARY_GRAPH:
                self.summary_train._graph_summary(sess.graph)

            # Restore the weights from the previous training
            if self.config.RESTORE:
                start_epoch = saver.restore(sess)
            else:
                # Create a new folder for saving model
                saver.set_save_path(comments=self.comments)
                start_epoch = 0

            # initialize the variables
            init_var = tf.group(tf.global_variables_initializer(), \
                                tf.local_variables_initializer())
            sess.run(init_var)

            # Start Training
            tf.logging.info("Start training!")
            for epoch in range(1, self.config.EPOCHS + 1):
                tf.logging.info("Training for epoch {}.".format(epoch))
                train_pr_bar = tf.contrib.keras.utils.Progbar(target= \
                                                                  int(800 / self.config.BATCH_SIZE))
                sess.run(init_op)
                batch = 0
                while True:
                    try:
                        im_loss_o, lm_loss_o, summary_o, _, _ = sess.run([im_loss, lm_loss, merged_summary_train, \
                                                                          im_solver, lm_solver])
                        batch += 1
                        train_pr_bar.update(batch)

                        if self.config.SUMMARY:
                            # Add summary
                            self.summary_train.summary_writer.add_summary(summary_o, epoch + start_epoch)

                    except (tf.errors.InvalidArgumentError, tf.errors.OutOfRangeError):
                        break

                tf.logging.info(
                    "\nThe current epoch {}, appearance loss is {:.2f}, landmark loss is {:.2f}.\n" \
                    .format(epoch, im_loss_o, lm_loss_o))
                # Save the model per SAVE_PER_EPOCH
                if epoch % self.config.SAVE_PER_EPOCH == 0:
                    save_name = str(epoch + start_epoch)
                    saver.save(sess, 'model_' + save_name.zfill(4) \
                               + '.ckpt')

            if self.config.SUMMARY:
                self.summary_train.summary_writer.flush()
                self.summary_train.summary_writer.close()

            # Save the model after all epochs
            save_name = str(epoch + start_epoch)
            saver.save(sess, 'model_' + save_name.zfill(4) + '.ckpt')
        return
Ejemplo n.º 9
0
    def train(self, model):
        # Reset tf graph.
        tf.reset_default_graph()

        # Create input node
        init_op_train, init_op_val, real_lab_input,\
               real_lab, real_unl_input, dataset_train = self._input_fn_train_val()

        # Build up the graph
        with tf.device('/gpu:0'):
            d_loss, g_loss, accuracy, roc_auc, pr_auc, \
            update_op, reset_op, preds, probs, main_graph, scalar_train_sum_dict\
                                            = training._build_train_graph(real_lab_input, \
                                            real_unl_input, real_lab, model)
        # Create optimizer
        with tf.name_scope('Train'):
            theta_G = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                        scope='Generator')
            theta_D = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                        scope='Discriminator')
            optimizer = self._Adam_Optimizer()
            d_solver, d_grads = self._train_op_w_grads(optimizer, d_loss,\
                                                  var_list = theta_D)
            g_solver, g_grads = self._train_op_w_grads(optimizer, g_loss,\
                                                  var_list = theta_G)

        # Print out the variable name in debug mode
        tf.logging.debug(utils.variable_name_string())

        # Add summary
        if self.config.SUMMARY:
            if self.config.SUMMARY_TRAIN_VAL:
                summary_dict_train = {}
                summary_dict_val = {}
                if self.config.SUMMARY_SCALAR:
                    scalar_train = {'generator_loss': g_loss, \
                                    'discriminator_loss': d_loss}
                    scalar_train.update(scalar_train_sum_dict)
                    scalar_val = {'val_accuracy': accuracy, \
                                  'val_pr_auc': pr_auc, \
                                  'val_roc_auc': roc_auc}

                    summary_dict_train['scalar'] = scalar_train
                    summary_dict_val['scalar'] = scalar_val

                if self.config.SUMMARY_HISTOGRAM:
                    ## TODO: add any vectors that you want to visulize.
                    pass
                # Merge summary
                merged_summary_train = \
                  self.summary_train.add_summary(summary_dict_train)
                merged_summary_val = \
                  self.summary_val.add_summary(summary_dict_val)

        # Add saver
        saver = Saver(self.save_dir)

        # Whether to use a pre-trained weights
        if self.config.PRE_TRAIN:
            pre_train_saver = tf.train.Saver(theta_D)

        sess_config = tf.ConfigProto(allow_soft_placement=True)
        # Use soft_placement to place those variables, which can be placed, on GPU

        # Create Session
        with tf.Session(config=sess_config) as sess:
            # Add graph to tensorboard
            if self.config.SUMMARY and self.config.SUMMARY_GRAPH:
                if self.config.SUMMARY_TRAIN_VAL:
                    self.summary_train._graph_summary(sess.graph)

            # Restore the weights from the previous training
            if self.config.RESTORE:
                start_epoch = saver.restore(sess)
            else:
                # Create a new folder for saving model
                saver.set_save_path(comments=self.comments)
                start_epoch = 0
                if self.config.PRE_TRAIN:
                    # Restore the pre-trained weights for D
                    pre_train_saver.restore(sess,
                                            self.config.PRE_TRAIN_FILE_PATH)
                    initialize_uninitialized_vars(sess)
                else:
                    # initialize the variables
                    init_var = tf.group(tf.global_variables_initializer(), \
                                        tf.local_variables_initializer())
                    sess.run(init_var)

            # Start Training
            tf.logging.info("Start training!")
            for epoch in range(1, self.config.EPOCHS + 1):
                tf.logging.info("Training for epoch {}.".format(epoch))
                train_pr_bar = tf.contrib.keras.utils.Progbar(target = \
                                  int(tmp_config.TRAIN_SIZE / tmp_config.BATCH_SIZE))
                sess.run(init_op_train,
                         feed_dict={dataset_train._is_training_input: True})

                for i in range(
                        int(self.config.TRAIN_SIZE / self.config.BATCH_SIZE)):
                    # Update discriminator
                    d_loss_o, _, summary_out  = sess.run([d_loss, d_solver, \
                                             merged_summary_train], \
                                             feed_dict={main_graph.is_training: True, \
                                             dataset_train._is_training_input: True})
                    # Update generator
                    g_loss_o, _, summary_out = sess.run([g_loss, g_solver, \
                                              merged_summary_train], \
                                              feed_dict={main_graph.is_training: True, \
                                              dataset_train._is_training_input: True})

                    tf.logging.debug("Training for batch {}.".format(i))
                    # Update progress bar
                    train_pr_bar.update(i)

                if self.config.SUMMARY_TRAIN_VAL:
                    # Add summary
                    self.summary_train.summary_writer.add_summary(
                        summary_out, epoch + start_epoch)

                # Perform validation
                tf.logging.info("\nValidate for epoch {}.".format(epoch))
                sess.run(init_op_val + [reset_op], \
                         feed_dict = {dataset_train._is_training_input: False})
                count = 1
                val_pr_bar = tf.contrib.keras.utils.Progbar(target = \
                                              int(tmp_config.VAL_SIZE / tmp_config.BATCH_SIZE))

                for i in range(
                        int(self.config.VAL_SIZE / self.config.BATCH_SIZE)):
                    try:
                        accuracy_o, summary_out, roc_auc_o, \
                        pr_auc_o, val_lab_o, preds_o, \
                        probs_o, _, _, _ = sess.run([accuracy, merged_summary_val,\
                                                     roc_auc, pr_auc, real_lab, \
                                                     preds, probs] + update_op,\
                                                     feed_dict={main_graph.is_training: False,\
                                                     dataset_train._is_training_input: False})

                        tf.logging.debug(
                            "Validate for batch {}.".format(count))
                        # Update progress bar
                        val_pr_bar.update(count)
                        count += 1
                    except (tf.errors.InvalidArgumentError,
                            tf.errors.OutOfRangeError):
                        break

                tf.logging.info("\nThe current validation accuracy for epoch {} is {:.2f}, roc_auc is {:.2f}, pr_auc is {:.2f}.\n"\
                                  .format(epoch, accuracy_o, roc_auc_o, pr_auc_o))

                # Add summary to tensorboard
                if self.config.SUMMARY_TRAIN_VAL:
                    self.summary_val.summary_writer.add_summary(
                        summary_out, epoch + start_epoch)

                # Save the model per SAVE_PER_EPOCH
                if epoch % self.config.SAVE_PER_EPOCH == 0:
                    save_name = str(epoch + start_epoch)
                    saver.save(sess, 'model_' + save_name.zfill(4) \
                                         + '.ckpt')

            if self.config.SUMMARY_TRAIN_VAL:
                self.summary_train.summary_writer.flush()
                self.summary_train.summary_writer.close()
                self.summary_val.summary_writer.flush()
                self.summary_val.summary_writer.close()

            # Save the model after all epochs
            save_name = str(epoch + start_epoch)
            saver.save(sess, 'model_' + save_name.zfill(4) + '.ckpt')

        return
def _main_inspect_checkpoint(save_dir):
    from Training.Saver import Saver
    saver = Saver(save_dir)
    _, filename, _ = saver._findfilename()
    fn_inspect_checkpoint(filename)
Ejemplo n.º 11
0
    def train(self, Model, DataSet, SAMPLE_X=None, SAMPLE_Y=None):
        # Reset tf graph.
        tf.reset_default_graph()

        # Create input node
        if not self.config.Y_LABLE:
            image_batch, init_op, dataset = self._input_fn(DataSet)
        else:
            image_batch, label_batch, init_op, dataset = self._input_fn_w_label(
                DataSet)
            # image, label = self._input_fn_NP()  # using numpy array as feed dict

        # Build up the graph and loss
        with tf.device('/gpu:0'):
            # Create placeholder
            if self.config.Y_LABLE:
                y = tf.placeholder(
                    tf.float32,
                    [self.config.BATCH_SIZE, self.config.NUM_CLASSES],
                    name='y')  # label batch
                x = tf.placeholder(tf.float32, [self.config.BATCH_SIZE] +
                                   self.config.IMAGE_DIM,
                                   name='real_images')  # real image
            else:
                y = None
                x = tf.placeholder(tf.float32, [self.config.BATCH_SIZE] + [
                    self.config.IMAGE_HEIGHT_O, self.config.IMAGE_WIDTH_O,
                    self.config.CHANNEL
                ],
                                   name='real_images')  # real image

            z = tf.placeholder(
                tf.float32,
                [self.config.BATCH_SIZE, self.config.Z_DIM])  # latent variable

            # Build up the graph

            G, D, D_logits, D_, D_logits_, fm, fm_, model = self._build_train_graph(
                x, y, z, Model)
            # # Create the loss:
            # d_loss, g_loss = self._loss(D, D_logits, D_, D_logits_, x, G, fm, fm_, model.discriminator, y)

            samples = model.sampler(z, y)

        # Create optimizer
        with tf.name_scope('Train'):
            t_vars = tf.trainable_variables()
            # theta_G = [var for var in t_vars if 'g_' in var.name]
            # theta_D = [var for var in t_vars if 'd_' in var.name]
            theta_G = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                        scope="generator")
            theta_D = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                        scope="discriminator")
            optimizer = self._Adam_optimizer()

            # Discriminator loss
            if self.config.LABEL_SMOOTH:
                d_loss_real = self._sigmoid_cross_entopy_w_logits(
                    0.9 * tf.ones_like(D), D_logits)
            else:
                d_loss_real = self._sigmoid_cross_entopy_w_logits(
                    tf.ones_like(D), D_logits)
            d_loss_fake = self._sigmoid_cross_entopy_w_logits(
                tf.zeros_like(D_), D_logits_)
            d_loss = d_loss_fake + d_loss_real

            d_updates = tf.keras.optimizers.Adam(
                lr=1e-4,
                beta_1=self.config.BETA1).get_updates(d_loss, theta_D)
            d_optim = tf.group(*d_updates, name="d_train_op")

            if self.config.UNROLLED_STEP > 0:
                update_dict = self._extract_update_dict(d_updates)
                cur_update_dict = update_dict
                for i in range(self.config.UNROLLED_STEP - 1):
                    cur_update_dict = self._graph_replace(
                        update_dict, cur_update_dict)
                g_loss = -self._graph_replace(d_loss, cur_update_dict)
            else:
                g_loss = -d_loss

            g_optim = self._train_op(optimizer, g_loss, theta_G)

        # Add summary
        if self.config.SUMMARY:
            summary_dict = {}
            if self.config.SUMMARY_SCALAR:
                scaler = {
                    'generator_loss': g_loss,
                    'discriminator_loss': d_loss
                }

            merged_summary = self.summary.add_summary(summary_dict)

        # Add saver
        saver = Saver(self.save_dir)

        # Create Session
        sess_config = tf.ConfigProto(allow_soft_placement=True)
        # Use soft_placement to place those variables, which can be placed, on GPU
        with tf.Session(config=sess_config) as sess:
            # if self.config.DEBUG:
            #     sess = tf_debug.LocalCLIDebugWrapperSession(sess)
            if self.config.SUMMARY and self.config.SUMMARY_GRAPH:
                self.summary._graph_summary(sess.graph)

            if self.config.RESTORE:
                start_epoch = saver.restore(sess)
            else:
                saver.set_save_path(comments=self.comments)
                start_epoch = 0
                # initialize the variables
                init_var = tf.group(tf.global_variables_initializer(), \
                                    tf.local_variables_initializer())
                sess.run(init_var)

            sample_z = np.random.normal(size=(self.config.BATCH_SIZE, 100))
            if not self.config.Y_LABLE:
                ## TODO: support PacGAN for conditional case
                sess.run(init_op)
                sample_x = sess.run(image_batch)
            else:
                # sample_x, sample_y = sess.run([image_batch, label_batch])
                sample_x, sample_y = SAMPLE_X, SAMPLE_Y
                # sample_x, sample_y = image[:64, ...], label[:64, ...] # for numpy input

            # Start Training
            tf.logging.info("Start unrolledGAN traininig!")
            for epoch in range(start_epoch + 1,
                               self.config.EPOCHS + start_epoch + 1):
                tf.logging.info("Training for epoch {}.".format(epoch))
                train_pr_bar = tf.contrib.keras.utils.Progbar(target= \
                                                                  int(self.config.TRAIN_SIZE / self.config.BATCH_SIZE))
                sess.run(init_op)
                for i in range(
                        int(self.config.TRAIN_SIZE / self.config.BATCH_SIZE)):
                    batch_z = np.random.normal(
                        size=(self.config.BATCH_SIZE,
                              self.config.Z_DIM)).astype(np.float32)
                    # Fetch a data batch
                    if not self.config.Y_LABLE:
                        image_batch_o = sess.run(image_batch)
                    else:
                        image_batch_o, label_batch_o = sess.run(
                            [image_batch, label_batch])

                    ## for numpy input
                    # image_batch_o, label_batch_o = image[i * self.config.BATCH_SIZE : (i + 1) * self.config.BATCH_SIZE], \
                    #                                label[i * self.config.BATCH_SIZE : (i + 1) * self.config.BATCH_SIZE]

                    if not self.config.Y_LABLE:
                        # Update discriminator
                        _, d_loss_o = sess.run([d_optim, d_loss],
                                               feed_dict={
                                                   x: image_batch_o,
                                                   z: batch_z
                                               })
                        # Update generator
                        _ = sess.run([g_optim],
                                     feed_dict={
                                         x: image_batch_o,
                                         z: batch_z
                                     })
                        _, g_loss_o = sess.run([g_optim, g_loss],
                                               feed_dict={
                                                   x: image_batch_o,
                                                   z: batch_z
                                               })
                    else:
                        # Update discriminator
                        _, d_loss_o = sess.run([d_optim, d_loss],
                                               feed_dict={
                                                   x: image_batch_o,
                                                   y: label_batch_o,
                                                   z: batch_z
                                               })

                        # Update generator
                        _ = sess.run([g_optim],
                                     feed_dict={
                                         x: image_batch_o,
                                         y: label_batch_o,
                                         z: batch_z
                                     })
                        _, g_loss_o = sess.run([g_optim, g_loss],
                                               feed_dict={
                                                   x: image_batch_o,
                                                   y: label_batch_o,
                                                   z: batch_z
                                               })

                    # Update progress bar
                    train_pr_bar.update(i)
                    if i % 100 == 0:
                        if self.config.DEBUG:
                            ## Sample image for every 100 update in debug mode
                            if not self.config.Y_LABLE:
                                samples_o, d_loss_o, g_loss_o, summary_o = sess.run(
                                    [samples, d_loss, g_loss, merged_summary],
                                    feed_dict={
                                        x: sample_x,
                                        z: sample_z
                                    })
                            else:
                                samples_o, d_loss_o, g_loss_o, summary_o = sess.run(
                                    [samples, d_loss, g_loss, merged_summary],
                                    feed_dict={
                                        x: sample_x,
                                        y: sample_y,
                                        z: sample_z
                                    })

                            save_images(samples_o, image_manifold_size(samples_o.shape[0]), \
                                        os.path.join(self.config.SAMPLE_DIR, 'train_{:02d}_{:02d}.png'.format(epoch, i)))

                # Save the model per SAVE_PER_EPOCH
                if epoch % self.config.SAVE_PER_EPOCH == 0:
                    save_name = str(epoch)
                    saver.save(sess, 'model_' + save_name.zfill(4) + '.ckpt')

                print("Epoch: [%2d/%2d], d_loss: %.8f, g_loss: %.8f" \
                      % (epoch, self.config.EPOCHS + start_epoch, d_loss_o, g_loss_o))
                ## Sample image after every epoch
                if not self.config.Y_LABLE:
                    samples_o, d_loss_o, g_loss_o, summary_o = sess.run(
                        [samples, d_loss, g_loss, merged_summary],
                        feed_dict={
                            x: sample_x,
                            z: sample_z
                        })
                else:
                    samples_o, d_loss_o, g_loss_o, summary_o = sess.run(
                        [samples, d_loss, g_loss, merged_summary],
                        feed_dict={
                            x: sample_x,
                            y: sample_y,
                            z: sample_z
                        })

                if self.config.SUMMARY:
                    self.summary.summary_writer.add_summary(summary_o, epoch)

                print("[Sample] d_loss: %.8f, g_loss: %.8f" %
                      (d_loss_o, g_loss_o))
                save_images(samples_o, image_manifold_size(samples_o.shape[0]), \
                            os.path.join(self.config.SAMPLE_DIR, 'train_{:02d}.png'.format(epoch)))

            if self.config.SUMMARY:
                self.summary.summary_writer.flush()
                self.summary.summary_writer.close()

            # Save the model after all epochs
            save_name = str(epoch)
            saver.save(sess, 'model_' + save_name.zfill(4) + '.ckpt')
            return
    def main_sampler(self, Model, SAMPLE_Y = None):
        # Reset tf graph.
        tf.reset_default_graph()

        # Build up the graph and loss
        with tf.device('/gpu:0'):
            # Create placeholder
            if self.config.Y_LABEL:
                y = tf.placeholder(tf.float32, [self.config.BATCH_SIZE, self.config.NUM_CLASSES], name='y') # label batch
            else:
                y = None

            z = tf.placeholder(tf.float32, [self.config.BATCH_SIZE, self.config.Z_DIM]) # latent variable

            # Sample the generated image
            model = Model(self.config)
            samples = model.sampler(z, y)

        # Add saver
        saver = Saver(self.save_dir)

        # Create Session
        sess_config = tf.ConfigProto(allow_soft_placement = True)
        # Use soft_placement to place those variables, which can be placed, on GPU
        with tf.Session(config = sess_config) as sess:
            assert self.config.RESTORE, "RESTORE must be true for the sampler mode!"
            _ = saver.restore(sess, self.config.RUN, self.config.RESTORE_EPOCH)
            # Start Sampling
            tf.logging.info("Start sampling!")
            with tf.python_io.TFRecordWriter(os.path.join(self.config.SAMPLE_DIR, \
                                                          self.config.DATA_NAME + "_sampler.tfrecords")) as record_writer:
                for epoch in range(1, self.config.EPOCHS + 1):
                    if self.config.Y_LABEL:
                       sample_y = SAMPLE_Y
                    sample_z = np.random.normal(size=(self.config.BATCH_SIZE, self.config.Z_DIM))
                    sample_pr_bar = tf.contrib.keras.utils.Progbar(target= self.config.EPOCHS)
                    if not self.config.Y_LABEL:
                        samples_o = sess.run(samples,
                                                       feed_dict={z: sample_z})
                        if self.config.DATA_NAME == "mnist":
                            samples_o = samples_o * 255
                        else:
                            samples_o = (samples_o + 1) * 127.5

                        for i in range(samples_o.shape[0]):
                            image = samples_o[i,...].astype(np.uint8)
                            if self.config.DATA_NAME == "prostate":
                                example = tf.train.Example(features=tf.train.Features(
                                    feature={
                                        'image': self._bytes_feature(image.tobytes()),
                                        'label': self._int64_feature(-1), ## -1 stands for no label
                                        'height': self._int64_feature(image.shape[0]),
                                        'width': self._int64_feature(image.shape[1])
                                    }))
                            else:
                                example = tf.train.Example(features=tf.train.Features(
                                    feature={
                                        'image': self._bytes_feature(image.tobytes()),
                                        'label': self._int64_feature(-1),
                                    }))
                            record_writer.write(example.SerializeToString())
                    else:
                        samples_o = sess.run(samples,
                                                           feed_dict = {y: sample_y,
                                                                        z: sample_z})

                        if self.config.DATA_NAME == "mnist":
                            samples_o = samples_o * 255
                        else:
                            samples_o = (samples_o + 1) * 127.5

                        labels = np.argmax(sample_y, axis = 1)
                        for i in range(samples_o.shape[0]):
                            image = samples_o[i,...].astype(np.uint8)
                            label = labels[i]
                            if self.config.DATA_NAME == "prostate":
                                example = tf.train.Example(features=tf.train.Features(
                                    feature={
                                        'image': self._bytes_feature(image.tobytes()),
                                        'label': self._int64_feature(label),
                                        'height': self._int64_feature(image.shape[0]),
                                        'width': self._int64_feature(image.shape[1])
                                    }))
                            elif self.config.DATA_NAME == "mnist":
                                example = tf.train.Example(features=tf.train.Features(
                                    feature={
                                        'image': self._bytes_feature(image.tobytes()),
                                        'label': self._int64_feature(label),
                                    }))
                            record_writer.write(example.SerializeToString())

                # Update progress bar
                sample_pr_bar.update(epoch)
            save_images(samples_o[:64], image_manifold_size(64), \
                        os.path.join(self.config.SAMPLE_DIR, 'samples.png'))
            return samples_o, labels
    def train(self, Model, DataSet):
        # Reset tf graph.
        tf.reset_default_graph()
        image_batch, label_batch, init_op_train, init_op_var = self._input_fn_w_label(
            DataSet)
        # Build up the graph and loss
        with tf.device('/gpu:0'):
            # Build up the graph
            logits = self._build_train_graph(image_batch, Model)
            # Create the loss:
            loss = self._loss(logits, label_batch)
            # Create metric:
            accuracy, update_op, reset_op = self._metric(logits, label_batch)

        # Create optimizer
        with tf.name_scope('Train'):
            optimizer = self._Adam_optimizer()
            optim = self._train_op(optimizer, loss)

        # Add summary
        if self.config.SUMMARY:
            summary_dict = {}
            if self.config.SUMMARY_SCALAR:
                scaler = {'loss': loss, 'accuracy': accuracy}
                summary_dict['scalar'] = scaler

            merged_summary = self.summary.add_summary(summary_dict)

        # Add saver
        saver = Saver(self.save_dir)

        # Create Session
        sess_config = tf.ConfigProto(allow_soft_placement=True)
        # Use soft_placement to place those variables, which can be placed, on GPU
        with tf.Session(config=sess_config) as sess:
            # if self.config.DEBUG:
            #     sess = tf_debug.LocalCLIDebugWrapperSession(sess)
            if self.config.SUMMARY and self.config.SUMMARY_GRAPH:
                self.summary._graph_summary(sess.graph)

            if self.config.RESTORE:
                start_epoch = saver.restore(sess)
            else:
                saver.set_save_path(comments=self.comments)
                start_epoch = 0
                # initialize the variables
                init_var = tf.group(tf.global_variables_initializer(), \
                                    tf.local_variables_initializer())
                sess.run(init_var)

            # Start Training
            tf.logging.info("Start traininig!")
            for epoch in range(start_epoch + 1,
                               self.config.EPOCHS + start_epoch + 1):
                tf.logging.info("Training for epoch {}.".format(epoch))
                train_pr_bar = tf.contrib.keras.utils.Progbar(target= \
                                                                  int(self.config.TRAIN_SIZE / self.config.BATCH_SIZE))
                sess.run(init_op_train)
                for i in range(
                        int(self.config.TRAIN_SIZE / self.config.BATCH_SIZE)):
                    _, loss_o, accuracy_o, summary_o, _ = sess.run(
                        [optim, loss, accuracy, merged_summary] + update_op)
                    # Update progress bar
                    train_pr_bar.update(i)
                print("Epoch: [%2d/%2d], training loss: %.8f, training accuracy: %.8f" \
                      % (epoch, self.config.EPOCHS + start_epoch, loss_o, accuracy_o))

                # Do validation
                sess.run(init_op_var + [reset_op])
                for i in range(
                        int(self.config.VAL_SIZE / self.config.BATCH_SIZE)):
                    try:
                        accuracy_o, loss_o, _ = sess.run([accuracy, loss] +
                                                         update_op)
                    except (tf.errors.InvalidArgumentError,
                            tf.errors.OutOfRangeError):
                        break

                print("Epoch: [%2d/%2d], validation loss: %.8f, validation accuracy: %.8f" \
                      % (epoch, self.config.EPOCHS + start_epoch, loss_o, accuracy_o))

                # Save the model per SAVE_PER_EPOCH
                if epoch % self.config.SAVE_PER_EPOCH == 0:
                    save_name = str(epoch)
                    saver.save(sess, 'model_' + save_name.zfill(4) + '.ckpt')

                if self.config.SUMMARY:
                    self.summary.summary_writer.add_summary(summary_o, epoch)

            if self.config.SUMMARY:
                self.summary.summary_writer.flush()
                self.summary.summary_writer.close()

            # Save the model after all epochs
            save_name = str(epoch)
            saver.save(sess, 'model_' + save_name.zfill(4) + '.ckpt')
            return
Ejemplo n.º 14
0
    def train(self, Model, DataSet, SAMPLE_X=None, SAMPLE_Y=None):
        # Reset tf graph.
        tf.reset_default_graph()

        # Create input node
        if not self.config.Y_LABEL:
            image_batch, init_op, dataset = self._input_fn(DataSet)
        else:
            image_batch, label_batch, init_op, dataset = self._input_fn_w_label(
                DataSet)
            # image, label = self._input_fn_NP()  # using numpy array as feed dict

        # Build up the graph and loss
        with tf.device('/gpu:0'):
            if self.config.LOSS == "PacGAN":
                # TODO: support conditional GAN for PacGAN
                y = None
                if self.config.CROP:
                    image_dims = [
                        self.config.IMAGE_HEIGHT_O, self.config.IMAGE_WIDTH_O,
                        self.config.CHANNEL * self.config.PAC_NUM
                    ]
                else:
                    image_dims = self.config.IMAGE_DIM[:-1] + [
                        self.config.CHANNEL * self.config.PAC_NUM
                    ]
                x = tf.placeholder(tf.float32,
                                   [self.config.BATCH_SIZE] + image_dims,
                                   name='real_images')
                z = []
                for i in range(self.config.PAC_NUM):
                    z.append(tf.placeholder(tf.float32, [self.config.BATCH_SIZE, self.config.Z_DIM], \
                                            name = 'z{}'.format(i))) # latent variable

            else:
                # Create placeholder
                if self.config.Y_LABEL:
                    y = tf.placeholder(
                        tf.float32,
                        [self.config.BATCH_SIZE, self.config.NUM_CLASSES],
                        name='y')  # label batch
                    x = tf.placeholder(tf.float32, [self.config.BATCH_SIZE] +
                                       self.config.IMAGE_DIM,
                                       name='real_images')  # real image
                else:
                    y = None
                    x = tf.placeholder(tf.float32, [self.config.BATCH_SIZE] + [
                        self.config.IMAGE_HEIGHT_O, self.config.IMAGE_WIDTH_O,
                        self.config.CHANNEL
                    ],
                                       name='real_images')  # real image

                z = tf.placeholder(tf.float32,
                                   [self.config.BATCH_SIZE, self.config.Z_DIM
                                    ])  # latent variable

            if self.config.LOSS == "MRGAN":
                # Build up the graph for mrGAN
                G, G_mr, D, D_logits, D_, D_logits_, fm, fm_, D_mr, D_mr_logits, _, model = self._build_train_graph(
                    x, y, z, Model)
                # Create the loss:
                d_loss, g_loss, e_loss = self._loss(D, D_logits, D_, D_logits_, x, G, fm, fm_, model.discriminator, y, \
                                                    G_mr, D_mr_logits)

            else:
                # Build up the graph
                G, D, D_logits, D_, D_logits_, fm, fm_, model = self._build_train_graph(
                    x, y, z, Model)
                # Create the loss:
                d_loss, g_loss = self._loss(D, D_logits, D_, D_logits_, x, G,
                                            fm, fm_, model.discriminator, y)

            # Sample the generated image every epoch
            if self.config.LOSS == "PacGAN":
                samples = model.sampler(z[0], y)
            else:
                samples = model.sampler(z, y)

        # Create optimizer
        with tf.name_scope('Train'):
            t_vars = tf.trainable_variables()
            theta_G = [var for var in t_vars if 'g_' in var.name]
            theta_D = [var for var in t_vars if 'd_' in var.name]
            if self.config.LOSS == "MRGAN":
                theta_E = [var for var in t_vars if 'e_' in var.name]

            if self.config.LOSS in ["WGAN", "WGAN_GP", "FMGAN"]:
                optimizer = self._RMSProp_optimizer()
                d_optim_ = self._train_op(optimizer, d_loss, theta_D)
            elif self.config.LOSS in [
                    "GAN", "LSGAN", "cGPGAN", "MRGAN", "PacGAN"
            ]:
                optimizer = self._Adam_optimizer()

            if self.config.LOSS == "WGAN":
                with tf.control_dependencies([d_optim_]):
                    d_optim = tf.group(*(tf.assign(var, \
                                                   tf.clip_by_value(var, -self.config.WEIGHT_CLIP, \
                                                                    self.config.WEIGHT_CLIP)) for var in theta_D))
            else:
                d_optim = self._train_op(optimizer, d_loss, theta_D)
                if self.config.LOSS == "MRGAN":
                    e_optim = self._train_op(optimizer, e_loss, theta_E)

            g_optim = self._train_op(optimizer, g_loss, theta_G)

        # Add summary
        if self.config.SUMMARY:
            summary_dict = {}
            if self.config.SUMMARY_SCALAR:
                scaler = {
                    'generator_loss': g_loss,
                    'discriminator_loss': d_loss
                }
                if self.config.LOSS == "MRGAN":
                    scaler['encoder_loss'] = e_loss
                summary_dict['scalar'] = scaler

            merged_summary = self.summary.add_summary(summary_dict)

        # Add saver
        saver = Saver(self.save_dir)

        # Create Session
        sess_config = tf.ConfigProto(allow_soft_placement=True)
        # Use soft_placement to place those variables, which can be placed, on GPU
        with tf.Session(config=sess_config) as sess:
            # if self.config.DEBUG:
            #     sess = tf_debug.LocalCLIDebugWrapperSession(sess)
            if self.config.SUMMARY and self.config.SUMMARY_GRAPH:
                self.summary._graph_summary(sess.graph)

            if self.config.RESTORE:
                start_epoch = saver.restore(sess)
            else:
                saver.set_save_path(comments=self.comments)
                start_epoch = 0
                # initialize the variables
                init_var = tf.group(tf.global_variables_initializer(), \
                                    tf.local_variables_initializer())
                sess.run(init_var)

            sample_z = np.random.normal(size=(self.config.BATCH_SIZE, 100))
            if not self.config.Y_LABEL:
                ## TODO: support PacGAN for conditional case
                sess.run(init_op)
                if self.config.LOSS == "PacGAN":
                    sample_feed_dict_z = {}
                    sample_x_sep = []
                    for i in range(self.config.PAC_NUM):
                        sample_feed_dict_z[z[i]] = np.random.normal(
                            size=(self.config.BATCH_SIZE,
                                  self.config.Z_DIM)).astype(np.float32)
                        sample_x_sep.append(sess.run(image_batch))
                        sample_x = np.concatenate(sample_x_sep, axis=3)
                else:
                    sample_x = sess.run(image_batch)
            else:
                sample_x, sample_y = SAMPLE_X, SAMPLE_Y

            # Start Training
            tf.logging.info("Start traininig!")
            for epoch in range(start_epoch + 1,
                               self.config.EPOCHS + start_epoch + 1):
                tf.logging.info("Training for epoch {}.".format(epoch))
                train_pr_bar = tf.contrib.keras.utils.Progbar(target= \
                                                                  int(self.config.TRAIN_SIZE / self.config.BATCH_SIZE))
                sess.run(init_op)
                for i in range(
                        int(self.config.TRAIN_SIZE / self.config.BATCH_SIZE)):
                    if self.config.LOSS == "PacGAN":
                        image_batch_sep = []
                        feed_dict_z = {}
                        for j in range(self.config.PAC_NUM):
                            feed_dict_z[z[j]] = np.random.normal(
                                size=(self.config.BATCH_SIZE,
                                      self.config.Z_DIM)).astype(np.float32)
                            image_batch_sep.append(sess.run(image_batch))
                            image_batch_o = np.concatenate(image_batch_sep,
                                                           axis=3)
                    else:
                        batch_z = np.random.normal(
                            size=(self.config.BATCH_SIZE,
                                  self.config.Z_DIM)).astype(np.float32)
                        # Fetch a data batch
                        if not self.config.Y_LABEL:
                            image_batch_o = sess.run(image_batch)
                        else:
                            image_batch_o, label_batch_o = sess.run(
                                [image_batch, label_batch])

                    if not self.config.Y_LABEL:
                        if self.config.LOSS == "PacGAN":
                            # Update discriminator
                            _, d_loss_o = sess.run([d_optim, d_loss],
                                                   feed_dict={
                                                       x: image_batch_o,
                                                       **feed_dict_z
                                                   })
                            # Update generator
                            _ = sess.run([g_optim],
                                         feed_dict={
                                             x: image_batch_o,
                                             **feed_dict_z
                                         })
                            _, g_loss_o = sess.run([g_optim, g_loss],
                                                   feed_dict={
                                                       x: image_batch_o,
                                                       **feed_dict_z
                                                   })
                            if self.config.LOSS == "MRGAN":
                                # Update encoder
                                _, e_loss_o = sess.run([e_optim, e_loss],
                                                       feed_dict={
                                                           x: image_batch_o,
                                                           **feed_dict_z
                                                       })
                        else:
                            # Update discriminator
                            _, d_loss_o = sess.run([d_optim, d_loss],
                                                   feed_dict={
                                                       x: image_batch_o,
                                                       z: batch_z
                                                   })
                            # Update generator
                            _ = sess.run([g_optim],
                                         feed_dict={
                                             x: image_batch_o,
                                             z: batch_z
                                         })
                            _, g_loss_o = sess.run([g_optim, g_loss],
                                                   feed_dict={
                                                       x: image_batch_o,
                                                       z: batch_z
                                                   })
                            if self.config.LOSS == "MRGAN":
                                # Update encoder
                                _, e_loss_o = sess.run([e_optim, e_loss],
                                                       feed_dict={
                                                           x: image_batch_o,
                                                           z: batch_z
                                                       })
                    else:
                        # Update discriminator
                        _, d_loss_o = sess.run([d_optim, d_loss],
                                               feed_dict={
                                                   x: image_batch_o,
                                                   y: label_batch_o,
                                                   z: batch_z
                                               })

                        # Update generator
                        _ = sess.run([g_optim],
                                     feed_dict={
                                         x: image_batch_o,
                                         y: label_batch_o,
                                         z: batch_z
                                     })
                        _, g_loss_o = sess.run([g_optim, g_loss],
                                               feed_dict={
                                                   x: image_batch_o,
                                                   y: label_batch_o,
                                                   z: batch_z
                                               })
                        if self.config.LOSS == "MRGAN":
                            # Update encoder
                            _, e_loss_o = sess.run([e_optim, e_loss],
                                                   feed_dict={
                                                       x: image_batch_o,
                                                       y: label_batch_o,
                                                       z: batch_z
                                                   })

                    # Update progress bar
                    train_pr_bar.update(i)
                    if i % 100 == 0:
                        if self.config.DEBUG:
                            ## Sample image for every 100 update in debug mode
                            if not self.config.Y_LABEL:
                                if self.config.LOSS == "PacGAN":
                                    samples_o, d_loss_o, g_loss_o, summary_o = sess.run(
                                        [
                                            samples, d_loss, g_loss,
                                            merged_summary
                                        ],
                                        feed_dict={
                                            x: image_batch_o,
                                            **sample_feed_dict_z
                                        })
                                else:
                                    samples_o, d_loss_o, g_loss_o, summary_o = sess.run(
                                        [
                                            samples, d_loss, g_loss,
                                            merged_summary
                                        ],
                                        feed_dict={
                                            x: sample_x,
                                            z: sample_z
                                        })
                            else:
                                samples_o, d_loss_o, g_loss_o, summary_o = sess.run(
                                    [samples, d_loss, g_loss, merged_summary],
                                    feed_dict={
                                        x: sample_x,
                                        y: sample_y,
                                        z: sample_z
                                    })
                            save_images(samples_o[:64], image_manifold_size(64), \
                                        os.path.join(self.config.SAMPLE_DIR, 'train_{:02d}_{:02d}.png'.format(epoch, i)))

                # Save the model per SAVE_PER_EPOCH
                if epoch % self.config.SAVE_PER_EPOCH == 0:
                    save_name = str(epoch)
                    saver.save(sess, 'model_' + save_name.zfill(4) + '.ckpt')

                if self.config.LOSS == "MRGAN":
                    print("Epoch: [%2d/%2d], d_loss: %.8f, g_loss: %.8f, e_loss: %.8f" \
                          % (epoch, self.config.EPOCHS + start_epoch, d_loss_o, g_loss_o, e_loss_o))
                else:
                    print("Epoch: [%2d/%2d], d_loss: %.8f, g_loss: %.8f" \
                          % (epoch, self.config.EPOCHS + start_epoch, d_loss_o, g_loss_o))
                ## Sample image after every epoch
                if not self.config.Y_LABEL:
                    if self.config.LOSS == "PacGAN":
                        samples_o, d_loss_o, g_loss_o, summary_o = sess.run(
                            [samples, d_loss, g_loss, merged_summary],
                            feed_dict={
                                x: image_batch_o,
                                **sample_feed_dict_z
                            })
                    else:
                        samples_o, d_loss_o, g_loss_o, summary_o = sess.run(
                            [samples, d_loss, g_loss, merged_summary],
                            feed_dict={
                                x: sample_x,
                                z: sample_z
                            })
                else:
                    samples_o, d_loss_o, g_loss_o, summary_o = sess.run(
                        [samples, d_loss, g_loss, merged_summary],
                        feed_dict={
                            x: sample_x,
                            y: sample_y,
                            z: sample_z
                        })

                if self.config.SUMMARY:
                    self.summary.summary_writer.add_summary(summary_o, epoch)

                print("[Sample] d_loss: %.8f, g_loss: %.8f" %
                      (d_loss_o, g_loss_o))
                save_images(samples_o[:64], image_manifold_size(64), \
                            os.path.join(self.config.SAMPLE_DIR, 'train_{:02d}.png'.format(epoch)))

            if self.config.SUMMARY:
                self.summary.summary_writer.flush()
                self.summary.summary_writer.close()

            # Save the model after all epochs
            save_name = str(epoch)
            saver.save(sess, 'model_' + save_name.zfill(4) + '.ckpt')
            return
    def train(self, Dataset, Model, sample_y):
        # Create input node
        init_op_train, init_op_val, NNIO = self._input_fn_train_val(Dataset)
        x_l_c, y_l_c, x_l_d, y_l_d, x_u = NNIO

        with tf.device('/gpu:0'):
            # Build up the graph
            PH, G, D, C, model = self._build_train_graph(Model)

            z_g_ph, y_g_ph, x_l_c_ph, y_l_c_ph, x_l_d_ph, y_l_d_ph, x_u_d_ph, x_u_c_ph, \
            train_ph, Lambda = PH
            if self.config.DATA_NAME == "cifar10":
                C_real_logits, C_unl_logits, C_unl_d_logits, C_fake_logits, C_unl_logits_rep = C
            else:
                C_real_logits, C_unl_logits, C_unl_d_logits, C_fake_logits = C

            lambda_1_ph = Lambda[0]
            if self.config.DATA_NAME == "cifar10":
                lambda_2_ph = Lambda[1]

            ## Sample the images
            sample_z_ph = tf.placeholder(
                tf.float32, [self.config.SAMPLE_SIZE, self.config.Z_DIM],
                name='sample_latent_variable')  # latent variable
            sample_y_ph = tf.placeholder(
                tf.float32, [self.config.SAMPLE_SIZE, self.config.NUM_CLASSES],
                name='condition_label_for_sampler')
            samples = model.good_sampler(sample_z_ph, sample_y_ph)

            # Create Loss function
            d_loss, g_loss, c_loss = self._goodGAN_loss(
                G, D, C, None, [y_g_ph, y_l_c_ph], Lambda, model.discriminator)

            # Create the metric
            accuracy, update_op, reset_op, preds, probs = self._metric(
                C_real_logits, y_l_c_ph)

        # Create optimizer
        with tf.name_scope('Train'):
            t_vars = tf.trainable_variables()

            g_vars = [var for var in t_vars if 'good_generator' in var.name]
            d_vars = [var for var in t_vars if 'discriminator' in var.name]
            c_vars = [var for var in t_vars if 'classifier' in var.name]

            d_optimizer = self._Adam_optimizer(lr=self.lr_ph,
                                               beta1=self.config.BETA1)
            g_optimizer = self._Adam_optimizer(lr=self.lr_ph,
                                               beta1=self.config.BETA1)
            c_optimizer = self._Adam_optimizer(lr=self.cla_lr_ph, beta1=0.5)
            # optimizer = self._Adam_optimizer(self.lr_ph, self.config.BETA1)
            d_solver = self._train_op(d_optimizer, d_loss, \
                                      var_list = d_vars)
            g_solver = self._train_op(g_optimizer, g_loss, \
                                       var_list = g_vars)
            c_solver_ = self._train_op(c_optimizer, c_loss, var_list=c_vars)

            ## add weight normalization for classifier
            # c_weight_loss = self._loss_weight_l2([var for var in t_vars if 'c_h2_lin' in var.name and 'b' not in var.name], eta = 1e-4)
            # c_loss += c_weight_loss

            # add moving average
            ema = tf.train.ExponentialMovingAverage(decay=0.9999)
            with tf.control_dependencies([c_solver_]):
                c_solver = ema.apply(c_vars)

        # Add summary
        if self.config.SUMMARY:
            summary_dict_train = {}
            summary_dict_val = {}
            if self.config.SUMMARY_SCALAR:
                scalar_train = {
                    'g_loss': g_loss,
                    'd_loss': d_loss,
                    'c_loss': c_loss,
                    'train_accuracy': accuracy
                }
                scalar_val = {'val_accuracy': accuracy}
                summary_dict_train['scalar'] = scalar_train
                summary_dict_val['scalar'] = scalar_val

            # Merge summary
            merged_summary_train = \
                self.summary_train.add_summary(summary_dict_train)
            merged_summary_val = \
                self.summary_val.add_summary(summary_dict_val)

        # Add saver
        saver = Saver(self.save_dir)

        # Create a session for training
        sess_config = tf.ConfigProto(
            allow_soft_placement=True,
            gpu_options=tf.GPUOptions(per_process_gpu_memory_fraction=0.8))
        # Use soft_placement to place those variables, which can be placed, on GPU

        # Build up latent variable for samples
        sample_z = np.random.uniform(low=-1.0,
                                     high=1.0,
                                     size=(self.config.SAMPLE_SIZE,
                                           self.config.Z_DIM)).astype(
                                               np.float32)

        with tf.Session(config=sess_config) as sess:
            if self.config.DEBUG:
                sess = tf_debug.LocalCLIDebugWrapperSession(sess)
            # Add graph to tensorboard
            if self.config.SUMMARY and self.config.SUMMARY_GRAPH:
                self.summary_train._graph_summary(sess.graph)

            lr = self.config.LEARNING_RATE
            cla_lr = self.config.CLA_LEARNINIG_RATE
            # Restore teh weights from the previous training
            if self.config.RESTORE:
                start_epoch = saver.restore(sess,
                                            dir_names=self.config.RUN,
                                            epoch=self.config.RESTORE_EPOCH)
                if start_epoch >= 300:
                    lr = lr * 0.995**(start_epoch - 300)
                    cla_lr = cla_lr * 0.99**(start_epoch - 300)
                initialize_uninitialized_vars(sess)
            else:
                # Create a new folder for saving model
                saver.set_save_path(comments=self.comments)
                start_epoch = 0

                # initialize the variables
                init_var = tf.group(tf.global_variables_initializer(), \
                                    tf.local_variables_initializer())
                sess.run(init_var)

            # Start Training
            tf.logging.info("Start training!")
            for epoch in range(1, self.config.EPOCHS + 1):
                tf.logging.info("Training for epoch {}.".format(epoch +
                                                                start_epoch))
                train_pr_bar = tf.contrib.keras.utils.Progbar(target= \
                                                                  int(self.config.TRAIN_SIZE / self.config.BATCH_SIZE))

                lambda_1 = self.config.FAKE_G_LAMBDA if (start_epoch +
                                                         epoch) > 200 else 0.

                if self.config.DATA_NAME == "cifar10":
                    # rampup_value = rampup(start_epoch + epoch)
                    # rampdown_value = rampdown(start_epoch + epoch)
                    # lambda_2 = rampup_value * 20 if epoch > 1 else 0
                    lambda_2 = 0.5 if epoch > 67 else 0
                    # b1_c = rampdown_value * 0.9 + (1.0 - rampdown_value) * 0.5
                    # lambda_2 = 0

                if (start_epoch + epoch >= 300):
                    lr = lr * 0.995
                    cla_lr = cla_lr * 0.99

                tf.logging.info("Lambda_1 {}.".format(lambda_1))
                sess.run(init_op_train)

                if self.config.PRE_TRAIN and (start_epoch + epoch <= 30):
                    tf.logging.info("Pre Training!")
                    for i in range(
                            int(self.config.TRAIN_SIZE /
                                self.config.BATCH_SIZE)):
                        # Feature labeled and unlabeled data
                        x_l_c_o, y_l_c_o, x_l_d_o, y_l_d_o, x_u_o = sess.run(
                            [x_l_c, y_l_c, x_l_d, y_l_d, x_u])
                        # Define latent vector
                        z = np.random.uniform(low=-1.0,
                                              high=1.0,
                                              size=(self.config.BATCH_SIZE_G,
                                                    self.config.Z_DIM)).astype(
                                                        np.float32)
                        # y = y_l_c_o  ## Todo: think about if this is the best way
                        y_temp = np.random.randint(
                            low=0,
                            high=self.config.NUM_CLASSES,
                            size=(self.config.BATCH_SIZE_G))
                        y = np.zeros((self.config.BATCH_SIZE_G,
                                      self.config.NUM_CLASSES))
                        y[np.arange(self.config.BATCH_SIZE_G), y_temp] = 1

                        if self.config.DATA_NAME == "cifar10":
                            z_g_ph, y_g_ph, x_l_c_ph, y_l_c_ph, x_l_d_ph, y_l_d_ph, x_u_d_ph, x_u_c_ph, \
                            train_ph, Lambda_ph = PH
                            lambda_1_ph, lambda_2_ph = Lambda_ph
                        else:
                            z_g_ph, y_g_ph, x_l_c_ph, y_l_c_ph, x_l_d_ph, y_l_d_ph, x_u_d_ph, x_u_c_ph, \
                            train_ph, lambda_1_ph = PH
                            lambda_1_ph = lambda_1_ph[0]

                        feed_dict = {
                            z_g_ph:
                            z,
                            y_g_ph:
                            y,
                            x_l_c_ph:
                            x_l_c_o,
                            y_l_c_ph:
                            y_l_c_o,
                            x_l_d_ph:
                            x_l_d_o,
                            y_l_d_ph:
                            y_l_d_o,
                            x_u_d_ph:
                            x_u_o[:self.config.BATCH_SIZE_U_D, ...],
                            x_u_c_ph:
                            x_u_o[self.config.
                                  BATCH_SIZE_U_D:self.config.BATCH_SIZE_U_D +
                                  self.config.BATCH_SIZE_U_C, ...],
                            train_ph:
                            True,
                            lambda_1_ph:
                            lambda_1,
                            self.lr_ph:
                            lr,
                            self.cla_lr_ph:
                            cla_lr
                        }

                        if self.config.DATA_NAME == "cifar10":
                            feed_dict[lambda_2_ph] = lambda_2
                            # feed_dict[self.cla_beta1] = b1_c

                        # Update classifier
                        _, c_loss_o = sess.run([c_solver, c_loss], \
                                               feed_dict=feed_dict)
                        # Update progress bar
                        train_pr_bar.update(i)

                else:
                    tf.logging.info("Good-GAN Training!")
                    for i in range(
                            int(self.config.TRAIN_SIZE /
                                self.config.BATCH_SIZE)):
                        # Feature labeled and unlabeled data
                        x_l_c_o, y_l_c_o, x_l_d_o, y_l_d_o, x_u_o = sess.run(
                            [x_l_c, y_l_c, x_l_d, y_l_d, x_u])
                        # Define latent vector
                        z = np.random.uniform(low=-1.0,
                                              high=1.0,
                                              size=(self.config.BATCH_SIZE_G,
                                                    self.config.Z_DIM)).astype(
                                                        np.float32)

                        y_temp = np.random.randint(
                            low=0,
                            high=self.config.NUM_CLASSES,
                            size=(self.config.BATCH_SIZE_G))
                        y = np.zeros((self.config.BATCH_SIZE_G,
                                      self.config.NUM_CLASSES))
                        y[np.arange(self.config.BATCH_SIZE_G), y_temp] = 1

                        if self.config.DATA_NAME == "cifar10":
                            z_g_ph, y_g_ph, x_l_c_ph, y_l_c_ph, x_l_d_ph, y_l_d_ph, x_u_d_ph, x_u_c_ph, \
                            train_ph, Lambda_ph = PH
                            lambda_1_ph, lambda_2_ph = Lambda_ph
                        else:
                            z_g_ph, y_g_ph, x_l_c_ph, y_l_c_ph, x_l_d_ph, y_l_d_ph, x_u_d_ph, x_u_c_ph, \
                            train_ph, lambda_1_ph = PH
                            lambda_1_ph = lambda_1_ph[0]
                        feed_dict = {
                            z_g_ph:
                            z,
                            y_g_ph:
                            y,
                            x_l_c_ph:
                            x_l_c_o,
                            y_l_c_ph:
                            y_l_c_o,
                            x_l_d_ph:
                            x_l_d_o,
                            y_l_d_ph:
                            y_l_d_o,
                            x_u_d_ph:
                            x_u_o[:self.config.BATCH_SIZE_U_D, ...],
                            x_u_c_ph:
                            x_u_o[self.config.
                                  BATCH_SIZE_U_D:self.config.BATCH_SIZE_U_D +
                                  self.config.BATCH_SIZE_U_C, ...],
                            train_ph:
                            True,
                            lambda_1_ph:
                            lambda_1,
                            self.lr_ph:
                            lr,
                            self.cla_lr_ph:
                            cla_lr
                        }

                        if self.config.DATA_NAME == "cifar10":
                            feed_dict[lambda_2_ph] = lambda_2
                            # feed_dict[self.cla_beta1] = b1_c

                        # Update discriminator
                        _, d_loss_o = sess.run([d_solver, d_loss], \
                                   feed_dict = feed_dict)
                        # Update generator
                        _, g_loss_o = sess.run([g_solver, g_loss], \
                                    feed_dict = feed_dict)
                        # _, g_loss_o = sess.run([g_solver, g_loss], \
                        #                        feed_dict=feed_dict)
                        # Update classifier
                        _, c_loss_o = sess.run([c_solver, c_loss], \
                                    feed_dict = feed_dict)
                        # Update progress bar
                        train_pr_bar.update(i)

                # Get the training statistics
                summary_train_o, d_loss_o, g_loss_o, c_loss_o = sess.run(
                    [merged_summary_train, d_loss, g_loss, c_loss],
                    feed_dict=feed_dict)

                tf.logging.info("The training stats for epoch {}: g_loss: {:.2f}, d_loss {:.2f}, c_loss: {:.2f}."\
                                .format(epoch + start_epoch, g_loss_o, d_loss_o, c_loss_o))

                if self.config.SUMMARY:
                    # Add summary
                    self.summary_train.summary_writer.add_summary(
                        summary_train_o, epoch + start_epoch)

                # Perform validation
                tf.logging.info("\nValidate for epoch {}.".format(epoch +
                                                                  start_epoch))
                sess.run(init_op_val + [reset_op])
                if not self.config.VAL_STEP:  # perform full validation
                    while True:
                        try:
                            x_l_c_o, y_l_c_o = sess.run([x_l_c, y_l_c])
                            feed_dict = {
                                z_g_ph:
                                z,
                                y_g_ph:
                                y,
                                x_l_c_ph:
                                x_l_c_o,
                                y_l_c_ph:
                                y_l_c_o,
                                x_l_d_ph:
                                x_l_d_o,
                                y_l_d_ph:
                                y_l_d_o,
                                x_u_d_ph:
                                x_u_o[:self.config.BATCH_SIZE_U_D, ...],
                                x_u_c_ph:
                                x_u_o[self.config.BATCH_SIZE_U_D:self.config.
                                      BATCH_SIZE_U_D +
                                      self.config.BATCH_SIZE_U_C, ...],
                                train_ph:
                                False,
                                lambda_1_ph:
                                lambda_1,
                                self.lr_ph:
                                lr,
                                self.cla_lr_ph:
                                cla_lr
                            }

                            if self.config.DATA_NAME == "cifar10":
                                feed_dict[lambda_2_ph] = lambda_2
                                # feed_dict[self.cla_beta1] = b1_c

                            accuracy_o, summary_val_o, _ = sess.run([accuracy, merged_summary_val] + update_op, \
                                                                    feed_dict = feed_dict)
                        except (tf.errors.InvalidArgumentError,
                                tf.errors.OutOfRangeError, ValueError):
                            break
                else:
                    for i in range(self.config.VAL_STEP):
                        x_l_c_o, y_l_c_o = sess.run([x_l_c, y_l_c])
                        feed_dict = {
                            z_g_ph:
                            z,
                            y_g_ph:
                            y,
                            x_l_c_ph:
                            x_l_c_o,
                            y_l_c_ph:
                            y_l_c_o,
                            x_l_d_ph:
                            x_l_d_o,
                            y_l_d_ph:
                            y_l_d_o,
                            x_u_d_ph:
                            x_u_o[:self.config.BATCH_SIZE_U_D, ...],
                            x_u_c_ph:
                            x_u_o[self.config.
                                  BATCH_SIZE_U_D:self.config.BATCH_SIZE_U_D +
                                  self.config.BATCH_SIZE_U_C, ...],
                            train_ph:
                            False,
                            lambda_1_ph:
                            lambda_1,
                            self.lr_ph:
                            lr,
                            self.cla_lr_ph:
                            cla_lr
                        }

                        if self.config.DATA_NAME == "cifar10":
                            feed_dict[lambda_2_ph] = lambda_2
                            # feed_dict[self.cla_beta1] = b1_c

                        accuracy_o, summary_val_o, _ = sess.run([accuracy, merged_summary_val] + update_op, \
                                                                feed_dict=feed_dict)

                tf.logging.info(
                    "\nThe current validation accuracy for epoch {} is {:.2f}.\n" \
                    .format(epoch + start_epoch, accuracy_o))
                # Add summary to tensorboard
                if self.config.SUMMARY:
                    self.summary_val.summary_writer.add_summary(
                        summary_val_o, epoch + start_epoch)

                # Sample generated images
                samples_o = sess.run(samples,
                                     feed_dict={
                                         sample_z_ph: sample_z,
                                         sample_y_ph: sample_y
                                     })

                if self.config.DATA_NAME == "mnist":
                    samples_o = np.reshape(samples_o, [-1, 28, 28, 1])

                save_images(samples_o[:64], image_manifold_size(64), \
                            os.path.join(self.config.SAMPLE_DIR, 'train_{:02d}.png'.format(epoch + start_epoch)))

                # Save the model per SAVE_PER_EPOCH
                if epoch % self.config.SAVE_PER_EPOCH == 0:
                    save_name = str(epoch + start_epoch)
                    saver.save(sess, 'model_' + save_name.zfill(4) \
                               + '.ckpt')

            if self.config.SUMMARY:
                self.summary_train.summary_writer.flush()
                self.summary_train.summary_writer.close()
                self.summary_val.summary_writer.flush()
                self.summary_val.summary_writer.close()

                # Save the model after all epochs
            save_name = str(epoch + start_epoch)
            saver.save(sess, 'model_' + save_name.zfill(4) + '.ckpt')
        return
Ejemplo n.º 16
0
    def train(self, model):
        tf.reset_default_graph()
        # Input node
        image, label, init_op_train, init_op_val \
            = self._input_fn_train_val()
        # Build up the train graph
        with tf.device('/cpu:0'):  #Todo: parameterize
            loss, accuracy, update_op, reset_op, histogram \
                = self._build_train_graph(image, label, model)
            with tf.name_scope('Train'):
                optimizer = self._SGD_w_Momentum_optimizer()
                train_op, grads = self._train_op_w_grads(optimizer, loss)

        tf.logging.debug(utils.variable_name_string())
        # Add summary
        if self.config.SUMMARY:
            if self.config.SUMMARY_TRAIN_VAL:
                summary_dict_train = {}
                summary_dict_val = {}
                if self.config.SUMMARY_SCALAR:
                    scalar_train = {'train_loss': loss}
                    scalar_val = {'val_loss': loss, 'val_accuracy': accuracy}
                    summary_dict_train['scalar'] = scalar_train
                    summary_dict_val['scalar'] = scalar_val
                if self.config.SUMMARY_IMAGE:
                    image = {'input_image': image}
                    summary_dict_train['image'] = image
                if self.config.SUMMARY_HISTOGRAM:
                    histogram ['Conv_Block_0_Weight'] = \
                        [var for var in tf.global_variables() \
                         if var.name=='conv2d/kernel:0'][0]
                    histogram = utils.grads_dict(grads, histogram)
                    summary_dict_train['histogram'] = histogram
                merged_summary_train = \
                self.summary_train.add_summary(summary_dict_train)
                merged_summary_val = \
                self.summary_val.add_summary(summary_dict_val)

        # Add saver
        saver = Saver(self.save_dir)

        with tf.Session() as sess:
            if self.config.SUMMARY and self.config.SUMMARY_GRAPH:
                if self.config.SUMMARY_TRAIN_VAL:
                    self.summary_train._graph_summary(sess.graph)

            if self.config.RESTORE:
                start_epoch = saver.restore(sess)
            else:
                # Create a new folder for saving model
                saver.set_save_path(comments=self.comments)
                start_epoch = 0
                # initialize the variables
                init_var = tf.group(tf.global_variables_initializer(), \
                             tf.local_variables_initializer())
                sess.run(init_var)

            for epoch in range(1, self.config.EPOCHS + 1):
                sess.run(init_op_train)
                while True:
                    try:
                        _, loss_out, summary_out = \
                            sess.run([train_op, loss, merged_summary_train])
                    except tf.errors.OutOfRangeError:
                        break
                if self.config.SUMMARY_TRAIN_VAL:
                    self.summary_train.summary_writer.add_summary(
                        summary_out, epoch)

                ## Perform test on validation
                sess.run([init_op_val, reset_op])
                loss_val = []
                while True:
                    try:
                        loss_out, accuracy_out, _, summary_out = \
                            sess.run([loss, accuracy, update_op, merged_summary_val])
                        loss_val.append(loss_out)
                    except tf.errors.OutOfRangeError:
                        tf.logging.info("The current validation loss for epoch {} is {:.2f}, accuracy is {:.2f}."\
                              .format(epoch, np.mean(loss_val), accuracy_out))
                        break
                if self.config.SUMMARY_TRAIN_VAL:
                    self.summary_val.summary_writer.add_summary(
                        summary_out, epoch)

                # Save the model per SAVE_PER_EPOCH
                if epoch % self.config.SAVE_PER_EPOCH == 0:
                    save_name = str(epoch + start_epoch)
                    saver.save(sess, 'model_' + save_name.zfill(4) \
                               + '.ckpt')

            if self.config.SUMMARY_TRAIN_VAL:
                self.summary_train.summary_writer.flush()
                self.summary_train.summary_writer.close()
                self.summary_val.summary_writer.flush()
                self.summary_val.summary_writer.close()
            # Save the model after all epochs
            save_name = str(epoch + start_epoch)
            saver.save(sess, 'model_' + save_name.zfill(4) + '.ckpt')
        return