Esempio n. 1
0
def train(tfrecords, logdir, cfg, pretrained_model_path=None, trainable_scopes=None, checkpoint_exclude_scopes=None, restore_variables_with_moving_averages=False, restore_moving_averages=False, read_images=False):
    """
    Args:
        tfrecords (list)
        bbox_priors (np.array)
        logdir (str)
        cfg (EasyDict)
        pretrained_model_path (str) : path to a pretrained Inception Network
    """
    tf.logging.set_verbosity(tf.logging.INFO)

    graph = tf.Graph()

    # Force all Variables to reside on the CPU.
    with graph.as_default():

        # Create a variable to count the number of train() calls.
        global_step = slim.get_or_create_global_step()

        with tf.device('/cpu:0'):
            batch_dict = input_nodes(
                tfrecords=tfrecords,
                cfg=cfg.IMAGE_PROCESSING,
                num_epochs=None,
                batch_size=cfg.BATCH_SIZE,
                num_threads=cfg.NUM_INPUT_THREADS,
                shuffle_batch =cfg.SHUFFLE_QUEUE,
                random_seed=cfg.RANDOM_SEED,
                capacity=cfg.QUEUE_CAPACITY,
                min_after_dequeue=cfg.QUEUE_MIN,
                add_summaries=True,
                input_type='train',
                read_filenames=read_images
            )

            batched_one_hot_labels = slim.one_hot_encoding(batch_dict['labels'],
                                                        num_classes=cfg.NUM_CLASSES)

        # GVH: Doesn't seem to help to the poor queueing performance...
        # batch_queue = slim.prefetch_queue.prefetch_queue(
        #                   [batch_dict['inputs'], batched_one_hot_labels], capacity=2)
        # inputs, labels = batch_queue.dequeue()

        arg_scope = nets_factory.arg_scopes_map[cfg.MODEL_NAME](
            weight_decay=cfg.WEIGHT_DECAY,
            batch_norm_decay=cfg.BATCHNORM_MOVING_AVERAGE_DECAY,
            batch_norm_epsilon=cfg.BATCHNORM_EPSILON
        )

        with slim.arg_scope(arg_scope):
            logits, end_points = nets_factory.networks_map[cfg.MODEL_NAME](
                inputs=batch_dict['inputs'],
                num_classes=cfg.NUM_CLASSES,
                dropout_keep_prob=cfg.DROPOUT_KEEP_PROB,
                is_training=True
            )

            # Add the losses
            if 'AuxLogits' in end_points:
                tf.losses.softmax_cross_entropy(
                    logits=end_points['AuxLogits'], onehot_labels=batched_one_hot_labels,
                    label_smoothing=cfg.LABEL_SMOOTHING, weights=0.4, scope='aux_loss')

            tf.losses.softmax_cross_entropy(
                logits=logits, onehot_labels=batched_one_hot_labels, label_smoothing=cfg.LABEL_SMOOTHING, weights=1.0)



        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        # Summarize the losses
        for loss in tf.get_collection(tf.GraphKeys.LOSSES):
            summaries.add(tf.summary.scalar(name='losses/%s' % loss.op.name, tensor=loss))

        regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
        if regularization_losses:
            regularization_loss = tf.add_n(regularization_losses, name='regularization_loss')
            summaries.add(tf.summary.scalar(name='losses/regularization_loss', tensor=regularization_loss))

        total_loss = tf.losses.get_total_loss()
        summaries.add(tf.summary.scalar(name='losses/total_loss', tensor=total_loss))


        if 'MOVING_AVERAGE_DECAY' in cfg and cfg.MOVING_AVERAGE_DECAY > 0:
            moving_average_variables = slim.get_model_variables()
            ema = tf.train.ExponentialMovingAverage(
                decay=cfg.MOVING_AVERAGE_DECAY,
                num_updates=global_step
            )
        elif restore_variables_with_moving_averages or restore_moving_averages:
            # Perhaps we are finetuning the last layer of a pretrained model?
            # So we just need something to load in the moving averages, for use in get_init_function()
            moving_average_variables = None
            ema = tf.train.ExponentialMovingAverage(
                decay=1,
                num_updates=global_step
            )
        else:
            moving_average_variables = None
            ema = None


        # Calculate the learning rate schedule.
        lr = _configure_learning_rate(global_step, cfg)

        # Create an optimizer that performs gradient descent.
        optimizer = _configure_optimizer(lr, cfg)

        summaries.add(tf.summary.scalar(tensor=lr,
                                        name='learning_rate'))

        # Add the moving average update ops to the graph
        if ema != None and moving_average_variables != None:
            tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, ema.apply(moving_average_variables))

        trainable_vars = get_trainable_variables(trainable_scopes)
        train_op = slim.learning.create_train_op(total_loss=total_loss,
                                                 optimizer=optimizer,
                                                 global_step=global_step,
                                                 variables_to_train=trainable_vars,
                                                 clip_gradient_norm=cfg.CLIP_GRADIENT_NORM)

        # Merge all of the summaries
        summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES))
        summary_op = tf.summary.merge(inputs=list(summaries), name='summary_op')

        sess_config = tf.ConfigProto(
          log_device_placement=cfg.SESSION_CONFIG.LOG_DEVICE_PLACEMENT,
          allow_soft_placement = True,
          gpu_options = tf.GPUOptions(
              per_process_gpu_memory_fraction=cfg.SESSION_CONFIG.PER_PROCESS_GPU_MEMORY_FRACTION
          ),
          intra_op_parallelism_threads=cfg.SESSION_CONFIG.INTRA_OP_PARALLELISM_THREADS if 'INTRA_OP_PARALLELISM_THREADS' in cfg.SESSION_CONFIG else None,
          inter_op_parallelism_threads=cfg.SESSION_CONFIG.INTER_OP_PARALLELISM_THREADS if 'INTER_OP_PARALLELISM_THREADS' in cfg.SESSION_CONFIG else None
        )

        saver = tf.train.Saver(
          # Save all variables
          max_to_keep = cfg.MAX_TO_KEEP,
          keep_checkpoint_every_n_hours = cfg.KEEP_CHECKPOINT_EVERY_N_HOURS
        )

        # Run training.
        slim.learning.train(
            train_op=train_op,
            logdir=logdir,
            init_fn=get_init_function(logdir, pretrained_model_path, checkpoint_exclude_scopes, restore_variables_with_moving_averages=restore_variables_with_moving_averages, restore_moving_averages=restore_moving_averages, ema=ema),
            number_of_steps=cfg.NUM_TRAIN_ITERATIONS,
            save_summaries_secs=cfg.SAVE_SUMMARY_SECS,
            save_interval_secs=cfg.SAVE_INTERVAL_SECS,
            saver=saver,
            session_config=sess_config,
            summary_op = summary_op,
            log_every_n_steps = cfg.LOG_EVERY_N_STEPS
        )
