Example #1
0
def _train_deeplab_model(iterator, num_of_classes, ignore_label):
    """Trains the deeplab model.

  Args:
    iterator: An iterator of type tf.data.Iterator for images and labels.
    num_of_classes: Number of classes for the dataset.
    ignore_label: Ignore label for the dataset.

  Returns:
    train_tensor: A tensor to update the model variables.
    summary_op: An operation to log the summaries.
  """
    global_step = tf.train.get_or_create_global_step()

    learning_rate = train_utils.get_model_learning_rate(
        FLAGS.learning_policy, FLAGS.base_learning_rate,
        FLAGS.learning_rate_decay_step, FLAGS.learning_rate_decay_factor,
        FLAGS.training_number_of_steps, FLAGS.learning_power,
        FLAGS.slow_start_step, FLAGS.slow_start_learning_rate)
    tf.summary.scalar('learning_rate', learning_rate)

    optimizer = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum)

    tower_losses = []
    tower_grads = []
    for i in range(FLAGS.num_clones):
        with tf.device('/gpu:%d' % i):
            # First tower has default name scope.
            name_scope = ('clone_%d' % i) if i else ''
            with tf.name_scope(name_scope) as scope:
                loss = _tower_loss(iterator=iterator,
                                   num_of_classes=num_of_classes,
                                   ignore_label=ignore_label,
                                   scope=scope,
                                   reuse_variable=(i != 0))
                tower_losses.append(loss)

    if FLAGS.quantize_delay_step >= 0:
        if FLAGS.num_clones > 1:
            raise ValueError('Quantization doesn\'t support multi-clone yet.')
        tf.contrib.quantize.create_training_graph(
            quant_delay=FLAGS.quantize_delay_step)

    for i in range(FLAGS.num_clones):
        with tf.device('/gpu:%d' % i):
            name_scope = ('clone_%d' % i) if i else ''
            with tf.name_scope(name_scope) as scope:
                grads = optimizer.compute_gradients(tower_losses[i])
                tower_grads.append(grads)

    with tf.device('/cpu:0'):
        grads_and_vars = _average_gradients(tower_grads)

        # Modify the gradients for biases and last layer variables.
        last_layers = model.get_extra_layer_scopes(
            FLAGS.last_layers_contain_logits_only)
        grad_mult = train_utils.get_model_gradient_multipliers(
            last_layers, FLAGS.last_layer_gradient_multiplier)
        if grad_mult:
            grads_and_vars = tf.contrib.training.multiply_gradients(
                grads_and_vars, grad_mult)

        # Create gradient update op.
        grad_updates = optimizer.apply_gradients(grads_and_vars,
                                                 global_step=global_step)

        # Gather update_ops. These contain, for example,
        # the updates for the batch_norm variables created by model_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        update_ops.append(grad_updates)
        update_op = tf.group(*update_ops)

        total_loss = tf.losses.get_total_loss(add_regularization_losses=True)

        # Print total loss to the terminal.
        # This implementation is mirrored from tf.slim.summaries.
        should_log = math_ops.equal(math_ops.mod(global_step, FLAGS.log_steps),
                                    0)
        total_loss = tf.cond(
            should_log,
            lambda: tf.Print(total_loss, [total_loss], 'Total loss is :'),
            lambda: total_loss)

        tf.summary.scalar('total_loss', total_loss)
        with tf.control_dependencies([update_op]):
            train_tensor = tf.identity(total_loss, name='train_op')

        # Excludes summaries from towers other than the first one.
        summary_op = tf.summary.merge_all(scope='(?!clone_)')

    return train_tensor, summary_op
Example #2
0
def main(unused_argv):
    tf.logging.set_verbosity(tf.logging.INFO)

    labels = FLAGS.labels.split(',')
    num_classes = len(labels)

    tf.gfile.MakeDirs(FLAGS.train_logdir)
    tf.logging.info('Creating train logdir: %s', FLAGS.train_logdir)

    with tf.Graph().as_default() as graph:
        global_step = tf.train.get_or_create_global_step()

        X = tf.placeholder(tf.float32, [None, FLAGS.height, FLAGS.width, 3],
                           name='X')
        ground_truth = tf.placeholder(tf.int64, [None], name='ground_truth')
        is_training = tf.placeholder(tf.bool, name='is_training')
        keep_prob = tf.placeholder(tf.float32, [], name='keep_prob')
        # learning_rate = tf.placeholder(tf.float32, [])

        # apply SENet
        logits, end_points = model.hcd_model(X,
                                             num_classes=num_classes,
                                             is_training=is_training,
                                             keep_prob=keep_prob,
                                             attention_module='se_block')

        logits = tf.cond(
            is_training,
            lambda: tf.identity(logits), lambda: tf.reduce_mean(tf.reshape(
                logits, [FLAGS.val_batch_size, TEN_CROP, -1]),
                                                                axis=1))

        # Print name and shape of each tensor.
        tf.logging.info("++++++++++++++++++++++++++++++++++")
        tf.logging.info("Layers")
        tf.logging.info("++++++++++++++++++++++++++++++++++")
        for k, v in end_points.items():
            tf.logging.info('name = %s, shape = %s' % (v.name, v.get_shape()))

        # # Print name and shape of parameter nodes  (values not yet initialized)
        # tf.logging.info("++++++++++++++++++++++++++++++++++")
        # tf.logging.info("Parameters")
        # tf.logging.info("++++++++++++++++++++++++++++++++++")
        # for v in slim.get_model_variables():
        #     tf.logging.info('name = %s, shape = %s' % (v.name, v.get_shape()))

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        prediction = tf.argmax(logits, axis=1, name='prediction')
        correct_prediction = tf.equal(prediction, ground_truth)
        confusion_matrix = tf.confusion_matrix(ground_truth,
                                               prediction,
                                               num_classes=num_classes)
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32),
                                  name='accuracy')
        summaries.add(tf.summary.scalar('accuracy', accuracy))

        # Define loss
        tf.losses.sparse_softmax_cross_entropy(labels=ground_truth,
                                               logits=logits)

        # Gather update_ops. These contain, for example,
        # the updates for the batch_norm variables created by model.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

        # # Add summaries for model variables.
        # for model_var in slim.get_model_variables():
        #     summaries.add(tf.summary.histogram(model_var.op.name, model_var))

        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

        learning_rate = train_utils.get_model_learning_rate(
            FLAGS.learning_policy, FLAGS.base_learning_rate,
            FLAGS.learning_rate_decay_step, FLAGS.learning_rate_decay_factor,
            FLAGS.training_number_of_steps, FLAGS.learning_power,
            FLAGS.slow_start_step, FLAGS.slow_start_learning_rate)
        # optimizer = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum)
        optimizer = tf.train.AdamOptimizer(learning_rate)
        summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        total_loss, grads_and_vars = train_utils.optimize(optimizer)
        total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.')
        summaries.add(tf.summary.scalar('total_loss', total_loss))

        # # Modify the gradients for biases and last layer variables.
        # last_layers = train_utils.get_extra_layer_scopes(
        #     FLAGS.last_layers_contain_logits_only)
        # grad_mult = train_utils.get_model_gradient_multipliers(
        #     last_layers, FLAGS.last_layer_gradient_multiplier)
        # if grad_mult:
        #     grads_and_vars = slim.learning.multiply_gradients(
        #         grads_and_vars, grad_mult)

        # Gradient clipping
        # clipped_gvs = [(tf.clip_by_value(grad, -1., 1.), var) for grad, var in grads_and_vars]
        # Otherwise ->
        # gradients, variables = zip(*optimizer.compute_gradients(loss))
        # gradients, _ = tf.clip_by_global_norm(grads_and_vars[0], 5.0)
        # optimize = optimizer.apply_gradients(zip(gradients, grads_and_vars[1]))

        # TensorBoard: How to plot histogram for gradients
        grad_summ_op = tf.summary.merge([
            tf.summary.histogram("%s-grad" % g[1].name, g[0])
            for g in grads_and_vars
        ])

        # Create gradient update op.
        grad_updates = optimizer.apply_gradients(grads_and_vars,
                                                 global_step=global_step)
        update_ops.append(grad_updates)
        update_op = tf.group(*update_ops)
        with tf.control_dependencies([update_op]):
            train_op = tf.identity(total_loss, name='train_op')

        # Add the summaries. These contain the summaries
        # created by model and either optimize() or _gather_loss().
        summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries))
        train_writer = tf.summary.FileWriter(FLAGS.summaries_dir, graph)
        validation_writer = tf.summary.FileWriter(
            FLAGS.summaries_dir + '/validation', graph)

        ###############
        # Prepare data
        ###############
        # training dateset
        tfrecord_filenames = tf.placeholder(tf.string, shape=[])
        dataset = train_data.Dataset(tfrecord_filenames, FLAGS.batch_size,
                                     FLAGS.how_many_training_epochs,
                                     FLAGS.height, FLAGS.width)
        iterator = dataset.dataset.make_initializable_iterator()
        next_batch = iterator.get_next()

        # validation dateset
        val_dataset = val_data.Dataset(tfrecord_filenames,
                                       FLAGS.val_batch_size, FLAGS.height,
                                       FLAGS.width)
        val_iterator = val_dataset.dataset.make_initializable_iterator()
        val_next_batch = val_iterator.get_next()

        sess_config = tf.ConfigProto(gpu_options=tf.GPUOptions(
            allow_growth=True))
        with tf.Session(config=sess_config) as sess:
            sess.run(tf.global_variables_initializer())

            # Create a saver object which will save all the variables
            saver = tf.train.Saver()
            if FLAGS.saved_checkpoint_dir:
                if tf.gfile.IsDirectory(FLAGS.train_logdir):
                    checkpoint_path = tf.train.latest_checkpoint(
                        FLAGS.train_logdir)
                else:
                    checkpoint_path = FLAGS.train_logdir
                saver.restore(sess, checkpoint_path)

            if FLAGS.pre_trained_checkpoint:
                train_utils.restore_fn(FLAGS)

            start_epoch = 0
            # Get the number of training/validation steps per epoch
            tr_batches = int(PCAM_TRAIN_DATA_SIZE / FLAGS.batch_size)
            if PCAM_TRAIN_DATA_SIZE % FLAGS.batch_size > 0:
                tr_batches += 1
            val_batches = int(PCAM_VALIDATE_DATA_SIZE / FLAGS.val_batch_size)
            if PCAM_VALIDATE_DATA_SIZE % FLAGS.val_batch_size > 0:
                val_batches += 1

            # The filenames argument to the TFRecordDataset initializer can either be a string,
            # a list of strings, or a tf.Tensor of strings.
            train_record_filenames = os.path.join(FLAGS.dataset_dir,
                                                  'train.record')
            validate_record_filenames = os.path.join(FLAGS.dataset_dir,
                                                     'validate.record')
            ############################
            # Training loop.
            ############################
            for num_epoch in range(start_epoch,
                                   FLAGS.how_many_training_epochs):
                print("------------------------------------")
                print(" Epoch {} ".format(num_epoch))
                print("------------------------------------")

                sess.run(
                    iterator.initializer,
                    feed_dict={tfrecord_filenames: train_record_filenames})
                for step in range(tr_batches):
                    train_batch_xs, train_batch_ys = sess.run(next_batch)
                    # # Verify image
                    # # assert not np.any(np.isnan(train_batch_xs))
                    # n_batch = train_batch_xs.shape[0]
                    # # n_view = train_batch_xs.shape[1]
                    # for i in range(n_batch):
                    #     img = train_batch_xs[i]
                    #     # scipy.misc.toimage(img).show() Or
                    #     img = cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_BGR2RGB)
                    #     cv2.imwrite('/home/ace19/Pictures/' + str(i) + '.png', img)
                    #     # cv2.imshow(str(train_batch_ys[idx]), img)
                    #     cv2.waitKey(100)
                    #     cv2.destroyAllWindows()

                    augmented_batch_xs = aug_utils.aug(train_batch_xs)
                    # # Verify image
                    # # assert not np.any(np.isnan(train_batch_xs))
                    # n_batch = augmented_batch_xs.shape[0]
                    # # n_view = train_batch_xs.shape[1]
                    # for i in range(n_batch):
                    #     img = augmented_batch_xs[i]
                    #     # scipy.misc.toimage(img).show() Or
                    #     img = cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_BGR2RGB)
                    #     cv2.imwrite('/home/ace19/Pictures/' + str(i) + '.png', img)
                    #     # cv2.imshow(str(train_batch_ys[idx]), img)
                    #     cv2.waitKey(100)
                    #     cv2.destroyAllWindows()

                    # Run the graph with this batch of training data and learning rate policy.
                    lr, train_summary, train_accuracy, train_loss, grad_vals, _ = \
                        sess.run([learning_rate, summary_op, accuracy, total_loss, grad_summ_op, train_op],
                                 feed_dict={
                                     X: augmented_batch_xs,
                                     ground_truth: train_batch_ys,
                                     is_training: True,
                                     keep_prob: 0.8
                                 })
                    train_writer.add_summary(train_summary, num_epoch)
                    train_writer.add_summary(grad_vals, num_epoch)
                    tf.logging.info(
                        'Epoch #%d, Step #%d, rate %.10f, accuracy %.1f%%, loss %f'
                        % (num_epoch, step, lr, train_accuracy * 100,
                           train_loss))

                ###################################################
                # Validate the model on the validation set
                ###################################################
                tf.logging.info('--------------------------')
                tf.logging.info(' Start validation ')
                tf.logging.info('--------------------------')

                total_val_accuracy = 0
                validation_count = 0
                total_conf_matrix = None
                # Reinitialize iterator with the validation dataset
                sess.run(
                    val_iterator.initializer,
                    feed_dict={tfrecord_filenames: validate_record_filenames})
                for step in range(val_batches):
                    validation_batch_xs, validation_batch_ys = sess.run(
                        val_next_batch)
                    # TTA
                    batch_size, n_crops, c, h, w = validation_batch_xs.shape
                    # fuse batch size and ncrops
                    tencrop_val_batch_xs = np.reshape(validation_batch_xs,
                                                      (-1, c, h, w))

                    val_summary, val_accuracy, conf_matrix = sess.run(
                        [summary_op, accuracy, confusion_matrix],
                        feed_dict={
                            X: tencrop_val_batch_xs,
                            ground_truth: validation_batch_ys,
                            is_training: False,
                            keep_prob: 1.0
                        })

                    validation_writer.add_summary(val_summary, num_epoch)

                    total_val_accuracy += val_accuracy
                    validation_count += 1
                    if total_conf_matrix is None:
                        total_conf_matrix = conf_matrix
                    else:
                        total_conf_matrix += conf_matrix

                total_val_accuracy /= validation_count
                tf.logging.info('Confusion Matrix:\n %s' % (total_conf_matrix))
                tf.logging.info(
                    'Validation accuracy = %.1f%% (N=%d)' %
                    (total_val_accuracy * 100, PCAM_VALIDATE_DATA_SIZE))

                # Save the model checkpoint periodically.
                if (num_epoch <= FLAGS.how_many_training_epochs - 1):
                    checkpoint_path = os.path.join(FLAGS.train_logdir,
                                                   FLAGS.ckpt_name_to_save)
                    tf.logging.info('Saving to "%s-%d"', checkpoint_path,
                                    num_epoch)
                    saver.save(sess, checkpoint_path, global_step=num_epoch)
