Ejemplo n.º 1
0
def main(unused_argv):
    tf.logging.set_verbosity(tf.logging.INFO)

    labels = FLAGS.labels.split(',')
    num_classes = len(labels)

    tf.gfile.MakeDirs(FLAGS.train_logdir)
    tf.logging.info('Creating train logdir: %s', FLAGS.train_logdir)

    with tf.Graph().as_default() as graph:
        global_step = tf.train.get_or_create_global_step()

        X = tf.placeholder(tf.float32, [None, FLAGS.height, FLAGS.width, 3],
                           name='X')
        ground_truth = tf.placeholder(tf.int64, [None], name='ground_truth')
        is_training = tf.placeholder(tf.bool, name='is_training')
        keep_prob = tf.placeholder(tf.float32, [], name='keep_prob')
        # learning_rate = tf.placeholder(tf.float32, [])

        # apply SENet
        logits, end_points = model.hcd_model(X,
                                             num_classes=num_classes,
                                             is_training=is_training,
                                             keep_prob=keep_prob,
                                             attention_module='se_block')

        logits = tf.cond(
            is_training,
            lambda: tf.identity(logits), lambda: tf.reduce_mean(tf.reshape(
                logits, [FLAGS.val_batch_size, TEN_CROP, -1]),
                                                                axis=1))

        # Print name and shape of each tensor.
        tf.logging.info("++++++++++++++++++++++++++++++++++")
        tf.logging.info("Layers")
        tf.logging.info("++++++++++++++++++++++++++++++++++")
        for k, v in end_points.items():
            tf.logging.info('name = %s, shape = %s' % (v.name, v.get_shape()))

        # # Print name and shape of parameter nodes  (values not yet initialized)
        # tf.logging.info("++++++++++++++++++++++++++++++++++")
        # tf.logging.info("Parameters")
        # tf.logging.info("++++++++++++++++++++++++++++++++++")
        # for v in slim.get_model_variables():
        #     tf.logging.info('name = %s, shape = %s' % (v.name, v.get_shape()))

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        prediction = tf.argmax(logits, axis=1, name='prediction')
        correct_prediction = tf.equal(prediction, ground_truth)
        confusion_matrix = tf.confusion_matrix(ground_truth,
                                               prediction,
                                               num_classes=num_classes)
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32),
                                  name='accuracy')
        summaries.add(tf.summary.scalar('accuracy', accuracy))

        # Define loss
        tf.losses.sparse_softmax_cross_entropy(labels=ground_truth,
                                               logits=logits)

        # Gather update_ops. These contain, for example,
        # the updates for the batch_norm variables created by model.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

        # # Add summaries for model variables.
        # for model_var in slim.get_model_variables():
        #     summaries.add(tf.summary.histogram(model_var.op.name, model_var))

        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

        learning_rate = train_utils.get_model_learning_rate(
            FLAGS.learning_policy, FLAGS.base_learning_rate,
            FLAGS.learning_rate_decay_step, FLAGS.learning_rate_decay_factor,
            FLAGS.training_number_of_steps, FLAGS.learning_power,
            FLAGS.slow_start_step, FLAGS.slow_start_learning_rate)
        # optimizer = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum)
        optimizer = tf.train.AdamOptimizer(learning_rate)
        summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        total_loss, grads_and_vars = train_utils.optimize(optimizer)
        total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.')
        summaries.add(tf.summary.scalar('total_loss', total_loss))

        # # Modify the gradients for biases and last layer variables.
        # last_layers = train_utils.get_extra_layer_scopes(
        #     FLAGS.last_layers_contain_logits_only)
        # grad_mult = train_utils.get_model_gradient_multipliers(
        #     last_layers, FLAGS.last_layer_gradient_multiplier)
        # if grad_mult:
        #     grads_and_vars = slim.learning.multiply_gradients(
        #         grads_and_vars, grad_mult)

        # Gradient clipping
        # clipped_gvs = [(tf.clip_by_value(grad, -1., 1.), var) for grad, var in grads_and_vars]
        # Otherwise ->
        # gradients, variables = zip(*optimizer.compute_gradients(loss))
        # gradients, _ = tf.clip_by_global_norm(grads_and_vars[0], 5.0)
        # optimize = optimizer.apply_gradients(zip(gradients, grads_and_vars[1]))

        # TensorBoard: How to plot histogram for gradients
        grad_summ_op = tf.summary.merge([
            tf.summary.histogram("%s-grad" % g[1].name, g[0])
            for g in grads_and_vars
        ])

        # Create gradient update op.
        grad_updates = optimizer.apply_gradients(grads_and_vars,
                                                 global_step=global_step)
        update_ops.append(grad_updates)
        update_op = tf.group(*update_ops)
        with tf.control_dependencies([update_op]):
            train_op = tf.identity(total_loss, name='train_op')

        # Add the summaries. These contain the summaries
        # created by model and either optimize() or _gather_loss().
        summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries))
        train_writer = tf.summary.FileWriter(FLAGS.summaries_dir, graph)
        validation_writer = tf.summary.FileWriter(
            FLAGS.summaries_dir + '/validation', graph)

        ###############
        # Prepare data
        ###############
        # training dateset
        tfrecord_filenames = tf.placeholder(tf.string, shape=[])
        dataset = train_data.Dataset(tfrecord_filenames, FLAGS.batch_size,
                                     FLAGS.how_many_training_epochs,
                                     FLAGS.height, FLAGS.width)
        iterator = dataset.dataset.make_initializable_iterator()
        next_batch = iterator.get_next()

        # validation dateset
        val_dataset = val_data.Dataset(tfrecord_filenames,
                                       FLAGS.val_batch_size, FLAGS.height,
                                       FLAGS.width)
        val_iterator = val_dataset.dataset.make_initializable_iterator()
        val_next_batch = val_iterator.get_next()

        sess_config = tf.ConfigProto(gpu_options=tf.GPUOptions(
            allow_growth=True))
        with tf.Session(config=sess_config) as sess:
            sess.run(tf.global_variables_initializer())

            # Create a saver object which will save all the variables
            saver = tf.train.Saver()
            if FLAGS.saved_checkpoint_dir:
                if tf.gfile.IsDirectory(FLAGS.train_logdir):
                    checkpoint_path = tf.train.latest_checkpoint(
                        FLAGS.train_logdir)
                else:
                    checkpoint_path = FLAGS.train_logdir
                saver.restore(sess, checkpoint_path)

            if FLAGS.pre_trained_checkpoint:
                train_utils.restore_fn(FLAGS)

            start_epoch = 0
            # Get the number of training/validation steps per epoch
            tr_batches = int(PCAM_TRAIN_DATA_SIZE / FLAGS.batch_size)
            if PCAM_TRAIN_DATA_SIZE % FLAGS.batch_size > 0:
                tr_batches += 1
            val_batches = int(PCAM_VALIDATE_DATA_SIZE / FLAGS.val_batch_size)
            if PCAM_VALIDATE_DATA_SIZE % FLAGS.val_batch_size > 0:
                val_batches += 1

            # The filenames argument to the TFRecordDataset initializer can either be a string,
            # a list of strings, or a tf.Tensor of strings.
            train_record_filenames = os.path.join(FLAGS.dataset_dir,
                                                  'train.record')
            validate_record_filenames = os.path.join(FLAGS.dataset_dir,
                                                     'validate.record')
            ############################
            # Training loop.
            ############################
            for num_epoch in range(start_epoch,
                                   FLAGS.how_many_training_epochs):
                print("------------------------------------")
                print(" Epoch {} ".format(num_epoch))
                print("------------------------------------")

                sess.run(
                    iterator.initializer,
                    feed_dict={tfrecord_filenames: train_record_filenames})
                for step in range(tr_batches):
                    train_batch_xs, train_batch_ys = sess.run(next_batch)
                    # # Verify image
                    # # assert not np.any(np.isnan(train_batch_xs))
                    # n_batch = train_batch_xs.shape[0]
                    # # n_view = train_batch_xs.shape[1]
                    # for i in range(n_batch):
                    #     img = train_batch_xs[i]
                    #     # scipy.misc.toimage(img).show() Or
                    #     img = cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_BGR2RGB)
                    #     cv2.imwrite('/home/ace19/Pictures/' + str(i) + '.png', img)
                    #     # cv2.imshow(str(train_batch_ys[idx]), img)
                    #     cv2.waitKey(100)
                    #     cv2.destroyAllWindows()

                    augmented_batch_xs = aug_utils.aug(train_batch_xs)
                    # # Verify image
                    # # assert not np.any(np.isnan(train_batch_xs))
                    # n_batch = augmented_batch_xs.shape[0]
                    # # n_view = train_batch_xs.shape[1]
                    # for i in range(n_batch):
                    #     img = augmented_batch_xs[i]
                    #     # scipy.misc.toimage(img).show() Or
                    #     img = cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_BGR2RGB)
                    #     cv2.imwrite('/home/ace19/Pictures/' + str(i) + '.png', img)
                    #     # cv2.imshow(str(train_batch_ys[idx]), img)
                    #     cv2.waitKey(100)
                    #     cv2.destroyAllWindows()

                    # Run the graph with this batch of training data and learning rate policy.
                    lr, train_summary, train_accuracy, train_loss, grad_vals, _ = \
                        sess.run([learning_rate, summary_op, accuracy, total_loss, grad_summ_op, train_op],
                                 feed_dict={
                                     X: augmented_batch_xs,
                                     ground_truth: train_batch_ys,
                                     is_training: True,
                                     keep_prob: 0.8
                                 })
                    train_writer.add_summary(train_summary, num_epoch)
                    train_writer.add_summary(grad_vals, num_epoch)
                    tf.logging.info(
                        'Epoch #%d, Step #%d, rate %.10f, accuracy %.1f%%, loss %f'
                        % (num_epoch, step, lr, train_accuracy * 100,
                           train_loss))

                ###################################################
                # Validate the model on the validation set
                ###################################################
                tf.logging.info('--------------------------')
                tf.logging.info(' Start validation ')
                tf.logging.info('--------------------------')

                total_val_accuracy = 0
                validation_count = 0
                total_conf_matrix = None
                # Reinitialize iterator with the validation dataset
                sess.run(
                    val_iterator.initializer,
                    feed_dict={tfrecord_filenames: validate_record_filenames})
                for step in range(val_batches):
                    validation_batch_xs, validation_batch_ys = sess.run(
                        val_next_batch)
                    # TTA
                    batch_size, n_crops, c, h, w = validation_batch_xs.shape
                    # fuse batch size and ncrops
                    tencrop_val_batch_xs = np.reshape(validation_batch_xs,
                                                      (-1, c, h, w))

                    val_summary, val_accuracy, conf_matrix = sess.run(
                        [summary_op, accuracy, confusion_matrix],
                        feed_dict={
                            X: tencrop_val_batch_xs,
                            ground_truth: validation_batch_ys,
                            is_training: False,
                            keep_prob: 1.0
                        })

                    validation_writer.add_summary(val_summary, num_epoch)

                    total_val_accuracy += val_accuracy
                    validation_count += 1
                    if total_conf_matrix is None:
                        total_conf_matrix = conf_matrix
                    else:
                        total_conf_matrix += conf_matrix

                total_val_accuracy /= validation_count
                tf.logging.info('Confusion Matrix:\n %s' % (total_conf_matrix))
                tf.logging.info(
                    'Validation accuracy = %.1f%% (N=%d)' %
                    (total_val_accuracy * 100, PCAM_VALIDATE_DATA_SIZE))

                # Save the model checkpoint periodically.
                if (num_epoch <= FLAGS.how_many_training_epochs - 1):
                    checkpoint_path = os.path.join(FLAGS.train_logdir,
                                                   FLAGS.ckpt_name_to_save)
                    tf.logging.info('Saving to "%s-%d"', checkpoint_path,
                                    num_epoch)
                    saver.save(sess, checkpoint_path, global_step=num_epoch)