Esempio n. 2
0
def test(tfrecords, checkpoint_path, save_dir, max_iterations,
         eval_interval_secs, cfg):
    """
    Args:
        tfrecords (list)
        checkpoint_path (str)
        savedir (str)
        max_iterations (int)
        cfg (EasyDict)
    """
    print(tfrecords, checkpoint_path, save_dir, max_iterations, cfg)
    tf.logging.set_verbosity(tf.logging.DEBUG)

    graph = tf.Graph()

    with graph.as_default():

        global_step = slim.get_or_create_global_step()

        with tf.device('/cpu:0'):
            batch_dict = inputs.input_nodes(tfrecords=tfrecords,
                                            cfg=cfg.IMAGE_PROCESSING,
                                            num_epochs=1,
                                            batch_size=cfg.BATCH_SIZE,
                                            num_threads=cfg.NUM_INPUT_THREADS,
                                            shuffle_batch=cfg.SHUFFLE_QUEUE,
                                            random_seed=cfg.RANDOM_SEED,
                                            capacity=cfg.QUEUE_CAPACITY,
                                            min_after_dequeue=cfg.QUEUE_MIN,
                                            add_summaries=False,
                                            input_type='test')

            batched_one_hot_labels = slim.one_hot_encoding(
                batch_dict['labels'], num_classes=cfg.NUM_CLASSES)

        arg_scope = nets_factory.arg_scopes_map[cfg.MODEL_NAME]()

        with slim.arg_scope(arg_scope):
            logits, end_points = nets_factory.networks_map[cfg.MODEL_NAME](
                inputs=batch_dict['inputs'],
                num_classes=cfg.NUM_CLASSES,
                is_training=False)

            predictions = end_points['Predictions']
            #labels = tf.squeeze(batch_dict['labels'])
            labels = batch_dict['labels']

            # Add the loss summary
            loss = tf.losses.softmax_cross_entropy(
                logits=logits,
                onehot_labels=batched_one_hot_labels,
                label_smoothing=0.,
                weights=1.0)

        if 'MOVING_AVERAGE_DECAY' in cfg and cfg.MOVING_AVERAGE_DECAY > 0:
            variable_averages = tf.train.ExponentialMovingAverage(
                cfg.MOVING_AVERAGE_DECAY, global_step)
            variables_to_restore = variable_averages.variables_to_restore(
                slim.get_model_variables())
            variables_to_restore[global_step.op.name] = global_step
        else:
            variables_to_restore = slim.get_variables_to_restore()
            variables_to_restore.append(global_step)

        # Define the metrics:
        metric_map = {
            'Accuracy':
            slim.metrics.streaming_accuracy(labels=labels,
                                            predictions=tf.argmax(
                                                predictions, 1)),
            loss.op.name:
            slim.metrics.streaming_mean(loss)
        }
        if len(cfg.ACCURACY_AT_K_METRIC) > 0:
            bool_labels = tf.ones([cfg.BATCH_SIZE], dtype=tf.bool)
            for k in cfg.ACCURACY_AT_K_METRIC:
                if k <= 1 or k > cfg.NUM_CLASSES:
                    continue
                in_top_k = tf.nn.in_top_k(predictions=predictions,
                                          targets=labels,
                                          k=k)
                metric_map['Accuracy_at_%s' %
                           k] = slim.metrics.streaming_accuracy(
                               labels=bool_labels, predictions=in_top_k)

        names_to_values, names_to_updates = slim.metrics.aggregate_metric_map(
            metric_map)

        # Print the summaries to screen.
        print_global_step = True
        for name, value in names_to_values.iteritems():
            summary_name = 'eval/%s' % name
            op = tf.summary.scalar(summary_name, value, collections=[])
            if print_global_step:
                op = tf.Print(op, [global_step], "Model Step ")
                print_global_step = False
            op = tf.Print(op, [value], summary_name)
            tf.add_to_collection(tf.GraphKeys.SUMMARIES, op)

        if max_iterations > 0:
            num_batches = max_iterations
        else:
            # This ensures that we make a single pass over all of the data.
            # We could use ceil if the batch queue is allowed to pad the last batch
            num_batches = np.floor(cfg.NUM_TEST_EXAMPLES /
                                   float(cfg.BATCH_SIZE))

        sess_config = tf.ConfigProto(
            log_device_placement=cfg.SESSION_CONFIG.LOG_DEVICE_PLACEMENT,
            allow_soft_placement=True,
            gpu_options=tf.GPUOptions(
                per_process_gpu_memory_fraction=cfg.SESSION_CONFIG.
                PER_PROCESS_GPU_MEMORY_FRACTION))

        if eval_interval_secs > 0:

            if not os.path.isdir(checkpoint_path):
                raise ValueError("checkpoint_path should be a path to a directory when " \
                                 "evaluating in a loop.")

            slim.evaluation.evaluation_loop(
                master='',
                checkpoint_dir=checkpoint_path,
                logdir=save_dir,
                num_evals=num_batches,
                initial_op=None,
                initial_op_feed_dict=None,
                eval_op=names_to_updates.values(),
                eval_op_feed_dict=None,
                final_op=None,
                final_op_feed_dict=None,
                summary_op=tf.summary.merge_all(),
                summary_op_feed_dict=None,
                variables_to_restore=variables_to_restore,
                eval_interval_secs=eval_interval_secs,
                max_number_of_evaluations=None,
                session_config=sess_config,
                timeout=None)

        else:
            if os.path.isdir(checkpoint_path):
                checkpoint_dir = checkpoint_path
                checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir)

                if checkpoint_path is None:
                    raise ValueError("Unable to find a model checkpoint in the " \
                                     "directory %s" % (checkpoint_dir,))

            tf.logging.info('Evaluating %s' % checkpoint_path)

            slim.evaluation.evaluate_once(
                master='',
                checkpoint_path=checkpoint_path,
                logdir=save_dir,
                num_evals=num_batches,
                eval_op=names_to_updates.values(),
                variables_to_restore=variables_to_restore,
                session_config=sess_config)