Example #3
0
def main(unused_argv):
    tf.logging.set_verbosity(tf.logging.INFO)

    labels = FLAGS.labels.split(',')
    num_classes = len(labels)

    tf.gfile.MakeDirs(FLAGS.train_logdir)
    tf.logging.info('Creating train logdir: %s', FLAGS.train_logdir)

    with tf.Graph().as_default() as graph:
        global_step = tf.train.get_or_create_global_step()

        # Define the model
        X = tf.placeholder(
            tf.float32, [None, FLAGS.num_views, FLAGS.height, FLAGS.width, 3],
            name='X')
        # for 299 size, otherwise you should modify shape for ur size.
        final_X = tf.placeholder(tf.float32,
                                 [FLAGS.num_views, None, 8, 8, 1536],
                                 name='final_X')
        ground_truth = tf.placeholder(tf.int64, [None], name='ground_truth')
        is_training = tf.placeholder(tf.bool)
        is_training2 = tf.placeholder(tf.bool)
        dropout_keep_prob = tf.placeholder(tf.float32)
        grouping_scheme = tf.placeholder(tf.bool, [NUM_GROUP, FLAGS.num_views])
        grouping_weight = tf.placeholder(tf.float32, [NUM_GROUP, 1])
        # learning_rate = tf.placeholder(tf.float32)

        # Grouping Module
        d_scores, _, final_desc = gvcnn.discrimination_score(
            X, num_classes, is_training)

        # GVCNN
        logits, _ = gvcnn.gvcnn(final_X, grouping_scheme, grouping_weight,
                                num_classes, is_training2, dropout_keep_prob)

        # Define loss
        tf.reduce_mean(
            tf.losses.sparse_softmax_cross_entropy(labels=ground_truth,
                                                   logits=logits))

        # Gather update_ops. These contain, for example,
        # the updates for the batch_norm variables created by model.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        prediction = tf.argmax(logits, 1, name='prediction')
        correct_prediction = tf.equal(prediction, ground_truth)
        confusion_matrix = tf.confusion_matrix(ground_truth,
                                               prediction,
                                               num_classes=num_classes)
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
        summaries.add(tf.summary.scalar('accuracy', accuracy))

        # Add summaries for model variables.
        for model_var in slim.get_model_variables():
            summaries.add(tf.summary.histogram(model_var.op.name, model_var))

        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

        learning_rate = train_utils.get_model_learning_rate(
            FLAGS.learning_policy, FLAGS.base_learning_rate,
            FLAGS.learning_rate_decay_step, FLAGS.learning_rate_decay_factor,
            FLAGS.training_number_of_steps, FLAGS.learning_power,
            FLAGS.slow_start_step, FLAGS.slow_start_learning_rate)
        # optimizer = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum)
        optimizer = tf.train.AdamOptimizer(learning_rate)
        summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        # for variable in slim.get_model_variables():
        #     summaries.add(tf.summary.histogram(variable.op.name, variable))

        total_loss, grads_and_vars = train_utils.optimize(optimizer)
        total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.')
        summaries.add(tf.summary.scalar('total_loss', total_loss))

        # # Modify the gradients for biases and last layer variables.
        # last_layers = train_utils.get_extra_layer_scopes(
        #     FLAGS.last_layers_contain_logits_only)
        # grad_mult = train_utils.get_model_gradient_multipliers(
        #     last_layers, FLAGS.last_layer_gradient_multiplier)
        # if grad_mult:
        #     grads_and_vars = slim.learning.multiply_gradients(
        #         grads_and_vars, grad_mult)

        grad_summ_op = tf.summary.merge([
            tf.summary.histogram("%s-grad" % g[1].name, g[0])
            for g in grads_and_vars
        ])

        # Create gradient update op.
        grad_updates = optimizer.apply_gradients(grads_and_vars,
                                                 global_step=global_step)
        update_ops.append(grad_updates)
        update_op = tf.group(*update_ops)
        with tf.control_dependencies([update_op]):
            train_op = tf.identity(total_loss, name='train_op')

        # Add the summaries. These contain the summaries
        # created by model and either optimize() or _gather_loss().
        summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries))
        train_writer = tf.summary.FileWriter(FLAGS.summaries_dir, graph)
        validation_writer = tf.summary.FileWriter(
            FLAGS.summaries_dir + '/validation', graph)

        ################
        # Prepare data
        ################
        filenames = tf.placeholder(tf.string, shape=[])
        tr_dataset = data.Dataset(filenames, FLAGS.num_views, FLAGS.height,
                                  FLAGS.width, FLAGS.batch_size)
        iterator = tr_dataset.dataset.make_initializable_iterator()
        next_batch = iterator.get_next()

        sess_config = tf.ConfigProto(gpu_options=tf.GPUOptions(
            allow_growth=True))
        with tf.Session(config=sess_config) as sess:
            sess.run(tf.global_variables_initializer())

            # TODO:
            # Create a saver object which will save all the variables
            saver = tf.train.Saver(keep_checkpoint_every_n_hours=1.0)
            if FLAGS.pre_trained_checkpoint:
                train_utils.restore_fn(FLAGS)

            start_epoch = 0
            # Get the number of training/validation steps per epoch
            tr_batches = int(MODELNET_TRAIN_DATA_SIZE / FLAGS.batch_size)
            if MODELNET_TRAIN_DATA_SIZE % FLAGS.batch_size > 0:
                tr_batches += 1
            val_batches = int(MODELNET_VALIDATE_DATA_SIZE / FLAGS.batch_size)
            if MODELNET_VALIDATE_DATA_SIZE % FLAGS.batch_size > 0:
                val_batches += 1

            # The filenames argument to the TFRecordDataset initializer can either be a string,
            # a list of strings, or a tf.Tensor of strings.
            training_filenames = os.path.join(FLAGS.dataset_dir,
                                              'train.record')
            validate_filenames = os.path.join(FLAGS.dataset_dir,
                                              'validate.record')
            ##################
            # Training loop.
            ##################
            for training_epoch in range(start_epoch,
                                        FLAGS.how_many_training_epochs):
                print("-------------------------------------")
                print(" Epoch {} ".format(training_epoch))
                print("-------------------------------------")

                sess.run(iterator.initializer,
                         feed_dict={filenames: training_filenames})
                for step in range(tr_batches):
                    # Pull the image batch we'll use for training.
                    train_batch_xs, train_batch_ys = sess.run(next_batch)

                    # # Verify image
                    # assert not np.any(np.isnan(train_batch_xs))
                    # n_batch = train_batch_xs.shape[0]
                    # n_view = train_batch_xs.shape[1]
                    # for i in range(n_batch):
                    #     for j in range(n_view):
                    #         img = train_batch_xs[i][j]
                    #         # scipy.misc.toimage(img).show()
                    #         # Or
                    #         img = cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_BGR2RGB)
                    #         cv2.imwrite('/home/ace19/Pictures/' + str(i) +
                    #                     '_' + str(j) + '.png', img)
                    #         # cv2.imshow(str(train_batch_ys[idx]), img)
                    #         cv2.waitKey(100)
                    #         cv2.destroyAllWindows()

                    # Sets up a graph with feeds and fetches for partial run.
                    handle = sess.partial_run_setup([
                        d_scores, final_desc, learning_rate, summary_op,
                        accuracy, total_loss, grad_summ_op, train_op
                    ], [
                        X, final_X, ground_truth, grouping_scheme,
                        grouping_weight, is_training, is_training2,
                        dropout_keep_prob
                    ])

                    scores, final = sess.partial_run(handle,
                                                     [d_scores, final_desc],
                                                     feed_dict={
                                                         X: train_batch_xs,
                                                         is_training: True
                                                     })
                    schemes = gvcnn.grouping_scheme(scores, NUM_GROUP,
                                                    FLAGS.num_views)
                    weights = gvcnn.grouping_weight(scores, schemes)

                    # Run the graph with this batch of training data.
                    lr, train_summary, train_accuracy, train_loss, grad_vals, _ = \
                        sess.partial_run(handle,
                                         [learning_rate, summary_op, accuracy, total_loss, grad_summ_op, train_op],
                                         feed_dict={
                                             final_X: final,
                                             ground_truth: train_batch_ys,
                                             grouping_scheme: schemes,
                                             grouping_weight: weights,
                                             is_training2: True,
                                             dropout_keep_prob: 0.8}
                                         )

                    train_writer.add_summary(train_summary, training_epoch)
                    train_writer.add_summary(grad_vals, training_epoch)
                    tf.logging.info(
                        'Epoch #%d, Step #%d, rate %.10f, accuracy %.1f%%, loss %f'
                        % (training_epoch, step, lr, train_accuracy * 100,
                           train_loss))

                ###################################################
                # Validate the model on the validation set
                ###################################################
                # tf.logging.info('--------------------------')
                # tf.logging.info(' Start validation ')
                # tf.logging.info('--------------------------')
                #
                # # Reinitialize iterator with the validation dataset
                # sess.run(iterator.initializer, feed_dict={filenames: validate_filenames})
                # total_val_accuracy = 0
                # validation_count = 0
                # total_conf_matrix = None
                #
                # for step in range(val_batches):
                #     validation_batch_xs, validation_batch_ys = sess.run(next_batch)
                #
                #     # Sets up a graph with feeds and fetches for partial run.
                #     handle = sess.partial_run_setup([d_scores, final_desc,
                #                                      summary_op, accuracy, confusion_matrix],
                #                                     [X, final_X, ground_truth, learning_rate,
                #                                      grouping_scheme, grouping_weight, is_training,
                #                                      is_training2, dropout_keep_prob])
                #
                #     scores, final = sess.partial_run(handle,
                #                                      [d_scores, final_desc],
                #                                      feed_dict={
                #                                          X: validation_batch_xs,
                #                                          is_training: False}
                #                                      )
                #     schemes = gvcnn.grouping_scheme(scores, NUM_GROUP, FLAGS.num_views)
                #     weights = gvcnn.grouping_weight(scores, schemes)
                #
                #     # Run the graph with this batch of training data.
                #     val_summary, val_accuracy, conf_matrix = \
                #         sess.partial_run(handle,
                #                          [summary_op, accuracy, confusion_matrix],
                #                          feed_dict={
                #                              final_X: final,
                #                              ground_truth: validation_batch_ys,
                #                              grouping_scheme: schemes,
                #                              grouping_weight: weights,
                #                              is_training2: False,
                #                              dropout_keep_prob: 1.0}
                #                          )
                #
                #     validation_writer.add_summary(val_summary, training_epoch)
                #
                #     total_val_accuracy += val_accuracy
                #     validation_count += 1
                #     if total_conf_matrix is None:
                #         total_conf_matrix = conf_matrix
                #     else:
                #         total_conf_matrix += conf_matrix
                #
                #
                # total_val_accuracy /= validation_count
                # tf.logging.info('Confusion Matrix:\n %s' % (total_conf_matrix))
                # tf.logging.info('Validation accuracy = %.1f%% (N=%d)' %
                #                 (total_val_accuracy * 100, MODELNET_VALIDATE_DATA_SIZE))

                # Save the model checkpoint periodically.
                if (training_epoch <= FLAGS.how_many_training_epochs - 1):
                    checkpoint_path = os.path.join(FLAGS.train_logdir,
                                                   FLAGS.ckpt_name_to_save)
                    tf.logging.info('Saving to "%s-%d"', checkpoint_path,
                                    training_epoch)
                    saver.save(sess,
                               checkpoint_path,
                               global_step=training_epoch)