Ejemplo n.º 2
0
def main(unused_argv):
    tf.logging.set_verbosity(tf.logging.INFO)

    labels = FLAGS.labels.split(',')
    num_classes = len(labels)

    tf.gfile.MakeDirs(FLAGS.train_logdir)
    tf.logging.info('Creating train logdir: %s', FLAGS.train_logdir)

    with tf.Graph().as_default() as graph:
        global_step = tf.train.get_or_create_global_step()

        # Define the model
        X = tf.placeholder(
            tf.float32, [None, FLAGS.num_views, FLAGS.height, FLAGS.width, 3],
            name='X')
        # for 299 size, otherwise you should modify shape for ur size.
        final_X = tf.placeholder(tf.float32,
                                 [FLAGS.num_views, None, 8, 8, 1536],
                                 name='final_X')
        ground_truth = tf.placeholder(tf.int64, [None], name='ground_truth')
        is_training = tf.placeholder(tf.bool)
        is_training2 = tf.placeholder(tf.bool)
        dropout_keep_prob = tf.placeholder(tf.float32)
        grouping_scheme = tf.placeholder(tf.bool, [NUM_GROUP, FLAGS.num_views])
        grouping_weight = tf.placeholder(tf.float32, [NUM_GROUP, 1])
        # learning_rate = tf.placeholder(tf.float32)

        # Grouping Module
        d_scores, _, final_desc = gvcnn.discrimination_score(
            X, num_classes, is_training)

        # GVCNN
        logits, _ = gvcnn.gvcnn(final_X, grouping_scheme, grouping_weight,
                                num_classes, is_training2, dropout_keep_prob)

        # Define loss
        tf.reduce_mean(
            tf.losses.sparse_softmax_cross_entropy(labels=ground_truth,
                                                   logits=logits))

        # Gather update_ops. These contain, for example,
        # the updates for the batch_norm variables created by model.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        prediction = tf.argmax(logits, 1, name='prediction')
        correct_prediction = tf.equal(prediction, ground_truth)
        confusion_matrix = tf.confusion_matrix(ground_truth,
                                               prediction,
                                               num_classes=num_classes)
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
        summaries.add(tf.summary.scalar('accuracy', accuracy))

        # Add summaries for model variables.
        for model_var in slim.get_model_variables():
            summaries.add(tf.summary.histogram(model_var.op.name, model_var))

        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

        learning_rate = train_utils.get_model_learning_rate(
            FLAGS.learning_policy, FLAGS.base_learning_rate,
            FLAGS.learning_rate_decay_step, FLAGS.learning_rate_decay_factor,
            FLAGS.training_number_of_steps, FLAGS.learning_power,
            FLAGS.slow_start_step, FLAGS.slow_start_learning_rate)
        # optimizer = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum)
        optimizer = tf.train.AdamOptimizer(learning_rate)
        summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        # for variable in slim.get_model_variables():
        #     summaries.add(tf.summary.histogram(variable.op.name, variable))

        total_loss, grads_and_vars = train_utils.optimize(optimizer)
        total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.')
        summaries.add(tf.summary.scalar('total_loss', total_loss))

        # # Modify the gradients for biases and last layer variables.
        # last_layers = train_utils.get_extra_layer_scopes(
        #     FLAGS.last_layers_contain_logits_only)
        # grad_mult = train_utils.get_model_gradient_multipliers(
        #     last_layers, FLAGS.last_layer_gradient_multiplier)
        # if grad_mult:
        #     grads_and_vars = slim.learning.multiply_gradients(
        #         grads_and_vars, grad_mult)

        grad_summ_op = tf.summary.merge([
            tf.summary.histogram("%s-grad" % g[1].name, g[0])
            for g in grads_and_vars
        ])

        # Create gradient update op.
        grad_updates = optimizer.apply_gradients(grads_and_vars,
                                                 global_step=global_step)
        update_ops.append(grad_updates)
        update_op = tf.group(*update_ops)
        with tf.control_dependencies([update_op]):
            train_op = tf.identity(total_loss, name='train_op')

        # Add the summaries. These contain the summaries
        # created by model and either optimize() or _gather_loss().
        summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries))
        train_writer = tf.summary.FileWriter(FLAGS.summaries_dir, graph)
        validation_writer = tf.summary.FileWriter(
            FLAGS.summaries_dir + '/validation', graph)

        ################
        # Prepare data
        ################
        filenames = tf.placeholder(tf.string, shape=[])
        tr_dataset = data.Dataset(filenames, FLAGS.num_views, FLAGS.height,
                                  FLAGS.width, FLAGS.batch_size)
        iterator = tr_dataset.dataset.make_initializable_iterator()
        next_batch = iterator.get_next()

        sess_config = tf.ConfigProto(gpu_options=tf.GPUOptions(
            allow_growth=True))
        with tf.Session(config=sess_config) as sess:
            sess.run(tf.global_variables_initializer())

            # TODO:
            # Create a saver object which will save all the variables
            saver = tf.train.Saver(keep_checkpoint_every_n_hours=1.0)
            if FLAGS.pre_trained_checkpoint:
                train_utils.restore_fn(FLAGS)

            start_epoch = 0
            # Get the number of training/validation steps per epoch
            tr_batches = int(MODELNET_TRAIN_DATA_SIZE / FLAGS.batch_size)
            if MODELNET_TRAIN_DATA_SIZE % FLAGS.batch_size > 0:
                tr_batches += 1
            val_batches = int(MODELNET_VALIDATE_DATA_SIZE / FLAGS.batch_size)
            if MODELNET_VALIDATE_DATA_SIZE % FLAGS.batch_size > 0:
                val_batches += 1

            # The filenames argument to the TFRecordDataset initializer can either be a string,
            # a list of strings, or a tf.Tensor of strings.
            training_filenames = os.path.join(FLAGS.dataset_dir,
                                              'train.record')
            validate_filenames = os.path.join(FLAGS.dataset_dir,
                                              'validate.record')
            ##################
            # Training loop.
            ##################
            for training_epoch in range(start_epoch,
                                        FLAGS.how_many_training_epochs):
                print("-------------------------------------")
                print(" Epoch {} ".format(training_epoch))
                print("-------------------------------------")

                sess.run(iterator.initializer,
                         feed_dict={filenames: training_filenames})
                for step in range(tr_batches):
                    # Pull the image batch we'll use for training.
                    train_batch_xs, train_batch_ys = sess.run(next_batch)

                    # # Verify image
                    # assert not np.any(np.isnan(train_batch_xs))
                    # n_batch = train_batch_xs.shape[0]
                    # n_view = train_batch_xs.shape[1]
                    # for i in range(n_batch):
                    #     for j in range(n_view):
                    #         img = train_batch_xs[i][j]
                    #         # scipy.misc.toimage(img).show()
                    #         # Or
                    #         img = cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_BGR2RGB)
                    #         cv2.imwrite('/home/ace19/Pictures/' + str(i) +
                    #                     '_' + str(j) + '.png', img)
                    #         # cv2.imshow(str(train_batch_ys[idx]), img)
                    #         cv2.waitKey(100)
                    #         cv2.destroyAllWindows()

                    # Sets up a graph with feeds and fetches for partial run.
                    handle = sess.partial_run_setup([
                        d_scores, final_desc, learning_rate, summary_op,
                        accuracy, total_loss, grad_summ_op, train_op
                    ], [
                        X, final_X, ground_truth, grouping_scheme,
                        grouping_weight, is_training, is_training2,
                        dropout_keep_prob
                    ])

                    scores, final = sess.partial_run(handle,
                                                     [d_scores, final_desc],
                                                     feed_dict={
                                                         X: train_batch_xs,
                                                         is_training: True
                                                     })
                    schemes = gvcnn.grouping_scheme(scores, NUM_GROUP,
                                                    FLAGS.num_views)
                    weights = gvcnn.grouping_weight(scores, schemes)

                    # Run the graph with this batch of training data.
                    lr, train_summary, train_accuracy, train_loss, grad_vals, _ = \
                        sess.partial_run(handle,
                                         [learning_rate, summary_op, accuracy, total_loss, grad_summ_op, train_op],
                                         feed_dict={
                                             final_X: final,
                                             ground_truth: train_batch_ys,
                                             grouping_scheme: schemes,
                                             grouping_weight: weights,
                                             is_training2: True,
                                             dropout_keep_prob: 0.8}
                                         )

                    train_writer.add_summary(train_summary, training_epoch)
                    train_writer.add_summary(grad_vals, training_epoch)
                    tf.logging.info(
                        'Epoch #%d, Step #%d, rate %.10f, accuracy %.1f%%, loss %f'
                        % (training_epoch, step, lr, train_accuracy * 100,
                           train_loss))

                ###################################################
                # Validate the model on the validation set
                ###################################################
                # tf.logging.info('--------------------------')
                # tf.logging.info(' Start validation ')
                # tf.logging.info('--------------------------')
                #
                # # Reinitialize iterator with the validation dataset
                # sess.run(iterator.initializer, feed_dict={filenames: validate_filenames})
                # total_val_accuracy = 0
                # validation_count = 0
                # total_conf_matrix = None
                #
                # for step in range(val_batches):
                #     validation_batch_xs, validation_batch_ys = sess.run(next_batch)
                #
                #     # Sets up a graph with feeds and fetches for partial run.
                #     handle = sess.partial_run_setup([d_scores, final_desc,
                #                                      summary_op, accuracy, confusion_matrix],
                #                                     [X, final_X, ground_truth, learning_rate,
                #                                      grouping_scheme, grouping_weight, is_training,
                #                                      is_training2, dropout_keep_prob])
                #
                #     scores, final = sess.partial_run(handle,
                #                                      [d_scores, final_desc],
                #                                      feed_dict={
                #                                          X: validation_batch_xs,
                #                                          is_training: False}
                #                                      )
                #     schemes = gvcnn.grouping_scheme(scores, NUM_GROUP, FLAGS.num_views)
                #     weights = gvcnn.grouping_weight(scores, schemes)
                #
                #     # Run the graph with this batch of training data.
                #     val_summary, val_accuracy, conf_matrix = \
                #         sess.partial_run(handle,
                #                          [summary_op, accuracy, confusion_matrix],
                #                          feed_dict={
                #                              final_X: final,
                #                              ground_truth: validation_batch_ys,
                #                              grouping_scheme: schemes,
                #                              grouping_weight: weights,
                #                              is_training2: False,
                #                              dropout_keep_prob: 1.0}
                #                          )
                #
                #     validation_writer.add_summary(val_summary, training_epoch)
                #
                #     total_val_accuracy += val_accuracy
                #     validation_count += 1
                #     if total_conf_matrix is None:
                #         total_conf_matrix = conf_matrix
                #     else:
                #         total_conf_matrix += conf_matrix
                #
                #
                # total_val_accuracy /= validation_count
                # tf.logging.info('Confusion Matrix:\n %s' % (total_conf_matrix))
                # tf.logging.info('Validation accuracy = %.1f%% (N=%d)' %
                #                 (total_val_accuracy * 100, MODELNET_VALIDATE_DATA_SIZE))

                # Save the model checkpoint periodically.
                if (training_epoch <= FLAGS.how_many_training_epochs - 1):
                    checkpoint_path = os.path.join(FLAGS.train_logdir,
                                                   FLAGS.ckpt_name_to_save)
                    tf.logging.info('Saving to "%s-%d"', checkpoint_path,
                                    training_epoch)
                    saver.save(sess,
                               checkpoint_path,
                               global_step=training_epoch)