Esempio n. 3
0
def classify(tfrecords, checkpoint_path, save_path, max_iterations,
             save_logits, cfg):
    """
    Args:
        tfrecords (list)
        checkpoint_path (str)
        save_dir (str)
        max_iterations (int)
        save_logits (bool)
        cfg (EasyDict)
    """
    tf.logging.set_verbosity(tf.logging.DEBUG)

    graph = tf.Graph()

    with graph.as_default():

        global_step = slim.get_or_create_global_step()

        with tf.device('/cpu:0'):
            batch_dict = inputs.input_nodes(tfrecords=tfrecords,
                                            cfg=cfg.IMAGE_PROCESSING,
                                            num_epochs=1,
                                            batch_size=cfg.BATCH_SIZE,
                                            num_threads=cfg.NUM_INPUT_THREADS,
                                            shuffle_batch=cfg.SHUFFLE_QUEUE,
                                            random_seed=cfg.RANDOM_SEED,
                                            capacity=cfg.QUEUE_CAPACITY,
                                            min_after_dequeue=cfg.QUEUE_MIN,
                                            add_summaries=False,
                                            input_type='classification')

        arg_scope = nets_factory.arg_scopes_map[cfg.MODEL_NAME]()

        with slim.arg_scope(arg_scope):
            logits, end_points = nets_factory.networks_map[cfg.MODEL_NAME](
                inputs=batch_dict['inputs'],
                num_classes=cfg.NUM_CLASSES,
                is_training=False)

            predicted_labels = tf.argmax(end_points['Predictions'], 1)

        if 'MOVING_AVERAGE_DECAY' in cfg and cfg.MOVING_AVERAGE_DECAY > 0:
            variable_averages = tf.train.ExponentialMovingAverage(
                cfg.MOVING_AVERAGE_DECAY, global_step)
            variables_to_restore = variable_averages.variables_to_restore(
                slim.get_model_variables())
            variables_to_restore[global_step.op.name] = global_step
        else:
            variables_to_restore = slim.get_variables_to_restore()
            variables_to_restore.append(global_step)

        saver = tf.train.Saver(variables_to_restore, reshape=True)

        num_batches = max_iterations
        num_images = num_batches * cfg.BATCH_SIZE
        label_array = np.empty(num_images, dtype=np.int32)
        id_array = np.empty(num_images, dtype=np.object)
        fetches = [predicted_labels, batch_dict['ids']]
        if save_logits:
            fetches.append(logits)
            logits_array = np.empty((num_images, cfg.NUM_CLASSES),
                                    dtype=np.float32)

        if os.path.isdir(checkpoint_path):
            checkpoint_dir = checkpoint_path
            checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir)

            if checkpoint_path is None:
                raise ValueError("Unable to find a model checkpoint in the " \
                                 "directory %s" % (checkpoint_dir,))

        tf.logging.info('Classifying records using %s' % checkpoint_path)

        coord = tf.train.Coordinator()

        sess_config = tf.ConfigProto(
            log_device_placement=cfg.SESSION_CONFIG.LOG_DEVICE_PLACEMENT,
            allow_soft_placement=True,
            gpu_options=tf.GPUOptions(
                per_process_gpu_memory_fraction=cfg.SESSION_CONFIG.
                PER_PROCESS_GPU_MEMORY_FRACTION))
        sess = tf.Session(graph=graph, config=sess_config)

        with sess.as_default():

            tf.global_variables_initializer().run()
            tf.local_variables_initializer().run()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)

            try:

                # Restore from checkpoint
                saver.restore(sess, checkpoint_path)

                print_str = ', '.join(['Step: %d', 'Time/image (ms): %.1f'])

                step = 0
                while not coord.should_stop():

                    t = time.time()
                    outputs = sess.run(fetches)
                    dt = time.time() - t

                    idx1 = cfg.BATCH_SIZE * step
                    idx2 = idx1 + cfg.BATCH_SIZE
                    label_array[idx1:idx2] = outputs[0]
                    id_array[idx1:idx2] = outputs[1]
                    if save_logits:
                        logits_array[idx1:idx2] = outputs[2]

                    step += 1
                    print(print_str % (step, (dt / cfg.BATCH_SIZE) * 1000))

                    if max_iterations > 0 and step == max_iterations:
                        break

            except tf.errors.OutOfRangeError as e:
                pass

        coord.request_stop()
        coord.join(threads)

        # save the results
        if save_logits:
            np.savez(save_path,
                     labels=label_array,
                     ids=id_array,
                     logits=logits_array)
        else:
            np.savez(save_path, labels=label_array, ids=id_array)