Example #4
0
def main(unused_argv):
    tf.logging.set_verbosity(tf.logging.INFO)
    # Set up deployment (i.e., multi-GPUs and/or multi-replicas).
    config = model_deploy.DeploymentConfig(num_clones=FLAGS.num_clones,
                                           clone_on_cpu=FLAGS.clone_on_cpu,
                                           replica_id=FLAGS.task,
                                           num_replicas=FLAGS.num_replicas,
                                           num_ps_tasks=FLAGS.num_ps_tasks)

    # Split the batch across GPUs.
    assert FLAGS.train_batch_size % config.num_clones == 0, (
        'Training batch size not divisble by number of clones (GPUs).')

    clone_batch_size = FLAGS.train_batch_size // config.num_clones

    # Get dataset-dependent information.
    dataset = segmentation_dataset.get_dataset(FLAGS.dataset,
                                               FLAGS.train_split,
                                               dataset_dir=FLAGS.dataset_dir)

    tf.gfile.MakeDirs(FLAGS.train_logdir)
    tf.logging.info('Training on %s set', FLAGS.train_split)

    with tf.Graph().as_default() as graph:
        with tf.device(config.inputs_device()):
            samples = input_generator.get(
                dataset,
                FLAGS.train_crop_size,
                clone_batch_size,
                min_resize_value=FLAGS.min_resize_value,
                max_resize_value=FLAGS.max_resize_value,
                resize_factor=FLAGS.resize_factor,
                min_scale_factor=FLAGS.min_scale_factor,
                max_scale_factor=FLAGS.max_scale_factor,
                scale_factor_step_size=FLAGS.scale_factor_step_size,
                dataset_split=FLAGS.train_split,
                is_training=True,
                model_variant=FLAGS.model_variant)
            inputs_queue = prefetch_queue.prefetch_queue(samples,
                                                         capacity=128 *
                                                         config.num_clones)
            #samples, capacity=12 * config.num_clones)

        # Create the global step on the device storing the variables.
        with tf.device(config.variables_device()):
            global_step = tf.train.get_or_create_global_step()

            # Define the model and create clones.
            model_fn = _build_unet

            #model_args = (inputs_queue, {
            #    common.OUTPUT_TYPE: dataset.num_classes
            #}, dataset.ignore_label)
            model_args = (inputs_queue, dataset, dataset.ignore_label)
            clones = model_deploy.create_clones(config,
                                                model_fn,
                                                args=model_args)

            # Gather update_ops from the first clone. These contain, for example,
            # the updates for the batch_norm variables created by model_fn.
            first_clone_scope = config.clone_scope(0)
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                           first_clone_scope)
            #input('stop!')

        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        # Add summaries for model variables.
        for model_var in slim.get_model_variables():
            summaries.add(tf.summary.histogram(model_var.op.name, model_var))

        # Add summaries for images, labels, semantic predictions
        if FLAGS.save_summaries_images:
            summary_image = graph.get_tensor_by_name(
                ('%s/%s:0' % (first_clone_scope, common.IMAGE)).strip('/'))
            summaries.add(
                tf.summary.image('samples/%s' % common.IMAGE, summary_image))

            first_clone_label = graph.get_tensor_by_name(
                ('%s/%s:0' % (first_clone_scope, common.LABEL)).strip('/'))
            # Scale up summary image pixel values for better visualization.
            pixel_scaling = max(1, 255 // dataset.num_classes)
            summary_label = tf.cast(first_clone_label * pixel_scaling,
                                    tf.uint8)
            summaries.add(
                tf.summary.image('samples/%s' % common.LABEL, summary_label))

            first_clone_output = graph.get_tensor_by_name(
                ('%s/%s:0' %
                 (first_clone_scope, common.OUTPUT_TYPE)).strip('/'))
            predictions = tf.expand_dims(tf.argmax(first_clone_output, 3), -1)

            summary_predictions = tf.cast(predictions * pixel_scaling,
                                          tf.uint8)
            summaries.add(
                tf.summary.image('samples/%s' % common.OUTPUT_TYPE,
                                 summary_predictions))

        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

        # Build the optimizer based on the device specification.
        with tf.device(config.optimizer_device()):
            learning_rate = train_utils.get_model_learning_rate(
                FLAGS.learning_policy, FLAGS.base_learning_rate,
                FLAGS.learning_rate_decay_step,
                FLAGS.learning_rate_decay_factor,
                FLAGS.training_number_of_steps, FLAGS.learning_power,
                FLAGS.slow_start_step, FLAGS.slow_start_learning_rate)
            optimizer = tf.train.MomentumOptimizer(learning_rate,
                                                   FLAGS.momentum)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        with tf.device(config.variables_device()):
            total_loss, grads_and_vars = model_deploy.optimize_clones(
                clones, optimizer)
            total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.')
            summaries.add(tf.summary.scalar('total_loss', total_loss))

            # Modify the gradients for biases and last layer variables.
            last_layers = model.get_extra_layer_scopes(
                FLAGS.last_layers_contain_logits_only)
            grad_mult = train_utils.get_model_gradient_multipliers(
                last_layers, FLAGS.last_layer_gradient_multiplier)
            if grad_mult:
                grads_and_vars = slim.learning.multiply_gradients(
                    grads_and_vars, grad_mult)

            # Create gradient update op.
            grad_updates = optimizer.apply_gradients(grads_and_vars,
                                                     global_step=global_step)
            update_ops.append(grad_updates)
            update_op = tf.group(*update_ops)
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(total_loss, name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries))

        # Soft placement allows placing on CPU ops without GPU implementation.
        session_config = tf.ConfigProto(allow_soft_placement=True,
                                        log_device_placement=False)
        #input('no training')
        # Start the training.
        slim.learning.train(train_tensor,
                            logdir=FLAGS.train_logdir,
                            log_every_n_steps=FLAGS.log_steps,
                            master=FLAGS.master,
                            number_of_steps=FLAGS.training_number_of_steps,
                            is_chief=(FLAGS.task == 0),
                            session_config=session_config,
                            startup_delay_steps=startup_delay_steps,
                            init_fn=train_utils.get_model_init_fn(
                                FLAGS.train_logdir,
                                FLAGS.tf_initial_checkpoint,
                                FLAGS.initialize_last_layer,
                                last_layers,
                                ignore_missing_vars=True),
                            summary_op=summary_op,
                            save_summaries_secs=FLAGS.save_summaries_secs,
                            save_interval_secs=FLAGS.save_interval_secs)