Ejemplo n.º 3
0
def main(unused_argv):
    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)

    labels = FLAGS.labels.split(',')
    num_classes = len(labels)

    # tf.compat.v1.logging.info('Creating train logdir: %s', FLAGS.train_logdir)

    with tf.Graph().as_default() as graph:
        global_step = tf.compat.v1.train.get_or_create_global_step()

        X = tf.compat.v1.placeholder(
            tf.float32, [None, FLAGS.num_views, FLAGS.height, FLAGS.width, 3],
            name='X')
        ground_truth = tf.compat.v1.placeholder(tf.int64, [None],
                                                name='ground_truth')
        is_training = tf.compat.v1.placeholder(tf.bool, name='is_training')
        dropout_keep_prob = tf.compat.v1.placeholder(tf.float32,
                                                     name='dropout_keep_prob')
        # learning_rate = tf.placeholder(tf.float32, name='lr')

        # metric learning
        logits, features = \
            model.mvcnn_with_deep_cosine_metric_learning(X,
                                                         num_classes,
                                                         is_training=is_training,
                                                         keep_prob=dropout_keep_prob,
                                                         attention_module='se_block')
        # logits, features = mvcnn.mvcnn(X, num_classes)

        cross_entropy = tf.compat.v1.losses.sparse_softmax_cross_entropy(
            labels=ground_truth, logits=logits)
        tf.compat.v1.summary.scalar("cross_entropy_loss", cross_entropy)

        # Gather update ops. These contain, for example, the updates for the
        # batch_norm variables created by model.
        update_ops = tf.compat.v1.get_collection(
            tf.compat.v1.GraphKeys.UPDATE_OPS)

        # Gather initial summaries.
        summaries = set(
            tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.SUMMARIES))

        predition = tf.argmax(logits, 1, name='prediction')
        correct_predition = tf.equal(predition, ground_truth)
        confusion_matrix = tf.math.confusion_matrix(ground_truth,
                                                    predition,
                                                    num_classes=num_classes)
        # accuracy = tf.reduce_mean(tf.cast(correct_predition, tf.float32))
        # summaries.add(tf.summary.scalar('accuracy', accuracy))
        accuracy = slim.metrics.accuracy(tf.cast(predition, tf.int64),
                                         ground_truth)
        tf.compat.v1.summary.scalar("accuracy", accuracy)

        # Add summaries for model variables.
        for model_var in slim.get_model_variables():
            summaries.add(
                tf.compat.v1.summary.histogram(model_var.op.name, model_var))

        # Add summaries for losses.
        for loss in tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.LOSSES):
            summaries.add(
                tf.compat.v1.summary.scalar('losses/%s' % loss.op.name, loss))

        learning_rate = train_utils.get_model_learning_rate(
            FLAGS.learning_policy, FLAGS.base_learning_rate,
            FLAGS.learning_rate_decay_step, FLAGS.learning_rate_decay_factor,
            FLAGS.training_number_of_steps, FLAGS.learning_power,
            FLAGS.slow_start_step, FLAGS.slow_start_learning_rate)
        # optimizer = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum)
        optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate)
        summaries.add(
            tf.compat.v1.summary.scalar('learning_rate', learning_rate))

        total_loss, grads_and_vars = train_utils.optimize(optimizer)
        total_loss = tf.compat.v1.check_numerics(total_loss,
                                                 'Loss is inf or nan')
        summaries.add(tf.compat.v1.summary.scalar('total_loss', total_loss))

        # TensorBoard: How to plot histogram for gradients
        # grad_summ_op = tf.compat.v1.summary.merge([tf.compat.v1.summary.histogram("%s-grad" % g[1].name, g[0]) for g in grads_and_vars])

        # Create gradient update op.
        grad_updates = optimizer.apply_gradients(grads_and_vars,
                                                 global_step=global_step)
        update_ops.append(grad_updates)
        update_op = tf.group(*update_ops)
        with tf.control_dependencies([update_op]):
            train_op = tf.identity(total_loss, name='train_op')

        # Add the summaries. These contain the summaries created by model
        # and either optimize() or _gather_loss()
        summaries |= set(
            tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.SUMMARIES))

        # Merge all summaries together.
        summary_op = tf.compat.v1.summary.merge(list(summaries))
        train_writer = tf.compat.v1.summary.FileWriter(FLAGS.summaries_dir,
                                                       graph)
        validation_writer = tf.compat.v1.summary.FileWriter(
            FLAGS.summaries_dir + '/validation', graph)

        #####################
        # prepare data
        #####################
        tfrecord_names = tf.compat.v1.placeholder(tf.string, shape=[])
        _dataset = data.Dataset(tfrecord_names, FLAGS.num_views, FLAGS.height,
                                FLAGS.width, FLAGS.batch_size)
        iterator = _dataset.dataset.make_initializable_iterator()
        next_batch = iterator.get_next()

        sess_config = tf.compat.v1.ConfigProto(
            gpu_options=tf.compat.v1.GPUOptions(allow_growth=True))
        with tf.compat.v1.Session(config=sess_config) as sess:
            sess.run(tf.compat.v1.global_variables_initializer())

            saver = tf.compat.v1.train.Saver(keep_checkpoint_every_n_hours=1.0)
            if FLAGS.pre_trained_checkpoint:
                train_utils.restore_fn(FLAGS)

            start_epoch = 0
            training_batches = int(MODELNET10_TRAIN_DATA_SIZE /
                                   FLAGS.batch_size)
            if MODELNET10_TRAIN_DATA_SIZE % FLAGS.batch_size > 0:
                training_batches += 1
            val_batches = int(MODELNET10_VALIDATE_DATA_SIZE / FLAGS.batch_size)
            if MODELNET10_VALIDATE_DATA_SIZE % FLAGS.batch_size > 0:
                val_batches += 1

            # The filenames argument to the TFRecordDataset initializer can either
            # be a string, a list of strings, or a tf.Tensor of strings.
            training_tf_filenames = os.path.join(FLAGS.dataset_dir,
                                                 'train.record')
            val_tf_filenames = os.path.join(FLAGS.dataset_dir,
                                            'validate.record')
            ##################
            # Training loop.
            ##################
            for n_epoch in range(start_epoch, FLAGS.how_many_training_epochs):
                tf.compat.v1.logging.info('--------------------------')
                tf.compat.v1.logging.info(' Epoch %d' % n_epoch)
                tf.compat.v1.logging.info('--------------------------')

                sess.run(iterator.initializer,
                         feed_dict={tfrecord_names: training_tf_filenames})
                for step in range(training_batches):
                    train_batch_xs, train_batch_ys = sess.run(next_batch)
                    # # Verify image
                    # assert not np.any(np.isnan(train_batch_xs))
                    # n_batch = train_batch_xs.shape[0]
                    # n_view = train_batch_xs.shape[1]
                    # for i in range(n_batch):
                    #     for j in range(n_view):
                    #         img = train_batch_xs[i][j]
                    #         # scipy.misc.toimage(img).show()
                    #         # Or
                    #         img = cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_BGR2RGB)
                    #         cv2.imwrite('/home/ace19/Pictures/' + str(i) +
                    #                     '_' + str(j) + '.png', img)
                    #         # cv2.imshow(str(train_batch_ys[idx]), img)
                    #         cv2.waitKey(100)
                    #         cv2.destroyAllWindows()

                    lr, train_summary, train_accuracy, train_loss, _ = \
                        sess.run([learning_rate, summary_op, accuracy, total_loss, train_op],
                                 feed_dict={X: train_batch_xs,
                                            ground_truth: train_batch_ys,
                                            is_training: True,
                                            dropout_keep_prob: 0.8})

                    # lr, train_summary, train_accuracy, train_loss, grad_vals, _ = \
                    #     sess.run([learning_rate, summary_op, accuracy, total_loss, grad_summ_op, train_op],
                    #     feed_dict={X: train_batch_xs,
                    #                ground_truth: train_batch_ys,
                    #                is_training: True,
                    #                dropout_keep_prob: 0.8})

                    train_writer.add_summary(train_summary, n_epoch)
                    # train_writer.add_summary(grad_vals, n_epoch)
                    tf.compat.v1.logging.info(
                        'Epoch #%d, Step #%d, rate %.10f, accuracy %.1f%%, loss %f'
                        %
                        (n_epoch, step, lr, train_accuracy * 100, train_loss))

                ###################################################
                # Validate the model on the validation set
                ###################################################
                tf.compat.v1.logging.info('--------------------------')
                tf.compat.v1.logging.info(' Start validation ')
                tf.compat.v1.logging.info('--------------------------')

                # Reinitialize iterator with the validation dataset
                sess.run(iterator.initializer,
                         feed_dict={tfrecord_names: val_tf_filenames})

                total_val_accuracy = 0
                validation_count = 0
                total_conf_matrix = None
                for step in range(val_batches):
                    validation_batch_xs, validation_batch_ys = sess.run(
                        next_batch)

                    val_summary, val_accuracy, conf_matrix = \
                        sess.run([summary_op, accuracy, confusion_matrix],
                                 feed_dict={X: validation_batch_xs,
                                            ground_truth: validation_batch_ys,
                                            is_training: False,
                                            dropout_keep_prob: 1.0})

                    validation_writer.add_summary(val_summary, n_epoch)

                    total_val_accuracy += val_accuracy
                    validation_count += 1
                    if total_conf_matrix is None:
                        total_conf_matrix = conf_matrix
                    else:
                        total_conf_matrix += conf_matrix

                total_val_accuracy /= validation_count
                tf.compat.v1.logging.info('Confusion Matrix:\n %s' %
                                          (total_conf_matrix))
                tf.compat.v1.logging.info(
                    'Validation accuracy = %.1f%% (N=%d)' %
                    (total_val_accuracy * 100, MODELNET10_VALIDATE_DATA_SIZE))

                # Save the model checkpoint periodically.
                if (n_epoch <= FLAGS.how_many_training_epochs - 1):
                    checkpoint_path = os.path.join(FLAGS.train_logdir,
                                                   FLAGS.ckpt_name_to_save)
                    tf.compat.v1.logging.info('Saving to "%s-%d"',
                                              checkpoint_path, n_epoch)
                    saver.save(sess, checkpoint_path, global_step=n_epoch)