def visualize_train_inputs(tfrecords, cfg, show_text_labels=False):

    graph = tf.Graph()
    sess = tf.Session(graph=graph)

    # run a session to look at the images...
    with sess.as_default(), graph.as_default():

        # Input Nodes
        with tf.device('/cpu:0'):
            batch_dict = input_nodes(tfrecords=tfrecords,
                                     cfg=cfg.IMAGE_PROCESSING,
                                     num_epochs=1,
                                     batch_size=cfg.BATCH_SIZE,
                                     num_threads=cfg.NUM_INPUT_THREADS,
                                     shuffle_batch=cfg.SHUFFLE_QUEUE,
                                     random_seed=cfg.RANDOM_SEED,
                                     capacity=cfg.QUEUE_CAPACITY,
                                     min_after_dequeue=cfg.QUEUE_MIN,
                                     add_summaries=False,
                                     input_type='visualize',
                                     fetch_text_labels=show_text_labels)

        # Convert float images to uint8 images
        image_to_convert = tf.placeholder(dtype=tf.float32,
                                          shape=[
                                              cfg.IMAGE_PROCESSING.INPUT_SIZE,
                                              cfg.IMAGE_PROCESSING.INPUT_SIZE,
                                              3
                                          ])
        uint8_image = tf.image.convert_image_dtype(image_to_convert,
                                                   dtype=tf.uint8)

        coord = tf.train.Coordinator()
        tf.global_variables_initializer().run()
        tf.local_variables_initializer().run()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        plt.ion()
        done = False
        while not done:

            output = sess.run(batch_dict)

            original_images = output['original_inputs']
            distorted_images = output['inputs']
            image_ids = output['ids']
            labels = output['labels']
            if show_text_labels:
                text_labels = output['text_labels']

            for b in range(cfg.BATCH_SIZE):

                original_image = original_images[b]
                distorted_image = distorted_images[b]

                if original_image.dtype != np.uint8:
                    original_image = sess.run(
                        uint8_image, {image_to_convert: original_image})

                if distorted_image.dtype != np.uint8:
                    distorted_image = sess.run(
                        uint8_image, {image_to_convert: distorted_image})

                image_id = image_ids[b]
                label = labels[b]

                fig = plt.figure('Train Inputs')

                if show_text_labels:
                    text_label = text_labels[b]
                    st = fig.suptitle("Image: %s\nLabel: %d\nText: %s" %
                                      (image_id, label, text_label),
                                      fontsize=12)
                else:
                    st = fig.suptitle("Image: %s\nLabel: %d" %
                                      (image_id, label),
                                      fontsize=12)

                plt.subplot(2, 1, 1)
                plt.imshow(original_image)
                plt.title("Original")
                plt.axis('off')

                plt.subplot(2, 1, 2)
                plt.imshow(distorted_image)
                plt.title("Modified")
                plt.axis('off')

                # Shift the subplots down a bit to make room for the super title
                st.set_y(0.95)
                fig.subplots_adjust(top=0.75)

                plt.show(block=False)

                t = raw_input("Press Enter to view next image. Press any key followed " \
                              "by Enter to quite: ")
                if t != '':
                    done = True
                    break
                plt.clf()