Example #5
0
def main(unused_argv):
  tf.logging.set_verbosity(tf.logging.INFO)
  # Set up deployment (i.e., multi-GPUs and/or multi-replicas).
  config = model_deploy.DeploymentConfig(
      num_clones=FLAGS.num_clones,
      clone_on_cpu=FLAGS.clone_on_cpu,
      replica_id=FLAGS.task,
      num_replicas=FLAGS.num_replicas,
      num_ps_tasks=FLAGS.num_ps_tasks)

  # Split the batch across GPUs.
  assert FLAGS.train_batch_size % config.num_clones == 0, (
      'Training batch size not divisble by number of clones (GPUs).')

  clone_batch_size = FLAGS.train_batch_size / config.num_clones

  # Get dataset-dependent information.
  dataset = segmentation_dataset.get_dataset(
      FLAGS.dataset, FLAGS.train_split, dataset_dir=FLAGS.dataset_dir)

  tf.gfile.MakeDirs(FLAGS.train_logdir)
  tf.logging.info('Training on %s set', FLAGS.train_split)

  with tf.Graph().as_default():
    with tf.device(config.inputs_device()):
      samples = input_generator.get(
          dataset,
          FLAGS.train_crop_size,
          clone_batch_size,
          min_resize_value=FLAGS.min_resize_value,
          max_resize_value=FLAGS.max_resize_value,
          resize_factor=FLAGS.resize_factor,
          min_scale_factor=FLAGS.min_scale_factor,
          max_scale_factor=FLAGS.max_scale_factor,
          scale_factor_step_size=FLAGS.scale_factor_step_size,
          dataset_split=FLAGS.train_split,
          is_training=True,
          model_variant=FLAGS.model_variant)
      inputs_queue = prefetch_queue.prefetch_queue(
          samples, capacity=128 * config.num_clones)

    # Create the global step on the device storing the variables.
    with tf.device(config.variables_device()):
      global_step = tf.train.get_or_create_global_step()

      # Define the model and create clones.
      model_fn = _build_deeplab
      model_args = (inputs_queue, {
          common.OUTPUT_TYPE: dataset.num_classes
      }, dataset.ignore_label)
      clones = model_deploy.create_clones(config, model_fn, args=model_args)

      # Gather update_ops from the first clone. These contain, for example,
      # the updates for the batch_norm variables created by model_fn.
      first_clone_scope = config.clone_scope(0)
      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

    # Gather initial summaries.
    summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

    # Add summaries for model variables.
    for model_var in slim.get_model_variables():
      summaries.add(tf.summary.histogram(model_var.op.name, model_var))

    # Add summaries for losses.
    for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
      summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

    # Build the optimizer based on the device specification.
    with tf.device(config.optimizer_device()):
      learning_rate = train_utils.get_model_learning_rate(
          FLAGS.learning_policy, FLAGS.base_learning_rate,
          FLAGS.learning_rate_decay_step, FLAGS.learning_rate_decay_factor,
          FLAGS.training_number_of_steps, FLAGS.learning_power,
          FLAGS.slow_start_step, FLAGS.slow_start_learning_rate)
      optimizer = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum)
      summaries.add(tf.summary.scalar('learning_rate', learning_rate))

    startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps
    for variable in slim.get_model_variables():
      summaries.add(tf.summary.histogram(variable.op.name, variable))

    with tf.device(config.variables_device()):
      total_loss, grads_and_vars = model_deploy.optimize_clones(
          clones, optimizer)
      total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.')
      summaries.add(tf.summary.scalar('total_loss', total_loss))

      # Modify the gradients for biases and last layer variables.
      last_layers = model.get_extra_layer_scopes(
          FLAGS.last_layers_contain_logits_only)
      grad_mult = train_utils.get_model_gradient_multipliers(
          last_layers, FLAGS.last_layer_gradient_multiplier)
      if grad_mult:
        grads_and_vars = slim.learning.multiply_gradients(
            grads_and_vars, grad_mult)

      # Create gradient update op.
      grad_updates = optimizer.apply_gradients(
          grads_and_vars, global_step=global_step)
      update_ops.append(grad_updates)
      update_op = tf.group(*update_ops)
      with tf.control_dependencies([update_op]):
        train_tensor = tf.identity(total_loss, name='train_op')

    # Add the summaries from the first clone. These contain the summaries
    # created by model_fn and either optimize_clones() or _gather_clone_loss().
    summaries |= set(
        tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

    # Merge all summaries together.
    summary_op = tf.summary.merge(list(summaries))

    # Soft placement allows placing on CPU ops without GPU implementation.
    session_config = tf.ConfigProto(
        allow_soft_placement=True, log_device_placement=False)

    # Start the training.
    slim.learning.train(
        train_tensor,
        logdir=FLAGS.train_logdir,
        log_every_n_steps=FLAGS.log_steps,
        master=FLAGS.master,
        number_of_steps=FLAGS.training_number_of_steps,
        is_chief=(FLAGS.task == 0),
        session_config=session_config,
        startup_delay_steps=startup_delay_steps,
        init_fn=train_utils.get_model_init_fn(
            FLAGS.train_logdir,
            FLAGS.tf_initial_checkpoint,
            FLAGS.initialize_last_layer,
            last_layers,
            ignore_missing_vars=True),
        summary_op=summary_op,
        save_summaries_secs=FLAGS.save_summaries_secs,
        save_interval_secs=FLAGS.save_interval_secs)
Example #6
0
def main(unused_argv):
    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)

    labels = FLAGS.labels.split(',')
    num_classes = len(labels)

    with tf.Graph().as_default() as graph:
        global_step = tf.compat.v1.train.get_or_create_global_step()

        # Define the model
        X = tf.compat.v1.placeholder(
            tf.float32, [None, FLAGS.num_views, FLAGS.height, FLAGS.width, 3],
            name='X')
        ground_truth = tf.compat.v1.placeholder(tf.int64, [None],
                                                name='ground_truth')
        is_training = tf.compat.v1.placeholder(tf.bool, name='is_training')
        dropout_keep_prob = tf.compat.v1.placeholder(tf.float32,
                                                     name='dropout_keep_prob')
        g_scheme = tf.compat.v1.placeholder(tf.int32,
                                            [FLAGS.num_group, FLAGS.num_views])
        g_weight = tf.compat.v1.placeholder(tf.float32, [FLAGS.num_group])

        # GVCNN
        view_scores, _, logits = model.gvcnn(X, num_classes, g_scheme,
                                             g_weight, is_training,
                                             dropout_keep_prob)

        # # basic - for verification
        # _, logits = model.basic(X,
        #                         num_classes,
        #                         is_training,
        #                         dropout_keep_prob)

        # Define loss
        _loss = tf.losses.sparse_softmax_cross_entropy(labels=ground_truth,
                                                       logits=logits)

        # Gather initial summaries.
        summaries = set(
            tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.SUMMARIES))

        prediction = tf.argmax(logits, 1, name='prediction')
        correct_prediction = tf.equal(prediction, ground_truth)
        confusion_matrix = tf.math.confusion_matrix(ground_truth,
                                                    prediction,
                                                    num_classes=num_classes)
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
        summaries.add(tf.compat.v1.summary.scalar('accuracy', accuracy))

        # # Add summaries for model variables.
        # for model_var in slim.get_model_variables():
        #     summaries.add(tf.compat.v1.summary.histogram(model_var.op.name, model_var))

        # Add summaries for losses.
        for loss in tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.LOSSES):
            summaries.add(
                tf.compat.v1.summary.scalar('losses/%s' % loss.op.name, loss))

        learning_rate = train_utils.get_model_learning_rate(
            FLAGS.learning_policy, FLAGS.base_learning_rate,
            FLAGS.learning_rate_decay_step, FLAGS.learning_rate_decay_factor,
            FLAGS.training_number_of_steps, FLAGS.learning_power,
            FLAGS.slow_start_step, FLAGS.slow_start_learning_rate)
        optimizer = tf.compat.v1.train.MomentumOptimizer(
            learning_rate, FLAGS.momentum)
        summaries.add(
            tf.compat.v1.summary.scalar('learning_rate', learning_rate))

        total_loss, grads_and_vars = train_utils.optimize(optimizer)
        total_loss = tf.debugging.check_numerics(total_loss,
                                                 'Loss is inf or nan.')
        summaries.add(tf.compat.v1.summary.scalar('total_loss', total_loss))

        # Gather update_ops.
        # These contain, for example, the updates for the batch_norm variables created by model.
        update_ops = tf.compat.v1.get_collection(
            tf.compat.v1.GraphKeys.UPDATE_OPS)

        # Create gradient update op.
        update_ops.append(
            optimizer.apply_gradients(grads_and_vars, global_step=global_step))
        update_op = tf.group(*update_ops)
        with tf.control_dependencies([update_op]):
            train_op = tf.identity(total_loss, name='train_op')

        ################
        # Prepare data
        ################
        filenames = tf.compat.v1.placeholder(tf.string, shape=[])
        tr_dataset = train_data.Dataset(filenames, FLAGS.num_views,
                                        FLAGS.height, FLAGS.width,
                                        FLAGS.batch_size)
        iterator = tr_dataset.dataset.make_initializable_iterator()
        next_batch = iterator.get_next()

        # validation dateset
        val_dataset = val_data.Dataset(filenames, FLAGS.num_views,
                                       FLAGS.height, FLAGS.width,
                                       FLAGS.val_batch_size)  # val_batch_size
        val_iterator = val_dataset.dataset.make_initializable_iterator()
        val_next_batch = val_iterator.get_next()

        sess_config = tf.compat.v1.ConfigProto(
            gpu_options=tf.compat.v1.GPUOptions(allow_growth=True))
        with tf.compat.v1.Session(config=sess_config) as sess:
            sess.run(tf.compat.v1.global_variables_initializer())

            # Add the summaries. These contain the summaries
            # created by model and either optimize() or _gather_loss().
            summaries |= set(
                tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.SUMMARIES))

            # Merge all summaries together.
            summary_op = tf.compat.v1.summary.merge(list(summaries))
            train_writer = tf.compat.v1.summary.FileWriter(
                FLAGS.summaries_dir, graph)
            validation_writer = tf.compat.v1.summary.FileWriter(
                FLAGS.summaries_dir + '/validation', graph)

            # Create a saver object which will save all the variables
            saver = tf.compat.v1.train.Saver(keep_checkpoint_every_n_hours=1.0)
            if FLAGS.pre_trained_checkpoint:
                train_utils.restore_fn(FLAGS)

            if FLAGS.saved_checkpoint_dir:
                if tf.gfile.IsDirectory(FLAGS.saved_checkpoint_dir):
                    checkpoint_path = tf.train.latest_checkpoint(
                        FLAGS.saved_checkpoint_dir)
                else:
                    checkpoint_path = FLAGS.saved_checkpoint_dir
                saver.restore(sess, checkpoint_path)

            start_epoch = 0
            # Get the number of training/validation steps per epoch
            tr_batches = int(MODELNET_TRAIN_DATA_SIZE / FLAGS.batch_size)
            if MODELNET_TRAIN_DATA_SIZE % FLAGS.batch_size > 0:
                tr_batches += 1
            val_batches = int(MODELNET_VALIDATE_DATA_SIZE /
                              FLAGS.val_batch_size)
            if MODELNET_VALIDATE_DATA_SIZE % FLAGS.val_batch_size > 0:
                val_batches += 1

            # The filenames argument to the TFRecordDataset initializer can either be a string,
            # a list of strings, or a tf.Tensor of strings.
            training_filenames = os.path.join(FLAGS.dataset_dir,
                                              'modelnet5_6view_train.record')
            validate_filenames = os.path.join(FLAGS.dataset_dir,
                                              'modelnet5_6view_test.record')

            ###################################
            # Training loop.
            ###################################
            for num_epoch in range(start_epoch,
                                   FLAGS.how_many_training_epochs):
                print("-------------------------------------")
                print(" Epoch {} ".format(num_epoch))
                print("-------------------------------------")

                sess.run(iterator.initializer,
                         feed_dict={filenames: training_filenames})
                for step in range(tr_batches):
                    # Pull the image batch we'll use for training.
                    train_batch_xs, train_batch_ys = sess.run(next_batch)

                    # Sets up a graph with feeds and fetches for partial run.
                    handle = sess.partial_run_setup(
                        [
                            view_scores,
                            learning_rate,
                            # summary_op, top1_acc, loss, optimize_op, dummy],
                            summary_op,
                            accuracy,
                            _loss,
                            train_op
                        ],
                        [
                            X, ground_truth, g_scheme, g_weight, is_training,
                            dropout_keep_prob
                        ])

                    _view_scores = sess.partial_run(handle, [view_scores],
                                                    feed_dict={
                                                        X: train_batch_xs,
                                                        is_training: True,
                                                        dropout_keep_prob: 0.8
                                                    })
                    _g_schemes = model.group_scheme(_view_scores,
                                                    FLAGS.num_group,
                                                    FLAGS.num_views)
                    _g_weights = model.group_weight(_g_schemes)

                    # Run the graph with this batch of training data.
                    lr, train_summary, train_accuracy, train_loss, _ = \
                        sess.partial_run(handle,
                                         [learning_rate, summary_op, accuracy, _loss, train_op],
                                         feed_dict={
                                             ground_truth: train_batch_ys,
                                             g_scheme: _g_schemes,
                                             g_weight: _g_weights}
                                         )

                    # for verification
                    # lr, train_summary, train_accuracy, train_loss, _ = \
                    #     sess.run([learning_rate, summary_op, accuracy, _loss, train_op],
                    #              feed_dict={
                    #                  X: train_batch_xs,
                    #                  ground_truth: train_batch_ys,
                    #                  is_training: True,
                    #                  dropout_keep_prob: 0.8}
                    #              )

                    train_writer.add_summary(train_summary, num_epoch)
                    tf.compat.v1.logging.info(
                        'Epoch #%d, Step #%d, rate %.6f, top1_acc %.3f%%, loss %.5f'
                        % (num_epoch, step, lr, train_accuracy, train_loss))

                ###################################################
                # Validate the model on the validation set
                ###################################################
                tf.compat.v1.logging.info('--------------------------')
                tf.compat.v1.logging.info(' Start validation ')
                tf.compat.v1.logging.info('--------------------------')

                total_val_losses = 0.0
                total_val_top1_acc = 0.0
                val_count = 0
                total_conf_matrix = None

                # Reinitialize val_iterator with the validation dataset
                sess.run(val_iterator.initializer,
                         feed_dict={filenames: validate_filenames})
                for step in range(val_batches):
                    validation_batch_xs, validation_batch_ys = sess.run(
                        val_next_batch)

                    # Sets up a graph with feeds and fetches for partial run.
                    handle = sess.partial_run_setup([
                        view_scores, summary_op, accuracy, _loss,
                        confusion_matrix
                    ], [
                        X, g_scheme, g_weight, ground_truth, is_training,
                        dropout_keep_prob
                    ])

                    _view_scores = sess.partial_run(handle, [view_scores],
                                                    feed_dict={
                                                        X: validation_batch_xs,
                                                        is_training: False,
                                                        dropout_keep_prob: 1.0
                                                    })
                    _g_schemes = model.group_scheme(_view_scores,
                                                    FLAGS.num_group,
                                                    FLAGS.num_views)
                    _g_weights = model.group_weight(_g_schemes)

                    # Run the graph with this batch of training data.
                    val_summary, val_accuracy, val_loss, conf_matrix = \
                        sess.partial_run(handle,
                                         [summary_op, accuracy, _loss, confusion_matrix],
                                         feed_dict={
                                             ground_truth: validation_batch_ys,
                                             g_scheme: _g_schemes,
                                             g_weight: _g_weights}
                                         )

                    # for verification
                    # val_summary, val_accuracy, val_loss, conf_matrix = \
                    #     sess.run([summary_op, accuracy, _loss, confusion_matrix],
                    #              feed_dict={
                    #                  X: validation_batch_xs,
                    #                  ground_truth: validation_batch_ys,
                    #                  is_training: False,
                    #                  dropout_keep_prob: 1.0}
                    #              )

                    validation_writer.add_summary(val_summary, num_epoch)

                    total_val_losses += val_loss
                    total_val_top1_acc += val_accuracy
                    val_count += 1
                    if total_conf_matrix is None:
                        total_conf_matrix = conf_matrix
                    else:
                        total_conf_matrix += conf_matrix

                total_val_losses /= val_count
                total_val_top1_acc /= val_count

                tf.compat.v1.logging.info('Confusion Matrix:\n %s' %
                                          total_conf_matrix)
                tf.compat.v1.logging.info('Validation loss = %.5f' %
                                          total_val_losses)
                tf.compat.v1.logging.info(
                    'Validation accuracy = %.3f%% (N=%d)' %
                    (total_val_top1_acc, MODELNET_VALIDATE_DATA_SIZE))

                # Save the model checkpoint periodically.
                if (num_epoch <= FLAGS.how_many_training_epochs - 1):
                    checkpoint_path = os.path.join(FLAGS.train_logdir,
                                                   FLAGS.ckpt_name_to_save)
                    tf.compat.v1.logging.info('Saving to "%s-%d"',
                                              checkpoint_path, num_epoch)
                    saver.save(sess, checkpoint_path, global_step=num_epoch)