Ejemplo n.º 4
0
def main(unused_argv):
    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)

    labels = FLAGS.labels.split(',')
    num_classes = len(labels)

    with tf.Graph().as_default() as graph:
        global_step = tf.compat.v1.train.get_or_create_global_step()

        # Define the model
        X = tf.compat.v1.placeholder(
            tf.float32, [None, FLAGS.num_views, FLAGS.height, FLAGS.width, 3],
            name='X')
        ground_truth = tf.compat.v1.placeholder(tf.int64, [None],
                                                name='ground_truth')
        is_training = tf.compat.v1.placeholder(tf.bool, name='is_training')
        dropout_keep_prob = tf.compat.v1.placeholder(tf.float32,
                                                     name='dropout_keep_prob')
        g_scheme = tf.compat.v1.placeholder(tf.int32,
                                            [FLAGS.num_group, FLAGS.num_views])
        g_weight = tf.compat.v1.placeholder(tf.float32, [FLAGS.num_group])

        # GVCNN
        view_scores, _, logits = model.gvcnn(X, num_classes, g_scheme,
                                             g_weight, is_training,
                                             dropout_keep_prob)

        # # basic - for verification
        # _, logits = model.basic(X,
        #                         num_classes,
        #                         is_training,
        #                         dropout_keep_prob)

        # Define loss
        _loss = tf.losses.sparse_softmax_cross_entropy(labels=ground_truth,
                                                       logits=logits)

        # Gather initial summaries.
        summaries = set(
            tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.SUMMARIES))

        prediction = tf.argmax(logits, 1, name='prediction')
        correct_prediction = tf.equal(prediction, ground_truth)
        confusion_matrix = tf.math.confusion_matrix(ground_truth,
                                                    prediction,
                                                    num_classes=num_classes)
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
        summaries.add(tf.compat.v1.summary.scalar('accuracy', accuracy))

        # # Add summaries for model variables.
        # for model_var in slim.get_model_variables():
        #     summaries.add(tf.compat.v1.summary.histogram(model_var.op.name, model_var))

        # Add summaries for losses.
        for loss in tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.LOSSES):
            summaries.add(
                tf.compat.v1.summary.scalar('losses/%s' % loss.op.name, loss))

        learning_rate = train_utils.get_model_learning_rate(
            FLAGS.learning_policy, FLAGS.base_learning_rate,
            FLAGS.learning_rate_decay_step, FLAGS.learning_rate_decay_factor,
            FLAGS.training_number_of_steps, FLAGS.learning_power,
            FLAGS.slow_start_step, FLAGS.slow_start_learning_rate)
        optimizer = tf.compat.v1.train.MomentumOptimizer(
            learning_rate, FLAGS.momentum)
        summaries.add(
            tf.compat.v1.summary.scalar('learning_rate', learning_rate))

        total_loss, grads_and_vars = train_utils.optimize(optimizer)
        total_loss = tf.debugging.check_numerics(total_loss,
                                                 'Loss is inf or nan.')
        summaries.add(tf.compat.v1.summary.scalar('total_loss', total_loss))

        # Gather update_ops.
        # These contain, for example, the updates for the batch_norm variables created by model.
        update_ops = tf.compat.v1.get_collection(
            tf.compat.v1.GraphKeys.UPDATE_OPS)

        # Create gradient update op.
        update_ops.append(
            optimizer.apply_gradients(grads_and_vars, global_step=global_step))
        update_op = tf.group(*update_ops)
        with tf.control_dependencies([update_op]):
            train_op = tf.identity(total_loss, name='train_op')

        ################
        # Prepare data
        ################
        filenames = tf.compat.v1.placeholder(tf.string, shape=[])
        tr_dataset = train_data.Dataset(filenames, FLAGS.num_views,
                                        FLAGS.height, FLAGS.width,
                                        FLAGS.batch_size)
        iterator = tr_dataset.dataset.make_initializable_iterator()
        next_batch = iterator.get_next()

        # validation dateset
        val_dataset = val_data.Dataset(filenames, FLAGS.num_views,
                                       FLAGS.height, FLAGS.width,
                                       FLAGS.val_batch_size)  # val_batch_size
        val_iterator = val_dataset.dataset.make_initializable_iterator()
        val_next_batch = val_iterator.get_next()

        sess_config = tf.compat.v1.ConfigProto(
            gpu_options=tf.compat.v1.GPUOptions(allow_growth=True))
        with tf.compat.v1.Session(config=sess_config) as sess:
            sess.run(tf.compat.v1.global_variables_initializer())

            # Add the summaries. These contain the summaries
            # created by model and either optimize() or _gather_loss().
            summaries |= set(
                tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.SUMMARIES))

            # Merge all summaries together.
            summary_op = tf.compat.v1.summary.merge(list(summaries))
            train_writer = tf.compat.v1.summary.FileWriter(
                FLAGS.summaries_dir, graph)
            validation_writer = tf.compat.v1.summary.FileWriter(
                FLAGS.summaries_dir + '/validation', graph)

            # Create a saver object which will save all the variables
            saver = tf.compat.v1.train.Saver(keep_checkpoint_every_n_hours=1.0)
            if FLAGS.pre_trained_checkpoint:
                train_utils.restore_fn(FLAGS)

            if FLAGS.saved_checkpoint_dir:
                if tf.gfile.IsDirectory(FLAGS.saved_checkpoint_dir):
                    checkpoint_path = tf.train.latest_checkpoint(
                        FLAGS.saved_checkpoint_dir)
                else:
                    checkpoint_path = FLAGS.saved_checkpoint_dir
                saver.restore(sess, checkpoint_path)

            start_epoch = 0
            # Get the number of training/validation steps per epoch
            tr_batches = int(MODELNET_TRAIN_DATA_SIZE / FLAGS.batch_size)
            if MODELNET_TRAIN_DATA_SIZE % FLAGS.batch_size > 0:
                tr_batches += 1
            val_batches = int(MODELNET_VALIDATE_DATA_SIZE /
                              FLAGS.val_batch_size)
            if MODELNET_VALIDATE_DATA_SIZE % FLAGS.val_batch_size > 0:
                val_batches += 1

            # The filenames argument to the TFRecordDataset initializer can either be a string,
            # a list of strings, or a tf.Tensor of strings.
            training_filenames = os.path.join(FLAGS.dataset_dir,
                                              'modelnet5_6view_train.record')
            validate_filenames = os.path.join(FLAGS.dataset_dir,
                                              'modelnet5_6view_test.record')

            ###################################
            # Training loop.
            ###################################
            for num_epoch in range(start_epoch,
                                   FLAGS.how_many_training_epochs):
                print("-------------------------------------")
                print(" Epoch {} ".format(num_epoch))
                print("-------------------------------------")

                sess.run(iterator.initializer,
                         feed_dict={filenames: training_filenames})
                for step in range(tr_batches):
                    # Pull the image batch we'll use for training.
                    train_batch_xs, train_batch_ys = sess.run(next_batch)

                    # Sets up a graph with feeds and fetches for partial run.
                    handle = sess.partial_run_setup(
                        [
                            view_scores,
                            learning_rate,
                            # summary_op, top1_acc, loss, optimize_op, dummy],
                            summary_op,
                            accuracy,
                            _loss,
                            train_op
                        ],
                        [
                            X, ground_truth, g_scheme, g_weight, is_training,
                            dropout_keep_prob
                        ])

                    _view_scores = sess.partial_run(handle, [view_scores],
                                                    feed_dict={
                                                        X: train_batch_xs,
                                                        is_training: True,
                                                        dropout_keep_prob: 0.8
                                                    })
                    _g_schemes = model.group_scheme(_view_scores,
                                                    FLAGS.num_group,
                                                    FLAGS.num_views)
                    _g_weights = model.group_weight(_g_schemes)

                    # Run the graph with this batch of training data.
                    lr, train_summary, train_accuracy, train_loss, _ = \
                        sess.partial_run(handle,
                                         [learning_rate, summary_op, accuracy, _loss, train_op],
                                         feed_dict={
                                             ground_truth: train_batch_ys,
                                             g_scheme: _g_schemes,
                                             g_weight: _g_weights}
                                         )

                    # for verification
                    # lr, train_summary, train_accuracy, train_loss, _ = \
                    #     sess.run([learning_rate, summary_op, accuracy, _loss, train_op],
                    #              feed_dict={
                    #                  X: train_batch_xs,
                    #                  ground_truth: train_batch_ys,
                    #                  is_training: True,
                    #                  dropout_keep_prob: 0.8}
                    #              )

                    train_writer.add_summary(train_summary, num_epoch)
                    tf.compat.v1.logging.info(
                        'Epoch #%d, Step #%d, rate %.6f, top1_acc %.3f%%, loss %.5f'
                        % (num_epoch, step, lr, train_accuracy, train_loss))

                ###################################################
                # Validate the model on the validation set
                ###################################################
                tf.compat.v1.logging.info('--------------------------')
                tf.compat.v1.logging.info(' Start validation ')
                tf.compat.v1.logging.info('--------------------------')

                total_val_losses = 0.0
                total_val_top1_acc = 0.0
                val_count = 0
                total_conf_matrix = None

                # Reinitialize val_iterator with the validation dataset
                sess.run(val_iterator.initializer,
                         feed_dict={filenames: validate_filenames})
                for step in range(val_batches):
                    validation_batch_xs, validation_batch_ys = sess.run(
                        val_next_batch)

                    # Sets up a graph with feeds and fetches for partial run.
                    handle = sess.partial_run_setup([
                        view_scores, summary_op, accuracy, _loss,
                        confusion_matrix
                    ], [
                        X, g_scheme, g_weight, ground_truth, is_training,
                        dropout_keep_prob
                    ])

                    _view_scores = sess.partial_run(handle, [view_scores],
                                                    feed_dict={
                                                        X: validation_batch_xs,
                                                        is_training: False,
                                                        dropout_keep_prob: 1.0
                                                    })
                    _g_schemes = model.group_scheme(_view_scores,
                                                    FLAGS.num_group,
                                                    FLAGS.num_views)
                    _g_weights = model.group_weight(_g_schemes)

                    # Run the graph with this batch of training data.
                    val_summary, val_accuracy, val_loss, conf_matrix = \
                        sess.partial_run(handle,
                                         [summary_op, accuracy, _loss, confusion_matrix],
                                         feed_dict={
                                             ground_truth: validation_batch_ys,
                                             g_scheme: _g_schemes,
                                             g_weight: _g_weights}
                                         )

                    # for verification
                    # val_summary, val_accuracy, val_loss, conf_matrix = \
                    #     sess.run([summary_op, accuracy, _loss, confusion_matrix],
                    #              feed_dict={
                    #                  X: validation_batch_xs,
                    #                  ground_truth: validation_batch_ys,
                    #                  is_training: False,
                    #                  dropout_keep_prob: 1.0}
                    #              )

                    validation_writer.add_summary(val_summary, num_epoch)

                    total_val_losses += val_loss
                    total_val_top1_acc += val_accuracy
                    val_count += 1
                    if total_conf_matrix is None:
                        total_conf_matrix = conf_matrix
                    else:
                        total_conf_matrix += conf_matrix

                total_val_losses /= val_count
                total_val_top1_acc /= val_count

                tf.compat.v1.logging.info('Confusion Matrix:\n %s' %
                                          total_conf_matrix)
                tf.compat.v1.logging.info('Validation loss = %.5f' %
                                          total_val_losses)
                tf.compat.v1.logging.info(
                    'Validation accuracy = %.3f%% (N=%d)' %
                    (total_val_top1_acc, MODELNET_VALIDATE_DATA_SIZE))

                # Save the model checkpoint periodically.
                if (num_epoch <= FLAGS.how_many_training_epochs - 1):
                    checkpoint_path = os.path.join(FLAGS.train_logdir,
                                                   FLAGS.ckpt_name_to_save)
                    tf.compat.v1.logging.info('Saving to "%s-%d"',
                                              checkpoint_path, num_epoch)
                    saver.save(sess, checkpoint_path, global_step=num_epoch)