Example #7
0
def main(unused_arg):
    tf.logging.set_verbosity(tf.logging.INFO)
    # Set up deployment (i.e., multi-GPUs and/or multi-replicas).
    config = model_deploy.DeploymentConfig(num_clones=FLAGS.num_clones,
                                           clone_on_cpu=FLAGS.clone_on_cpu,
                                           replica_id=FLAGS.task,
                                           num_replicas=FLAGS.num_replicas,
                                           num_ps_tasks=FLAGS.num_ps_tasks)

    # Split the batch across GPUs.
    assert FLAGS.train_batch_size % config.num_clones == 0, (
        'Training batch size not divisble by number of clones (GPUs).')

    clone_batch_size = FLAGS.train_batch_size // config.num_clones

    tf.gfile.MakeDirs(FLAGS.train_dir)

    with tf.Graph().as_default() as graph:
        with tf.device(config.inputs_device()):
            samples, num_samples = get_dataset.get_dataset(
                FLAGS.dataset,
                FLAGS.dataset_dir,
                split_name=FLAGS.train_split,
                is_training=True,
                image_size=[FLAGS.image_size, FLAGS.image_size],
                batch_size=clone_batch_size,
                channel=FLAGS.input_channel)
            tf.logging.info('Training on %s set: %d', FLAGS.train_split,
                            num_samples)
            inputs_queue = prefetch_queue.prefetch_queue(samples,
                                                         capacity=128 *
                                                         config.num_clones)
        # Create the global step on the device storing the variables.
        with tf.device(config.variables_device()):
            global_step = tf.train.get_or_create_global_step()
            # Define the model and create clones.
            model_fn = _build_model
            model_args = (inputs_queue, clone_batch_size)
            clones = model_deploy.create_clones(config,
                                                model_fn,
                                                args=model_args)

            # Gather update_ops from the first clone. These contain, for example,
            # the updates for the batch_norm variables created by model_fn.
            first_clone_scope = config.clone_scope(0)
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                           first_clone_scope)
        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
        # Add summaries for model variables.
        if FLAGS.save_summaries_variables:
            for model_var in slim.get_model_variables():
                summaries.add(
                    tf.summary.histogram(model_var.op.name, model_var))

        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))
        # Build the optimizer based on the device specification.
        with tf.device(config.optimizer_device()):
            learning_rate = train_utils.get_model_learning_rate(
                FLAGS.learning_policy, FLAGS.base_learning_rate,
                FLAGS.learning_rate_decay_step,
                FLAGS.learning_rate_decay_factor, FLAGS.number_of_steps,
                FLAGS.learning_power, FLAGS.slow_start_step,
                FLAGS.slow_start_learning_rate)
            optimizer = tf.train.AdamOptimizer(learning_rate)
            #optimizer = tf.train.RMSPropOptimizer(learning_rate, momentum=FLAGS.momentum)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps
        with tf.device(config.variables_device()):
            total_loss, grads_and_vars = model_deploy.optimize_clones(
                clones, optimizer)
            total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.')
            summaries.add(tf.summary.scalar('losses/total_loss', total_loss))

            # Modify the gradients for biases and last layer variables.
            if (FLAGS.dataset == 'protein') and FLAGS.add_counts_logits:
                last_layers = ['Logits', 'Counts_logits']
            else:
                last_layers = ['Logits']
            grad_mult = train_utils.get_model_gradient_multipliers(
                last_layers, FLAGS.last_layer_gradient_multiplier)
            if grad_mult:
                grads_and_vars = slim.learning.multiply_gradients(
                    grads_and_vars, grad_mult)

            # Create gradient update op.
            grad_updates = optimizer.apply_gradients(grads_and_vars,
                                                     global_step=global_step)
            update_ops.append(grad_updates)
            update_op = tf.group(*update_ops)
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(total_loss, name='train_op')

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries))

        # Soft placement allows placing on CPU ops without GPU implementation.
        session_config = tf.ConfigProto(allow_soft_placement=True,
                                        log_device_placement=False)
        session_config.gpu_options.allow_growth = True
        session_config.gpu_options.per_process_gpu_memory_fraction = 0.9

        # Start the training.
        slim.learning.train(train_tensor,
                            FLAGS.train_dir,
                            is_chief=(FLAGS.task == 0),
                            master=FLAGS.master,
                            graph=graph,
                            log_every_n_steps=FLAGS.log_every_n_steps,
                            session_config=session_config,
                            startup_delay_steps=startup_delay_steps,
                            number_of_steps=FLAGS.number_of_steps,
                            save_summaries_secs=FLAGS.save_summaries_secs,
                            save_interval_secs=FLAGS.save_interval_secs,
                            init_fn=train_utils.get_model_init_fn(
                                FLAGS.train_dir,
                                FLAGS.fine_tune_checkpoint,
                                FLAGS.initialize_last_layer,
                                last_layers,
                                ignore_missing_vars=True),
                            summary_op=summary_op,
                            saver=tf.train.Saver(max_to_keep=50))
Example #8
0
def build_model():
  """Builds graph for model to train with rewrites for quantization.
  Returns:
    g: Graph with fake quantization ops and batch norm folding suitable for
    training quantized weights.
    train_tensor: Train op for execution during training.
  """
  g = tf.Graph()
  with g.as_default(), tf.device(
      tf.train.replica_device_setter(FLAGS.ps_tasks)):
    samples, _ = get_dataset.get_dataset(FLAGS.dataset, FLAGS.dataset_dir,
                                         split_name=FLAGS.train_split,
                                         is_training=True,
                                         image_size=[FLAGS.image_size, FLAGS.image_size],
                                         batch_size=FLAGS.batch_size,
                                         channel=FLAGS.input_channel)

    inputs = tf.identity(samples['image'], name='image')
    labels = tf.identity(samples['label'], name='label')
    model_options = common.ModelOptions(output_stride=FLAGS.output_stride)
    net, end_points = model.get_features(
        inputs,
        model_options=model_options,
        weight_decay=FLAGS.weight_decay,
        is_training=True,
        fine_tune_batch_norm=FLAGS.fine_tune_batch_norm)
    logits, _ = model.classification(net, end_points, 
                                     num_classes=FLAGS.num_classes,
                                     is_training=True)
    logits = slim.softmax(logits)
    focal_loss_tensor = train_utils.focal_loss(labels, logits, weights=1.0)
    # f1_loss_tensor = train_utils.f1_loss(labels, logits, weights=1.0)
    # cls_loss = f1_loss_tensor
    cls_loss = focal_loss_tensor

    # Gather update_ops
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    # Gather initial summaries.
    summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

    global_step = tf.train.get_or_create_global_step()
    learning_rate = train_utils.get_model_learning_rate(
          FLAGS.learning_policy, FLAGS.base_learning_rate,
          FLAGS.learning_rate_decay_step, FLAGS.learning_rate_decay_factor,
          FLAGS.number_of_steps, FLAGS.learning_power,
          FLAGS.slow_start_step, FLAGS.slow_start_learning_rate)
    opt = tf.train.AdamOptimizer(learning_rate)
    # opt = tf.train.RMSPropOptimizer(learning_rate, momentum=FLAGS.momentum)
    summaries.add(tf.summary.scalar('learning_rate', learning_rate))

    for loss in tf.get_collection(tf.GraphKeys.LOSSES):
      summaries.add(tf.summary.scalar('sub_losses/%s'%(loss.op.name), loss))
    classifation_loss = tf.identity(cls_loss, name='classifation_loss')
    summaries.add(tf.summary.scalar('losses/classifation_loss', classifation_loss))
    regularization_loss = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
    regularization_loss = tf.add_n(regularization_loss, name='regularization_loss')
    summaries.add(tf.summary.scalar('losses/regularization_loss', regularization_loss))

    total_loss = tf.add(cls_loss, regularization_loss, name='total_loss')
    grads_and_vars = opt.compute_gradients(total_loss)

    total_loss = tf.check_numerics(total_loss, 'LossTensor is inf or nan.')
    summaries.add(tf.summary.scalar('losses/total_loss', total_loss))

    grad_updates = opt.apply_gradients(grads_and_vars, global_step=global_step)
    update_ops.append(grad_updates)
    update_op = tf.group(*update_ops, name='update_barrier')
    with tf.control_dependencies([update_op]):
      train_tensor = tf.identity(total_loss, name='train_op')

  # Merge all summaries together.
  summary_op = tf.summary.merge(list(summaries))
  return g, train_tensor, summary_op
Example #9
0
def _train_pgn_model(iterator,
                     num_of_classes,
                     model_options,
                     ignore_label,
                     reuse=None):
    """Trains the pgn model.
  Args:
    iterator: An iterator of type tf.data.Iterator for images and labels.
    num_of_classes: Number of classes for the dataset.
    ignore_label: Ignore label for the dataset.
  Returns:
    train_tensor: A tensor to update the model variables.
    summary_op: An operation to log the summaries.
  """
    global_step = tf.train.get_or_create_global_step()
    summaries = []

    learning_rate = train_utils.get_model_learning_rate(
        FLAGS.learning_policy, FLAGS.base_learning_rate,
        FLAGS.learning_rate_decay_step, FLAGS.learning_rate_decay_factor,
        FLAGS.training_number_of_steps, FLAGS.learning_power,
        FLAGS.slow_start_step, FLAGS.slow_start_learning_rate)

    optimizer = tf.train.AdamOptimizer(learning_rate)
    # optimizer = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum)

    tower_grads = []
    total_loss, total_seg_loss = 0, 0
    tower_summaries = None
    for i in range(FLAGS.num_clones):
        with tf.device('/gpu:%d' % i):
            with tf.name_scope('clone_%d' % i) as scope:
                loss, seg_loss = _tower_loss(iterator=iterator,
                                             num_of_classes=num_of_classes,
                                             model_options=model_options,
                                             ignore_label=ignore_label,
                                             scope=scope,
                                             reuse_variable=(i != 0)
                                             # reuse_variable=reuse
                                             )
                total_loss += loss
                total_seg_loss += seg_loss

                grads = optimizer.compute_gradients(loss)
                tower_grads.append(grads)

    tower_summaries = tf.summary.merge_all()
    summaries.append(tf.summary.scalar('learning_rate', learning_rate))

    with tf.device('/cpu:0'):
        grads_and_vars = _average_gradients(tower_grads)
        if tower_summaries is not None:
            summaries.append(tower_summaries)

        # Create gradient update op.
        grad_updates = optimizer.apply_gradients(grads_and_vars,
                                                 global_step=global_step)

        # Gather update_ops. These contain, for example,
        # the updates for the batch_norm variables created by model_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        update_ops.append(grad_updates)
        update_op = tf.group(*update_ops)

        should_log = tf.equal(math_ops.mod(global_step, FLAGS.log_steps), 0)
        total_loss = tf.cond(
            should_log, lambda: tf.Print(
                total_loss, [total_loss, total_seg_loss, global_step],
                'Total loss, Segmentation loss and Global step:'),
            lambda: total_loss)

        summaries.append(tf.summary.scalar('total_loss', total_loss))

        with tf.control_dependencies([update_op]):
            train_tensor = tf.identity(total_loss, name='train_op')
        summary_op = tf.summary.merge(summaries)

    return train_tensor, summary_op