Ejemplo n.º 5
0
def main(unused_argv):
    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)

    labels = FLAGS.labels.split(',')
    num_classes = len(labels)

    with tf.Graph().as_default() as graph:
        global_step = tf.compat.v1.train.get_or_create_global_step()

        X = tf.compat.v1.placeholder(tf.float32,
                                     [None, FLAGS.height, FLAGS.width, 3],
                                     name='X')
        ground_truth = tf.compat.v1.placeholder(tf.int64, [None],
                                                name='ground_truth')
        is_training = tf.compat.v1.placeholder(tf.bool, name='is_training')
        keep_prob = tf.compat.v1.placeholder(tf.float32, [], name='keep_prob')
        tfrecord_filenames = tf.compat.v1.placeholder(tf.string, shape=[])

        # # Print name and shape of each tensor.
        # tf.compat.v1.logging.info("++++++++++++++++++++++++++++++++++")
        # tf.compat.v1.logging.info("Layers")
        # tf.compat.v1.logging.info("++++++++++++++++++++++++++++++++++")
        # for k, v in end_points.items():
        #     tf.compat.v1.logging.info('name = %s, shape = %s' % (v.name, v.get_shape()))
        #

        # Gather initial summaries.
        summaries = set(
            tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.SUMMARIES))
        # # Add summaries for model variables.
        # for variable in slim.get_model_variables():
        #     summaries.add(tf.compat.v1.summary.histogram(variable.op.name, variable))
        #
        # # Add summaries for losses.
        # for loss in tf.compat.v1.get_collection(tf.GraphKeys.LOSSES):
        #     summaries.add(tf.compat.v1.summary.scalar('losses/%s' % loss.op.name, loss))

        learning_rate = train_utils.get_model_learning_rate(
            FLAGS.learning_policy, FLAGS.base_learning_rate,
            FLAGS.learning_rate_decay_step, FLAGS.learning_rate_decay_factor,
            FLAGS.training_number_of_steps, FLAGS.learning_power,
            FLAGS.slow_start_step, FLAGS.slow_start_learning_rate)
        summaries.add(
            tf.compat.v1.summary.scalar('learning_rate', learning_rate))

        # optimizers = \
        #     [tf.train.RMSPropOptimizer(learning_rate, decay=0.9, momentum=0.9) for _ in range(FLAGS.num_gpu)]
        # optimizers = \
        #     [tf.compat.v1.train.MomentumOptimizer(learning_rate, FLAGS.momentum) for _ in range(FLAGS.num_gpu)]
        optimizers = \
            [tf.compat.v1.train.GradientDescentOptimizer(learning_rate) for _ in range(FLAGS.num_gpu)]

        logits = []
        losses = []
        grad_list = []
        filename_batch = []
        image_batch = []
        gt_batch = []
        for gpu_idx in range(FLAGS.num_gpu):
            tf.compat.v1.logging.info('creating gpu tower @ %d' %
                                      (gpu_idx + 1))
            image_batch.append(X)
            gt_batch.append(ground_truth)

            scope_name = 'tower%d' % gpu_idx
            with tf.device(tf.DeviceSpec(device_type="GPU", device_index=gpu_idx)), \
                 tf.compat.v1.variable_scope(scope_name):
                # apply SENet
                _, logit = model.deep_cosine_softmax(
                    X,
                    num_classes=num_classes,
                    is_training=is_training,
                    is_reuse=False,
                    keep_prob=keep_prob,
                    attention_module='se_block')

                # # Print name and shape of parameter nodes  (values not yet initialized)
                tf.compat.v1.logging.info("++++++++++++++++++++++++++++++++++")
                tf.compat.v1.logging.info("Parameters")
                tf.compat.v1.logging.info("++++++++++++++++++++++++++++++++++")
                for v in slim.get_model_variables():
                    tf.compat.v1.logging.info('name = %s, shape = %s' %
                                              (v.name, v.get_shape()))

                # # TTA
                # logit = tf.cond(is_training,
                #                 lambda: tf.identity(logit),
                #                 lambda: tf.reduce_mean(tf.reshape(logit, [FLAGS.val_batch_size // FLAGS.num_gpu, TEN_CROP, -1]), axis=1))
                logits.append(logit)

                l = tf.nn.sparse_softmax_cross_entropy_with_logits(
                    labels=ground_truth, logits=logit)
                losses.append(l)
                loss_w_reg = tf.reduce_sum(l) + tf.add_n(
                    slim.losses.get_regularization_losses(scope=scope_name))

                grad_list.append([
                    x
                    for x in optimizers[gpu_idx].compute_gradients(loss_w_reg)
                    if x[0] is not None
                ])

        y_hat = tf.concat(logits, axis=0)
        image_batch = tf.concat(image_batch, axis=0)
        gt_batch = tf.concat(gt_batch, axis=0)

        # Acc
        top1_acc = tf.reduce_mean(
            tf.cast(tf.nn.in_top_k(y_hat, gt_batch, k=1), dtype=tf.float32))
        summaries.add(tf.compat.v1.summary.scalar('top1_acc', top1_acc))
        # top5_acc = tf.reduce_mean(
        #     tf.cast(tf.nn.in_top_k(y_hat, gt_batch, k=5), dtype=tf.float32)
        # )
        # summaries.add(tf.compat.v1.summary.scalar('top5_acc', top5_acc))
        prediction = tf.argmax(y_hat, axis=1, name='prediction')
        confusion_matrix = tf.math.confusion_matrix(gt_batch,
                                                    prediction,
                                                    num_classes=num_classes)
        confusion_matrix = tf.div(confusion_matrix, FLAGS.num_gpu)

        loss = tf.reduce_mean(losses)
        loss = tf.compat.v1.check_numerics(loss, 'Loss is inf or nan.')
        summaries.add(tf.compat.v1.summary.scalar('loss', loss))

        # use NCCL
        grads, all_vars = train_helper.split_grad_list(grad_list)
        reduced_grad = train_helper.allreduce_grads(grads, average=True)
        grads = train_helper.merge_grad_list(reduced_grad, all_vars)

        # optimizer using NCCL
        train_ops = []
        for idx, grad_and_vars in enumerate(grads):
            # apply_gradients may create variables. Make them LOCAL_VARIABLESZ¸¸¸¸¸¸
            with tf.name_scope('apply_gradients'), tf.device(
                    tf.DeviceSpec(device_type="GPU", device_index=idx)):
                update_ops = tf.compat.v1.get_collection(
                    tf.compat.v1.GraphKeys.UPDATE_OPS, scope='tower%d' % idx)
                with tf.control_dependencies(update_ops):
                    train_ops.append(optimizers[idx].apply_gradients(
                        grad_and_vars,
                        name='apply_grad_{}'.format(idx),
                        global_step=global_step))
                # TODO:
                # TensorBoard: How to plot histogram for gradients
                # grad_summ_op = tf.summary.merge([tf.summary.histogram("%s-grad" % g[1].name, g[0]) for g in grads_and_vars])

        optimize_op = tf.group(*train_ops, name='train_op')

        sync_op = train_helper.get_post_init_ops()

        # Create a saver object which will save all the variables
        saver = tf.compat.v1.train.Saver()
        best_ckpt_saver = BestCheckpointSaver(save_dir=FLAGS.train_logdir,
                                              num_to_keep=100,
                                              maximize=False,
                                              saver=saver)
        best_val_loss = 99999
        best_val_acc = 0

        start_epoch = 0
        epoch_count = tf.Variable(start_epoch, trainable=False)
        epoch_count_add = tf.assign(epoch_count, epoch_count + 1)

        ###############
        # Prepare data
        ###############
        # training dateset
        tr_dataset = train_data.Dataset(tfrecord_filenames,
                                        FLAGS.batch_size // FLAGS.num_gpu,
                                        num_classes,
                                        FLAGS.how_many_training_epochs,
                                        TRAIN_DATA_SIZE, FLAGS.height,
                                        FLAGS.width)
        iterator = tr_dataset.dataset.make_initializable_iterator()
        next_batch = iterator.get_next()

        # validation dateset
        val_dataset = val_data.Dataset(tfrecord_filenames,
                                       FLAGS.val_batch_size // FLAGS.num_gpu,
                                       num_classes,
                                       FLAGS.how_many_training_epochs,
                                       VALIDATE_DATA_SIZE, FLAGS.height,
                                       FLAGS.width)
        # 256,  # 256 ~ 480
        # 256)
        val_iterator = val_dataset.dataset.make_initializable_iterator()
        val_next_batch = val_iterator.get_next()

        sess_config = tf.compat.v1.ConfigProto(
            gpu_options=tf.compat.v1.GPUOptions(allow_growth=True))
        with tf.compat.v1.Session(config=sess_config) as sess:
            sess.run(tf.compat.v1.global_variables_initializer())

            # Add the summaries. These contain the summaries
            # created by model and either optimize() or _gather_loss().
            summaries |= set(
                tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.SUMMARIES))

            # Merge all summaries together.
            summary_op = tf.compat.v1.summary.merge(list(summaries))
            train_writer = tf.compat.v1.summary.FileWriter(
                FLAGS.summaries_dir, graph)
            validation_writer = tf.compat.v1.summary.FileWriter(
                FLAGS.summaries_dir + '/validation', graph)

            # TODO: supports multi gpu -> add scope ('tower%d' % gpu_idx)
            if FLAGS.pre_trained_checkpoint:
                train_utils.restore_fn(FLAGS)

            if FLAGS.saved_checkpoint_dir:
                if tf.gfile.IsDirectory(FLAGS.saved_checkpoint_dir):
                    checkpoint_path = tf.train.latest_checkpoint(
                        FLAGS.saved_checkpoint_dir)
                else:
                    checkpoint_path = FLAGS.saved_checkpoint_dir
                saver.restore(sess, checkpoint_path)

            # global_step = checkpoint_path.split('/')[-1].split('-')[-1]

            sess.run(sync_op)

            # Get the number of training/validation steps per epoch
            tr_batches = int(TRAIN_DATA_SIZE /
                             (FLAGS.batch_size // FLAGS.num_gpu))
            if TRAIN_DATA_SIZE % (FLAGS.batch_size // FLAGS.num_gpu) > 0:
                tr_batches += 1
            val_batches = int(VALIDATE_DATA_SIZE /
                              (FLAGS.val_batch_size // FLAGS.num_gpu))
            if VALIDATE_DATA_SIZE % (FLAGS.val_batch_size //
                                     FLAGS.num_gpu) > 0:
                val_batches += 1

            # The filenames argument to the TFRecordDataset initializer can either be a string,
            # a list of strings, or a tf.Tensor of strings.
            train_record_filenames = os.path.join(FLAGS.dataset_dir,
                                                  'train.record')
            validate_record_filenames = os.path.join(FLAGS.dataset_dir,
                                                     'validate.record')

            ############################
            # Training loop.
            ############################
            for num_epoch in range(start_epoch,
                                   FLAGS.how_many_training_epochs):
                print("------------------------------------")
                print(" Epoch {} ".format(num_epoch))
                print("------------------------------------")

                sess.run(epoch_count_add)
                sess.run(
                    iterator.initializer,
                    feed_dict={tfrecord_filenames: train_record_filenames})
                for step in range(tr_batches):
                    filenames, train_batch_xs, train_batch_ys = sess.run(
                        next_batch)
                    # show_batch_data(filenames, train_batch_xs, train_batch_ys)
                    #
                    # augmented_batch_xs = aug_utils.aug(train_batch_xs)
                    # show_batch_data(filenames, augmented_batch_xs,
                    #                 train_batch_ys, 'aug')

                    # Run the graph with this batch of training data and learning rate policy.
                    lr, train_summary, train_top1_acc, train_loss, _ = \
                        sess.run([learning_rate, summary_op, top1_acc, loss, optimize_op],
                                 feed_dict={
                                     X: train_batch_xs,
                                     ground_truth: train_batch_ys,
                                     is_training: True,
                                     keep_prob: 0.8
                                 })
                    train_writer.add_summary(train_summary, num_epoch)
                    # train_writer.add_summary(grad_vals, num_epoch)
                    tf.compat.v1.logging.info(
                        'Epoch #%d, Step #%d, rate %.6f, top1_acc %.3f%%, loss %.5f'
                        % (num_epoch, step, lr, train_top1_acc, train_loss))

                ###################################################
                # Validate the model on the validation set
                ###################################################
                tf.compat.v1.logging.info('--------------------------')
                tf.compat.v1.logging.info(' Start validation ')
                tf.compat.v1.logging.info('--------------------------')

                total_val_losses = 0.0
                total_val_top1_acc = 0.0
                val_count = 0
                total_conf_matrix = None

                sess.run(
                    val_iterator.initializer,
                    feed_dict={tfrecord_filenames: validate_record_filenames})
                for step in range(val_batches):
                    filenames, validation_batch_xs, validation_batch_ys = sess.run(
                        val_next_batch)
                    # # TTA
                    # batch_size, n_crops, c, h, w = validation_batch_xs.shape
                    # # fuse batch size and ncrops
                    # tencrop_val_batch_xs = np.reshape(validation_batch_xs, (-1, c, h, w))
                    # show_batch_data(filenames, tencrop_val_batch_xs, validation_batch_ys)

                    # augmented_val_batch_xs = aug_utils.aug(tencrop_val_batch_xs)
                    # show_batch_data(filenames, augmented_val_batch_xs,
                    #                 validation_batch_ys, 'aug')

                    val_summary, val_loss, val_top1_acc, _confusion_matrix = sess.run(
                        [summary_op, loss, top1_acc, confusion_matrix],
                        feed_dict={
                            X: validation_batch_xs,
                            ground_truth: validation_batch_ys,
                            is_training: False,
                            keep_prob: 1.0
                        })
                    validation_writer.add_summary(val_summary, num_epoch)

                    total_val_losses += val_loss
                    total_val_top1_acc += val_top1_acc

                    # total_val_accuracy += val_top1_acc
                    val_count += 1
                    if total_conf_matrix is None:
                        total_conf_matrix = _confusion_matrix
                    else:
                        total_conf_matrix += _confusion_matrix

                total_val_losses /= val_count
                total_val_top1_acc /= val_count

                # total_val_accuracy /= val_count
                tf.compat.v1.logging.info('Confusion Matrix:\n %s' %
                                          total_conf_matrix)
                tf.compat.v1.logging.info('Validation loss = %.5f' %
                                          total_val_losses)
                tf.compat.v1.logging.info(
                    'Validation top1 accuracy = %.3f%% (N=%d)' %
                    (total_val_top1_acc, VALIDATE_DATA_SIZE))

                # periodic synchronization
                sess.run(sync_op)

                # Save the model checkpoint periodically.
                if (num_epoch <= FLAGS.how_many_training_epochs - 1):
                    # best_checkpoint_path = os.path.join(FLAGS.train_logdir, 'best_' + FLAGS.ckpt_name_to_save)
                    # tf.compat.v1.logging.info('Saving to "%s"', best_checkpoint_path)
                    # saver.save(sess, best_checkpoint_path, global_step=num_epoch)

                    # save & keep best model wrt. validation loss
                    best_ckpt_saver.handle(total_val_losses, sess, epoch_count)

                    if best_val_loss > total_val_losses:
                        best_val_loss = total_val_losses
                        best_val_acc = total_val_top1_acc

            chk_path = get_best_checkpoint(FLAGS.train_logdir,
                                           select_maximum_value=False)
            tf.compat.v1.logging.info(
                'training done. best_model val_loss=%.5f, top1_acc=%.3f%%, ckpt=%s'
                % (best_val_loss, best_val_acc, chk_path))