Example #10
0
def main(unused_argv):
    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)

    labels = FLAGS.labels.split(',')
    num_classes = len(labels)

    # tf.compat.v1.logging.info('Creating train logdir: %s', FLAGS.train_logdir)

    with tf.Graph().as_default() as graph:
        global_step = tf.compat.v1.train.get_or_create_global_step()

        X = tf.compat.v1.placeholder(
            tf.float32, [None, FLAGS.num_views, FLAGS.height, FLAGS.width, 3],
            name='X')
        ground_truth = tf.compat.v1.placeholder(tf.int64, [None],
                                                name='ground_truth')
        is_training = tf.compat.v1.placeholder(tf.bool, name='is_training')
        dropout_keep_prob = tf.compat.v1.placeholder(tf.float32,
                                                     name='dropout_keep_prob')
        # learning_rate = tf.placeholder(tf.float32, name='lr')

        # metric learning
        logits, features = \
            model.mvcnn_with_deep_cosine_metric_learning(X,
                                                         num_classes,
                                                         is_training=is_training,
                                                         keep_prob=dropout_keep_prob,
                                                         attention_module='se_block')
        # logits, features = mvcnn.mvcnn(X, num_classes)

        cross_entropy = tf.compat.v1.losses.sparse_softmax_cross_entropy(
            labels=ground_truth, logits=logits)
        tf.compat.v1.summary.scalar("cross_entropy_loss", cross_entropy)

        # Gather update ops. These contain, for example, the updates for the
        # batch_norm variables created by model.
        update_ops = tf.compat.v1.get_collection(
            tf.compat.v1.GraphKeys.UPDATE_OPS)

        # Gather initial summaries.
        summaries = set(
            tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.SUMMARIES))

        predition = tf.argmax(logits, 1, name='prediction')
        correct_predition = tf.equal(predition, ground_truth)
        confusion_matrix = tf.math.confusion_matrix(ground_truth,
                                                    predition,
                                                    num_classes=num_classes)
        # accuracy = tf.reduce_mean(tf.cast(correct_predition, tf.float32))
        # summaries.add(tf.summary.scalar('accuracy', accuracy))
        accuracy = slim.metrics.accuracy(tf.cast(predition, tf.int64),
                                         ground_truth)
        tf.compat.v1.summary.scalar("accuracy", accuracy)

        # Add summaries for model variables.
        for model_var in slim.get_model_variables():
            summaries.add(
                tf.compat.v1.summary.histogram(model_var.op.name, model_var))

        # Add summaries for losses.
        for loss in tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.LOSSES):
            summaries.add(
                tf.compat.v1.summary.scalar('losses/%s' % loss.op.name, loss))

        learning_rate = train_utils.get_model_learning_rate(
            FLAGS.learning_policy, FLAGS.base_learning_rate,
            FLAGS.learning_rate_decay_step, FLAGS.learning_rate_decay_factor,
            FLAGS.training_number_of_steps, FLAGS.learning_power,
            FLAGS.slow_start_step, FLAGS.slow_start_learning_rate)
        # optimizer = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum)
        optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate)
        summaries.add(
            tf.compat.v1.summary.scalar('learning_rate', learning_rate))

        total_loss, grads_and_vars = train_utils.optimize(optimizer)
        total_loss = tf.compat.v1.check_numerics(total_loss,
                                                 'Loss is inf or nan')
        summaries.add(tf.compat.v1.summary.scalar('total_loss', total_loss))

        # TensorBoard: How to plot histogram for gradients
        # grad_summ_op = tf.compat.v1.summary.merge([tf.compat.v1.summary.histogram("%s-grad" % g[1].name, g[0]) for g in grads_and_vars])

        # Create gradient update op.
        grad_updates = optimizer.apply_gradients(grads_and_vars,
                                                 global_step=global_step)
        update_ops.append(grad_updates)
        update_op = tf.group(*update_ops)
        with tf.control_dependencies([update_op]):
            train_op = tf.identity(total_loss, name='train_op')

        # Add the summaries. These contain the summaries created by model
        # and either optimize() or _gather_loss()
        summaries |= set(
            tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.SUMMARIES))

        # Merge all summaries together.
        summary_op = tf.compat.v1.summary.merge(list(summaries))
        train_writer = tf.compat.v1.summary.FileWriter(FLAGS.summaries_dir,
                                                       graph)
        validation_writer = tf.compat.v1.summary.FileWriter(
            FLAGS.summaries_dir + '/validation', graph)

        #####################
        # prepare data
        #####################
        tfrecord_names = tf.compat.v1.placeholder(tf.string, shape=[])
        _dataset = data.Dataset(tfrecord_names, FLAGS.num_views, FLAGS.height,
                                FLAGS.width, FLAGS.batch_size)
        iterator = _dataset.dataset.make_initializable_iterator()
        next_batch = iterator.get_next()

        sess_config = tf.compat.v1.ConfigProto(
            gpu_options=tf.compat.v1.GPUOptions(allow_growth=True))
        with tf.compat.v1.Session(config=sess_config) as sess:
            sess.run(tf.compat.v1.global_variables_initializer())

            saver = tf.compat.v1.train.Saver(keep_checkpoint_every_n_hours=1.0)
            if FLAGS.pre_trained_checkpoint:
                train_utils.restore_fn(FLAGS)

            start_epoch = 0
            training_batches = int(MODELNET10_TRAIN_DATA_SIZE /
                                   FLAGS.batch_size)
            if MODELNET10_TRAIN_DATA_SIZE % FLAGS.batch_size > 0:
                training_batches += 1
            val_batches = int(MODELNET10_VALIDATE_DATA_SIZE / FLAGS.batch_size)
            if MODELNET10_VALIDATE_DATA_SIZE % FLAGS.batch_size > 0:
                val_batches += 1

            # The filenames argument to the TFRecordDataset initializer can either
            # be a string, a list of strings, or a tf.Tensor of strings.
            training_tf_filenames = os.path.join(FLAGS.dataset_dir,
                                                 'train.record')
            val_tf_filenames = os.path.join(FLAGS.dataset_dir,
                                            'validate.record')
            ##################
            # Training loop.
            ##################
            for n_epoch in range(start_epoch, FLAGS.how_many_training_epochs):
                tf.compat.v1.logging.info('--------------------------')
                tf.compat.v1.logging.info(' Epoch %d' % n_epoch)
                tf.compat.v1.logging.info('--------------------------')

                sess.run(iterator.initializer,
                         feed_dict={tfrecord_names: training_tf_filenames})
                for step in range(training_batches):
                    train_batch_xs, train_batch_ys = sess.run(next_batch)
                    # # Verify image
                    # assert not np.any(np.isnan(train_batch_xs))
                    # n_batch = train_batch_xs.shape[0]
                    # n_view = train_batch_xs.shape[1]
                    # for i in range(n_batch):
                    #     for j in range(n_view):
                    #         img = train_batch_xs[i][j]
                    #         # scipy.misc.toimage(img).show()
                    #         # Or
                    #         img = cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_BGR2RGB)
                    #         cv2.imwrite('/home/ace19/Pictures/' + str(i) +
                    #                     '_' + str(j) + '.png', img)
                    #         # cv2.imshow(str(train_batch_ys[idx]), img)
                    #         cv2.waitKey(100)
                    #         cv2.destroyAllWindows()

                    lr, train_summary, train_accuracy, train_loss, _ = \
                        sess.run([learning_rate, summary_op, accuracy, total_loss, train_op],
                                 feed_dict={X: train_batch_xs,
                                            ground_truth: train_batch_ys,
                                            is_training: True,
                                            dropout_keep_prob: 0.8})

                    # lr, train_summary, train_accuracy, train_loss, grad_vals, _ = \
                    #     sess.run([learning_rate, summary_op, accuracy, total_loss, grad_summ_op, train_op],
                    #     feed_dict={X: train_batch_xs,
                    #                ground_truth: train_batch_ys,
                    #                is_training: True,
                    #                dropout_keep_prob: 0.8})

                    train_writer.add_summary(train_summary, n_epoch)
                    # train_writer.add_summary(grad_vals, n_epoch)
                    tf.compat.v1.logging.info(
                        'Epoch #%d, Step #%d, rate %.10f, accuracy %.1f%%, loss %f'
                        %
                        (n_epoch, step, lr, train_accuracy * 100, train_loss))

                ###################################################
                # Validate the model on the validation set
                ###################################################
                tf.compat.v1.logging.info('--------------------------')
                tf.compat.v1.logging.info(' Start validation ')
                tf.compat.v1.logging.info('--------------------------')

                # Reinitialize iterator with the validation dataset
                sess.run(iterator.initializer,
                         feed_dict={tfrecord_names: val_tf_filenames})

                total_val_accuracy = 0
                validation_count = 0
                total_conf_matrix = None
                for step in range(val_batches):
                    validation_batch_xs, validation_batch_ys = sess.run(
                        next_batch)

                    val_summary, val_accuracy, conf_matrix = \
                        sess.run([summary_op, accuracy, confusion_matrix],
                                 feed_dict={X: validation_batch_xs,
                                            ground_truth: validation_batch_ys,
                                            is_training: False,
                                            dropout_keep_prob: 1.0})

                    validation_writer.add_summary(val_summary, n_epoch)

                    total_val_accuracy += val_accuracy
                    validation_count += 1
                    if total_conf_matrix is None:
                        total_conf_matrix = conf_matrix
                    else:
                        total_conf_matrix += conf_matrix

                total_val_accuracy /= validation_count
                tf.compat.v1.logging.info('Confusion Matrix:\n %s' %
                                          (total_conf_matrix))
                tf.compat.v1.logging.info(
                    'Validation accuracy = %.1f%% (N=%d)' %
                    (total_val_accuracy * 100, MODELNET10_VALIDATE_DATA_SIZE))

                # Save the model checkpoint periodically.
                if (n_epoch <= FLAGS.how_many_training_epochs - 1):
                    checkpoint_path = os.path.join(FLAGS.train_logdir,
                                                   FLAGS.ckpt_name_to_save)
                    tf.compat.v1.logging.info('Saving to "%s-%d"',
                                              checkpoint_path, n_epoch)
                    saver.save(sess, checkpoint_path, global_step=n_epoch)
Example #11
0
def main(unused_argv):
    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)

    labels = FLAGS.labels.split(',')
    num_classes = len(labels)

    with tf.Graph().as_default() as graph:
        global_step = tf.compat.v1.train.get_or_create_global_step()

        X = tf.compat.v1.placeholder(tf.float32,
                                     [None, FLAGS.height, FLAGS.width, 3],
                                     name='X')
        ground_truth = tf.compat.v1.placeholder(tf.int64, [None],
                                                name='ground_truth')
        is_training = tf.compat.v1.placeholder(tf.bool, name='is_training')
        keep_prob = tf.compat.v1.placeholder(tf.float32, [], name='keep_prob')
        tfrecord_filenames = tf.compat.v1.placeholder(tf.string, shape=[])

        # # Print name and shape of each tensor.
        # tf.compat.v1.logging.info("++++++++++++++++++++++++++++++++++")
        # tf.compat.v1.logging.info("Layers")
        # tf.compat.v1.logging.info("++++++++++++++++++++++++++++++++++")
        # for k, v in end_points.items():
        #     tf.compat.v1.logging.info('name = %s, shape = %s' % (v.name, v.get_shape()))
        #

        # Gather initial summaries.
        summaries = set(
            tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.SUMMARIES))
        # # Add summaries for model variables.
        # for variable in slim.get_model_variables():
        #     summaries.add(tf.compat.v1.summary.histogram(variable.op.name, variable))
        #
        # # Add summaries for losses.
        # for loss in tf.compat.v1.get_collection(tf.GraphKeys.LOSSES):
        #     summaries.add(tf.compat.v1.summary.scalar('losses/%s' % loss.op.name, loss))

        learning_rate = train_utils.get_model_learning_rate(
            FLAGS.learning_policy, FLAGS.base_learning_rate,
            FLAGS.learning_rate_decay_step, FLAGS.learning_rate_decay_factor,
            FLAGS.training_number_of_steps, FLAGS.learning_power,
            FLAGS.slow_start_step, FLAGS.slow_start_learning_rate)
        summaries.add(
            tf.compat.v1.summary.scalar('learning_rate', learning_rate))

        # optimizers = \
        #     [tf.train.RMSPropOptimizer(learning_rate, decay=0.9, momentum=0.9) for _ in range(FLAGS.num_gpu)]
        # optimizers = \
        #     [tf.compat.v1.train.MomentumOptimizer(learning_rate, FLAGS.momentum) for _ in range(FLAGS.num_gpu)]
        optimizers = \
            [tf.compat.v1.train.GradientDescentOptimizer(learning_rate) for _ in range(FLAGS.num_gpu)]

        logits = []
        losses = []
        grad_list = []
        filename_batch = []
        image_batch = []
        gt_batch = []
        for gpu_idx in range(FLAGS.num_gpu):
            tf.compat.v1.logging.info('creating gpu tower @ %d' %
                                      (gpu_idx + 1))
            image_batch.append(X)
            gt_batch.append(ground_truth)

            scope_name = 'tower%d' % gpu_idx
            with tf.device(tf.DeviceSpec(device_type="GPU", device_index=gpu_idx)), \
                 tf.compat.v1.variable_scope(scope_name):
                # apply SENet
                _, logit = model.deep_cosine_softmax(
                    X,
                    num_classes=num_classes,
                    is_training=is_training,
                    is_reuse=False,
                    keep_prob=keep_prob,
                    attention_module='se_block')

                # # Print name and shape of parameter nodes  (values not yet initialized)
                tf.compat.v1.logging.info("++++++++++++++++++++++++++++++++++")
                tf.compat.v1.logging.info("Parameters")
                tf.compat.v1.logging.info("++++++++++++++++++++++++++++++++++")
                for v in slim.get_model_variables():
                    tf.compat.v1.logging.info('name = %s, shape = %s' %
                                              (v.name, v.get_shape()))

                # # TTA
                # logit = tf.cond(is_training,
                #                 lambda: tf.identity(logit),
                #                 lambda: tf.reduce_mean(tf.reshape(logit, [FLAGS.val_batch_size // FLAGS.num_gpu, TEN_CROP, -1]), axis=1))
                logits.append(logit)

                l = tf.nn.sparse_softmax_cross_entropy_with_logits(
                    labels=ground_truth, logits=logit)
                losses.append(l)
                loss_w_reg = tf.reduce_sum(l) + tf.add_n(
                    slim.losses.get_regularization_losses(scope=scope_name))

                grad_list.append([
                    x
                    for x in optimizers[gpu_idx].compute_gradients(loss_w_reg)
                    if x[0] is not None
                ])

        y_hat = tf.concat(logits, axis=0)
        image_batch = tf.concat(image_batch, axis=0)
        gt_batch = tf.concat(gt_batch, axis=0)

        # Acc
        top1_acc = tf.reduce_mean(
            tf.cast(tf.nn.in_top_k(y_hat, gt_batch, k=1), dtype=tf.float32))
        summaries.add(tf.compat.v1.summary.scalar('top1_acc', top1_acc))
        # top5_acc = tf.reduce_mean(
        #     tf.cast(tf.nn.in_top_k(y_hat, gt_batch, k=5), dtype=tf.float32)
        # )
        # summaries.add(tf.compat.v1.summary.scalar('top5_acc', top5_acc))
        prediction = tf.argmax(y_hat, axis=1, name='prediction')
        confusion_matrix = tf.math.confusion_matrix(gt_batch,
                                                    prediction,
                                                    num_classes=num_classes)
        confusion_matrix = tf.div(confusion_matrix, FLAGS.num_gpu)

        loss = tf.reduce_mean(losses)
        loss = tf.compat.v1.check_numerics(loss, 'Loss is inf or nan.')
        summaries.add(tf.compat.v1.summary.scalar('loss', loss))

        # use NCCL
        grads, all_vars = train_helper.split_grad_list(grad_list)
        reduced_grad = train_helper.allreduce_grads(grads, average=True)
        grads = train_helper.merge_grad_list(reduced_grad, all_vars)

        # optimizer using NCCL
        train_ops = []
        for idx, grad_and_vars in enumerate(grads):
            # apply_gradients may create variables. Make them LOCAL_VARIABLESZ¸¸¸¸¸¸
            with tf.name_scope('apply_gradients'), tf.device(
                    tf.DeviceSpec(device_type="GPU", device_index=idx)):
                update_ops = tf.compat.v1.get_collection(
                    tf.compat.v1.GraphKeys.UPDATE_OPS, scope='tower%d' % idx)
                with tf.control_dependencies(update_ops):
                    train_ops.append(optimizers[idx].apply_gradients(
                        grad_and_vars,
                        name='apply_grad_{}'.format(idx),
                        global_step=global_step))
                # TODO:
                # TensorBoard: How to plot histogram for gradients
                # grad_summ_op = tf.summary.merge([tf.summary.histogram("%s-grad" % g[1].name, g[0]) for g in grads_and_vars])

        optimize_op = tf.group(*train_ops, name='train_op')

        sync_op = train_helper.get_post_init_ops()

        # Create a saver object which will save all the variables
        saver = tf.compat.v1.train.Saver()
        best_ckpt_saver = BestCheckpointSaver(save_dir=FLAGS.train_logdir,
                                              num_to_keep=100,
                                              maximize=False,
                                              saver=saver)
        best_val_loss = 99999
        best_val_acc = 0

        start_epoch = 0
        epoch_count = tf.Variable(start_epoch, trainable=False)
        epoch_count_add = tf.assign(epoch_count, epoch_count + 1)

        ###############
        # Prepare data
        ###############
        # training dateset
        tr_dataset = train_data.Dataset(tfrecord_filenames,
                                        FLAGS.batch_size // FLAGS.num_gpu,
                                        num_classes,
                                        FLAGS.how_many_training_epochs,
                                        TRAIN_DATA_SIZE, FLAGS.height,
                                        FLAGS.width)
        iterator = tr_dataset.dataset.make_initializable_iterator()
        next_batch = iterator.get_next()

        # validation dateset
        val_dataset = val_data.Dataset(tfrecord_filenames,
                                       FLAGS.val_batch_size // FLAGS.num_gpu,
                                       num_classes,
                                       FLAGS.how_many_training_epochs,
                                       VALIDATE_DATA_SIZE, FLAGS.height,
                                       FLAGS.width)
        # 256,  # 256 ~ 480
        # 256)
        val_iterator = val_dataset.dataset.make_initializable_iterator()
        val_next_batch = val_iterator.get_next()

        sess_config = tf.compat.v1.ConfigProto(
            gpu_options=tf.compat.v1.GPUOptions(allow_growth=True))
        with tf.compat.v1.Session(config=sess_config) as sess:
            sess.run(tf.compat.v1.global_variables_initializer())

            # Add the summaries. These contain the summaries
            # created by model and either optimize() or _gather_loss().
            summaries |= set(
                tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.SUMMARIES))

            # Merge all summaries together.
            summary_op = tf.compat.v1.summary.merge(list(summaries))
            train_writer = tf.compat.v1.summary.FileWriter(
                FLAGS.summaries_dir, graph)
            validation_writer = tf.compat.v1.summary.FileWriter(
                FLAGS.summaries_dir + '/validation', graph)

            # TODO: supports multi gpu -> add scope ('tower%d' % gpu_idx)
            if FLAGS.pre_trained_checkpoint:
                train_utils.restore_fn(FLAGS)

            if FLAGS.saved_checkpoint_dir:
                if tf.gfile.IsDirectory(FLAGS.saved_checkpoint_dir):
                    checkpoint_path = tf.train.latest_checkpoint(
                        FLAGS.saved_checkpoint_dir)
                else:
                    checkpoint_path = FLAGS.saved_checkpoint_dir
                saver.restore(sess, checkpoint_path)

            # global_step = checkpoint_path.split('/')[-1].split('-')[-1]

            sess.run(sync_op)

            # Get the number of training/validation steps per epoch
            tr_batches = int(TRAIN_DATA_SIZE /
                             (FLAGS.batch_size // FLAGS.num_gpu))
            if TRAIN_DATA_SIZE % (FLAGS.batch_size // FLAGS.num_gpu) > 0:
                tr_batches += 1
            val_batches = int(VALIDATE_DATA_SIZE /
                              (FLAGS.val_batch_size // FLAGS.num_gpu))
            if VALIDATE_DATA_SIZE % (FLAGS.val_batch_size //
                                     FLAGS.num_gpu) > 0:
                val_batches += 1

            # The filenames argument to the TFRecordDataset initializer can either be a string,
            # a list of strings, or a tf.Tensor of strings.
            train_record_filenames = os.path.join(FLAGS.dataset_dir,
                                                  'train.record')
            validate_record_filenames = os.path.join(FLAGS.dataset_dir,
                                                     'validate.record')

            ############################
            # Training loop.
            ############################
            for num_epoch in range(start_epoch,
                                   FLAGS.how_many_training_epochs):
                print("------------------------------------")
                print(" Epoch {} ".format(num_epoch))
                print("------------------------------------")

                sess.run(epoch_count_add)
                sess.run(
                    iterator.initializer,
                    feed_dict={tfrecord_filenames: train_record_filenames})
                for step in range(tr_batches):
                    filenames, train_batch_xs, train_batch_ys = sess.run(
                        next_batch)
                    # show_batch_data(filenames, train_batch_xs, train_batch_ys)
                    #
                    # augmented_batch_xs = aug_utils.aug(train_batch_xs)
                    # show_batch_data(filenames, augmented_batch_xs,
                    #                 train_batch_ys, 'aug')

                    # Run the graph with this batch of training data and learning rate policy.
                    lr, train_summary, train_top1_acc, train_loss, _ = \
                        sess.run([learning_rate, summary_op, top1_acc, loss, optimize_op],
                                 feed_dict={
                                     X: train_batch_xs,
                                     ground_truth: train_batch_ys,
                                     is_training: True,
                                     keep_prob: 0.8
                                 })
                    train_writer.add_summary(train_summary, num_epoch)
                    # train_writer.add_summary(grad_vals, num_epoch)
                    tf.compat.v1.logging.info(
                        'Epoch #%d, Step #%d, rate %.6f, top1_acc %.3f%%, loss %.5f'
                        % (num_epoch, step, lr, train_top1_acc, train_loss))

                ###################################################
                # Validate the model on the validation set
                ###################################################
                tf.compat.v1.logging.info('--------------------------')
                tf.compat.v1.logging.info(' Start validation ')
                tf.compat.v1.logging.info('--------------------------')

                total_val_losses = 0.0
                total_val_top1_acc = 0.0
                val_count = 0
                total_conf_matrix = None

                sess.run(
                    val_iterator.initializer,
                    feed_dict={tfrecord_filenames: validate_record_filenames})
                for step in range(val_batches):
                    filenames, validation_batch_xs, validation_batch_ys = sess.run(
                        val_next_batch)
                    # # TTA
                    # batch_size, n_crops, c, h, w = validation_batch_xs.shape
                    # # fuse batch size and ncrops
                    # tencrop_val_batch_xs = np.reshape(validation_batch_xs, (-1, c, h, w))
                    # show_batch_data(filenames, tencrop_val_batch_xs, validation_batch_ys)

                    # augmented_val_batch_xs = aug_utils.aug(tencrop_val_batch_xs)
                    # show_batch_data(filenames, augmented_val_batch_xs,
                    #                 validation_batch_ys, 'aug')

                    val_summary, val_loss, val_top1_acc, _confusion_matrix = sess.run(
                        [summary_op, loss, top1_acc, confusion_matrix],
                        feed_dict={
                            X: validation_batch_xs,
                            ground_truth: validation_batch_ys,
                            is_training: False,
                            keep_prob: 1.0
                        })
                    validation_writer.add_summary(val_summary, num_epoch)

                    total_val_losses += val_loss
                    total_val_top1_acc += val_top1_acc

                    # total_val_accuracy += val_top1_acc
                    val_count += 1
                    if total_conf_matrix is None:
                        total_conf_matrix = _confusion_matrix
                    else:
                        total_conf_matrix += _confusion_matrix

                total_val_losses /= val_count
                total_val_top1_acc /= val_count

                # total_val_accuracy /= val_count
                tf.compat.v1.logging.info('Confusion Matrix:\n %s' %
                                          total_conf_matrix)
                tf.compat.v1.logging.info('Validation loss = %.5f' %
                                          total_val_losses)
                tf.compat.v1.logging.info(
                    'Validation top1 accuracy = %.3f%% (N=%d)' %
                    (total_val_top1_acc, VALIDATE_DATA_SIZE))

                # periodic synchronization
                sess.run(sync_op)

                # Save the model checkpoint periodically.
                if (num_epoch <= FLAGS.how_many_training_epochs - 1):
                    # best_checkpoint_path = os.path.join(FLAGS.train_logdir, 'best_' + FLAGS.ckpt_name_to_save)
                    # tf.compat.v1.logging.info('Saving to "%s"', best_checkpoint_path)
                    # saver.save(sess, best_checkpoint_path, global_step=num_epoch)

                    # save & keep best model wrt. validation loss
                    best_ckpt_saver.handle(total_val_losses, sess, epoch_count)

                    if best_val_loss > total_val_losses:
                        best_val_loss = total_val_losses
                        best_val_acc = total_val_top1_acc

            chk_path = get_best_checkpoint(FLAGS.train_logdir,
                                           select_maximum_value=False)
            tf.compat.v1.logging.info(
                'training done. best_model val_loss=%.5f, top1_acc=%.3f%%, ckpt=%s'
                % (best_val_loss, best_val_acc, chk_path))