def variables_to_restore(scope=None, strip_scope=False):
    """Returns a list of variables to restore for the specified list of methods.
  It is supposed that variable name starts with the method's scope (a prefix
  returned by _method_scope function).
  Args:
    methods_names: a list of names of configurable methods.
    strip_scope: if True will return variable names without method's scope.
      If methods_names is None will return names unchanged.
    model_scope: a scope for a whole model.
  Returns:
    a dictionary mapping variable names to variables for restore.
  """
    if scope:
        variable_map = {}
        method_variables = tf_slim.get_variables_to_restore(include=[scope])
        for var in method_variables:
            if strip_scope:
                var_name = var.op.name[len(scope) + 1:]
            else:
                var_name = var.op.name
            variable_map[var_name] = var

        return variable_map
    else:
        return {v.op.name: v for v in tf_slim.get_variables_to_restore()}
Exemplo n.º 2
0
def reload_checkpoint(checkpoint_path, session):
    """Load TF checkpoint and restore into agent session"""
    global_vars = set([x.name for x in tf.compat.v1.global_variables()])
    ckpt_vars = [
        '{}:0'.format(name)
        for name, _ in tf.train.list_variables(checkpoint_path)
    ]
    include_vars = list(global_vars.intersection(set(ckpt_vars)))
    variables_to_restore = tf_slim.get_variables_to_restore(
        include=include_vars)

    if variables_to_restore:
        reloader = tf.compat.v1.train.Saver(var_list=variables_to_restore)
        reloader.restore(session, checkpoint_path)
        logging.info('Done restoring from %s', checkpoint_path)
    else:
        logging.info('Nothing to restore!')
Exemplo n.º 3
0
    def reload_checkpoint(self, checkpoint_path):
        """Reload variables from a fully specified checkpoint.

    Args:
      checkpoint_path: string, full path to a checkpoint to reload.
    """
        assert checkpoint_path
        variables_to_restore = tf_slim.get_variables_to_restore()
        reloader = tf.train.Saver(var_list=variables_to_restore)
        reloader.restore(self._sess, checkpoint_path)

        var = [
            v for v in variables_to_restore
            if v.name == 'Online/fully_connected_1/weights:0'
        ][0]
        wts = self._sess.run(var)
        var = [
            v for v in variables_to_restore
            if v.name == 'Online/fully_connected_1/biases:0'
        ][0]
        biases = self._sess.run(var)
        num_wts = wts.size + biases.size

        target_var = [
            v for v in variables_to_restore
            if v.name == 'Target/fully_connected_1/weights:0'
        ][0]
        target_wts = self._sess.run(target_var)
        target_var = [
            v for v in variables_to_restore
            if v.name == 'Target/fully_connected_1/biases:0'
        ][0]
        target_biases = self._sess.run(target_var)
        self.target_wts = target_wts
        self.target_biases = target_biases

        self.last_layer_weights = wts
        self.last_layer_biases = biases
        self.last_layer_wts = np.append(wts,
                                        np.expand_dims(biases, axis=0),
                                        axis=0)
        self.last_layer_wts = self.last_layer_wts.reshape((num_wts, ),
                                                          order='F')
Exemplo n.º 4
0
 def reload_checkpoint(self, checkpoint_path, use_legacy_checkpoint=False):
     if use_legacy_checkpoint:
         variables_to_restore = atari_lib.maybe_transform_variable_names(
             tf.compat.v1.global_variables(), legacy_checkpoint_load=True)
     else:
         global_vars = set(
             [x.name for x in tf.compat.v1.global_variables()])
         ckpt_vars = [
             '{}:0'.format(name)
             for name, _ in tf.train.list_variables(checkpoint_path)
         ]
         include_vars = list(global_vars.intersection(set(ckpt_vars)))
         variables_to_restore = tf_slim.get_variables_to_restore(
             include=include_vars)
     if variables_to_restore:
         reloader = tf.compat.v1.train.Saver(var_list=variables_to_restore)
         reloader.restore(self._sess, checkpoint_path)
         logging.info('Done restoring from %s', checkpoint_path)
     else:
         logging.info('Nothing to restore!')
Exemplo n.º 5
0
def get_checkpoint_init_fn():
    """Returns the checkpoint init_fn if the checkpoint is provided."""
    if FLAGS.fine_tune_checkpoint:
        variables_to_restore = slim.get_variables_to_restore()
        global_step_reset = tf.assign(tf.train.get_or_create_global_step(), 0)
        # When restoring from a floating point model, the min/max values for
        # quantized weights and activations are not present.
        # We instruct slim to ignore variables that are missing during restoration
        # by setting ignore_missing_vars=True
        slim_init_fn = slim.assign_from_checkpoint_fn(
            FLAGS.fine_tune_checkpoint,
            variables_to_restore,
            ignore_missing_vars=True)

        def init_fn(sess):
            slim_init_fn(sess)
            # If we are restoring from a floating point model, we need to initialize
            # the global step to zero for the exponential decay to result in
            # reasonable learning rates.
            sess.run(global_step_reset)

        return init_fn
    else:
        return None
Exemplo n.º 6
0
def main(unused_argv=None):
    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        # Forces all input processing onto CPU in order to reserve the GPU for the
        # forward inference and back-propagation.
        device = '/cpu:0' if not FLAGS.ps_tasks else '/job:worker/cpu:0'
        with tf.device(
                tf.train.replica_device_setter(FLAGS.ps_tasks,
                                               worker_device=device)):
            # Load content images
            content_inputs_, _ = image_utils.imagenet_inputs(
                FLAGS.batch_size, FLAGS.image_size)

            # Loads style images.
            [style_inputs_, _,
             style_inputs_orig_] = image_utils.arbitrary_style_image_inputs(
                 FLAGS.style_dataset_file,
                 batch_size=FLAGS.batch_size,
                 image_size=FLAGS.image_size,
                 shuffle=True,
                 center_crop=FLAGS.center_crop,
                 augment_style_images=FLAGS.augment_style_images,
                 random_style_image_size=FLAGS.random_style_image_size)

        with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
            # Process style and content weight flags.
            content_weights = ast.literal_eval(FLAGS.content_weights)
            style_weights = ast.literal_eval(FLAGS.style_weights)

            # Define the model
            stylized_images, \
            true_loss, \
            _, \
            bottleneck_feat = build_mobilenet_model.build_mobilenet_model(
                content_inputs_,
                style_inputs_,
                mobilenet_trainable=True,
                style_params_trainable=False,
                style_prediction_bottleneck=100,
                adds_losses=True,
                content_weights=content_weights,
                style_weights=style_weights,
                total_variation_weight=FLAGS.total_variation_weight,
            )

            _, inception_bottleneck_feat = build_model.style_prediction(
                style_inputs_,
                [],
                [],
                is_training=False,
                trainable=False,
                inception_end_point='Mixed_6e',
                style_prediction_bottleneck=100,
                reuse=None,
            )

            print('PRINTING TRAINABLE VARIABLES')
            for x in tf.trainable_variables():
                print(x)

            mse_loss = tf.losses.mean_squared_error(inception_bottleneck_feat,
                                                    bottleneck_feat)
            total_loss = mse_loss
            if FLAGS.use_true_loss:
                true_loss = FLAGS.true_loss_weight * true_loss
                total_loss += true_loss

            if FLAGS.use_true_loss:
                tf.summary.scalar('mse', mse_loss)
                tf.summary.scalar('true_loss', true_loss)
            tf.summary.scalar('total_loss', total_loss)
            tf.summary.image('image/0_content_inputs', content_inputs_, 3)
            tf.summary.image('image/1_style_inputs_orig', style_inputs_orig_,
                             3)
            tf.summary.image('image/2_style_inputs_aug', style_inputs_, 3)
            tf.summary.image('image/3_stylized_images', stylized_images, 3)

            mobilenet_variables_to_restore = slim.get_variables_to_restore(
                include=['MobilenetV2'], exclude=['global_step'])

            optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate)
            train_op = slim.learning.create_train_op(
                total_loss,
                optimizer,
                clip_gradient_norm=FLAGS.clip_gradient_norm,
                summarize_gradients=False)

            init_fn = slim.assign_from_checkpoint_fn(
                FLAGS.initial_checkpoint,
                slim.get_variables_to_restore(
                    exclude=['MobilenetV2', 'mobilenet_conv', 'global_step']))
            init_pretrained_mobilenet = slim.assign_from_checkpoint_fn(
                FLAGS.mobilenet_checkpoint, mobilenet_variables_to_restore)

            def init_sub_networks(session):
                init_fn(session)
                init_pretrained_mobilenet(session)

            slim.learning.train(train_op=train_op,
                                logdir=os.path.expanduser(FLAGS.train_dir),
                                master=FLAGS.master,
                                is_chief=FLAGS.task == 0,
                                number_of_steps=FLAGS.train_steps,
                                init_fn=init_sub_networks,
                                save_summaries_secs=FLAGS.save_summaries_secs,
                                save_interval_secs=FLAGS.save_interval_secs)
Exemplo n.º 7
0
def main(_):
  if not FLAGS.dataset_dir:
    raise ValueError('You must supply the dataset directory with --dataset_dir')

  tf.logging.set_verbosity(tf.logging.INFO)
  with tf.Graph().as_default():
    tf_global_step = slim.get_or_create_global_step()

    ######################
    # Select the dataset #
    ######################
    dataset = dataset_factory.get_dataset(
        FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)

    ####################
    # Select the model #
    ####################
    network_fn = nets_factory.get_network_fn(
        FLAGS.model_name,
        num_classes=(dataset.num_classes - FLAGS.labels_offset),
        is_training=False)

    ##############################################################
    # Create a dataset provider that loads data from the dataset #
    ##############################################################
    provider = slim.dataset_data_provider.DatasetDataProvider(
        dataset,
        shuffle=False,
        common_queue_capacity=2 * FLAGS.batch_size,
        common_queue_min=FLAGS.batch_size)
    [image, label] = provider.get(['image', 'label'])
    label -= FLAGS.labels_offset

    #####################################
    # Select the preprocessing function #
    #####################################
    preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
    image_preprocessing_fn = preprocessing_factory.get_preprocessing(
        preprocessing_name,
        is_training=False)

    eval_image_size = FLAGS.eval_image_size or network_fn.default_image_size

    image = image_preprocessing_fn(image, eval_image_size, eval_image_size)

    images, labels = tf.train.batch(
        [image, label],
        batch_size=FLAGS.batch_size,
        num_threads=FLAGS.num_preprocessing_threads,
        capacity=5 * FLAGS.batch_size)

    ####################
    # Define the model #
    ####################
    logits, _ = network_fn(images)

    #if FLAGS.quantize:
    #  tf.contrib.quantize.create_eval_graph()

    if FLAGS.moving_average_decay:
      variable_averages = tf.train.ExponentialMovingAverage(
          FLAGS.moving_average_decay, tf_global_step)
      variables_to_restore = variable_averages.variables_to_restore(
          slim.get_model_variables())
      variables_to_restore[tf_global_step.op.name] = tf_global_step
    else:
      variables_to_restore = slim.get_variables_to_restore()

    predictions = tf.argmax(logits, 1)
    labels = tf.squeeze(labels)

    # Define the metrics:
    names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
        'Accuracy': slim.metrics.streaming_accuracy(predictions, labels),
        'Recall_5': slim.metrics.streaming_recall_at_k(
            logits, labels, 5),
    })

    # Print the summaries to screen.
    for name, value in names_to_values.items():
      summary_name = 'eval/%s' % name
      op = tf.summary.scalar(summary_name, value, collections=[])
      op = tf.Print(op, [value], summary_name)
      tf.add_to_collection(tf.compat.v1.GraphKeys.SUMMARIES, op)

    # TODO(sguada) use num_epochs=1
    if FLAGS.max_num_batches:
      num_batches = FLAGS.max_num_batches
    else:
      # This ensures that we make a single pass over all of the data.
      num_batches = math.ceil(dataset.num_samples / float(FLAGS.batch_size))

    if tf.gfile.IsDirectory(FLAGS.checkpoint_path):
      checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
    else:
      checkpoint_path = FLAGS.checkpoint_path

    tf.logging.info('Evaluating %s' % checkpoint_path)

    slim.evaluation.evaluate_once(
        master=FLAGS.master,
        checkpoint_path=checkpoint_path,
        logdir=FLAGS.eval_dir,
        num_evals=num_batches,
        eval_op=list(names_to_updates.values()),
        variables_to_restore=variables_to_restore)
Exemplo n.º 8
0
def main(unused_argv=None):
  tf.logging.set_verbosity(tf.logging.INFO)
  if not tf.gfile.Exists(FLAGS.output_dir):
    tf.gfile.MkDir(FLAGS.output_dir)

  with tf.Graph().as_default(), tf.Session() as sess:
    # Defines place holder for the style image.
    style_img_ph = tf.placeholder(tf.float32, shape=[None, None, 3])
    if FLAGS.style_square_crop:
      style_img_preprocessed = image_utils.center_crop_resize_image(
          style_img_ph, FLAGS.style_image_size)
    else:
      style_img_preprocessed = image_utils.resize_image(style_img_ph,
                                                        FLAGS.style_image_size)

    # Defines place holder for the content image.
    content_img_ph = tf.placeholder(tf.float32, shape=[None, None, 3])
    if FLAGS.content_square_crop:
      content_img_preprocessed = image_utils.center_crop_resize_image(
          content_img_ph, FLAGS.image_size)
    else:
      content_img_preprocessed = image_utils.resize_image(
          content_img_ph, FLAGS.image_size)

    # Defines the model.
    stylized_images, _, _, bottleneck_feat = build_model.build_model(
        content_img_preprocessed,
        style_img_preprocessed,
        trainable=False,
        is_training=False,
        inception_end_point='Mixed_6e',
        style_prediction_bottleneck=100,
        adds_losses=False)

    if tf.gfile.IsDirectory(FLAGS.checkpoint):
      checkpoint = tf.train.latest_checkpoint(FLAGS.checkpoint)
    else:
      checkpoint = FLAGS.checkpoint
      tf.logging.info('loading latest checkpoint file: {}'.format(checkpoint))

    init_fn = slim.assign_from_checkpoint_fn(checkpoint,
                                             slim.get_variables_to_restore())
    sess.run([tf.local_variables_initializer()])
    init_fn(sess)

    # Gets the list of the input style images.
    style_img_list = tf.gfile.Glob(FLAGS.style_images_paths)
    if len(style_img_list) > FLAGS.maximum_styles_to_evaluate:
      np.random.seed(1234)
      style_img_list = np.random.permutation(style_img_list)
      style_img_list = style_img_list[:FLAGS.maximum_styles_to_evaluate]

    # Gets list of input content images.
    content_img_list = tf.gfile.Glob(FLAGS.content_images_paths)

    for content_i, content_img_path in enumerate(content_img_list):
      content_img_np = image_utils.load_np_image_uint8(content_img_path)[:, :, :
                                                                         3]
      content_img_name = os.path.basename(content_img_path)[:-4]

      # Saves preprocessed content image.
      inp_img_croped_resized_np = sess.run(
          content_img_preprocessed, feed_dict={
              content_img_ph: content_img_np
          })
      image_utils.save_np_image(inp_img_croped_resized_np,
                                os.path.join(FLAGS.output_dir,
                                             '%s.jpg' % (content_img_name)))

      # Computes bottleneck features of the style prediction network for the
      # identity transform.
      identity_params = sess.run(
          bottleneck_feat, feed_dict={style_img_ph: content_img_np})

      for style_i, style_img_path in enumerate(style_img_list):
        if style_i > FLAGS.maximum_styles_to_evaluate:
          break
        style_img_name = os.path.basename(style_img_path)[:-4]
        style_image_np = image_utils.load_np_image_uint8(style_img_path)[:, :, :
                                                                         3]

        if style_i % 10 == 0:
          tf.logging.info('Stylizing (%d) %s with (%d) %s' %
                          (content_i, content_img_name, style_i,
                           style_img_name))

        # Saves preprocessed style image.
        style_img_croped_resized_np = sess.run(
            style_img_preprocessed, feed_dict={
                style_img_ph: style_image_np
            })
        image_utils.save_np_image(style_img_croped_resized_np,
                                  os.path.join(FLAGS.output_dir,
                                               '%s.jpg' % (style_img_name)))

        # Computes bottleneck features of the style prediction network for the
        # given style image.
        style_params = sess.run(
            bottleneck_feat, feed_dict={style_img_ph: style_image_np})

        interpolation_weights = ast.literal_eval(FLAGS.interpolation_weights)
        # Interpolates between the parameters of the identity transform and
        # style parameters of the given style image.
        for interp_i, wi in enumerate(interpolation_weights):
          stylized_image_res = sess.run(
              stylized_images,
              feed_dict={
                  bottleneck_feat:
                      identity_params * (1 - wi) + style_params * wi,
                  content_img_ph:
                      content_img_np
              })

          # Saves stylized image.
          image_utils.save_np_image(
              stylized_image_res,
              os.path.join(FLAGS.output_dir, '%s_stylized_%s_%d.jpg' %
                           (content_img_name, style_img_name, interp_i)))
Exemplo n.º 9
0
def run_training(build_graph_fn,
                 train_dir,
                 num_training_steps=None,
                 summary_frequency=10,
                 save_checkpoint_secs=60,
                 checkpoints_to_keep=10,
                 keep_checkpoint_every_n_hours=1,
                 master='',
                 task=0,
                 num_ps_tasks=0,
                 warm_start_bundle_file=None):
    """Runs the training loop.

  Args:
    build_graph_fn: A function that builds the graph ops.
    train_dir: The path to the directory where checkpoints and summary events
        will be written to.
    num_training_steps: The number of steps to train for before exiting.
    summary_frequency: The number of steps between each summary. A summary is
        when graph values from the last step are logged to the console and
        written to disk.
    save_checkpoint_secs: The frequency at which to save checkpoints, in
        seconds.
    checkpoints_to_keep: The number of most recent checkpoints to keep in
       `train_dir`. Keeps all if set to 0.
    keep_checkpoint_every_n_hours: Keep a checkpoint every N hours, even if it
        results in more checkpoints than checkpoints_to_keep.
    master: URL of the Tensorflow master.
    task: Task number for this worker.
    num_ps_tasks: Number of parameter server tasks.
    warm_start_bundle_file: Path to a sequence generator bundle file that will
        be used to initialize the model weights for fine-tuning.
  """
    with tf.Graph().as_default():
        with tf.device(tf.train.replica_device_setter(num_ps_tasks)):
            build_graph_fn()

            global_step = tf.train.get_or_create_global_step()
            loss = tf.get_collection('loss')[0]
            perplexity = tf.get_collection('metrics/perplexity')[0]
            accuracy = tf.get_collection('metrics/accuracy')[0]
            train_op = tf.get_collection('train_op')[0]

            logging_dict = {
                'Global Step': global_step,
                'Loss': loss,
                'Perplexity': perplexity,
                'Accuracy': accuracy
            }
            hooks = [
                tf.train.NanTensorHook(loss),
                tf.train.LoggingTensorHook(logging_dict,
                                           every_n_iter=summary_frequency),
                tf.train.StepCounterHook(output_dir=train_dir,
                                         every_n_steps=summary_frequency)
            ]
            if num_training_steps:
                hooks.append(tf.train.StopAtStepHook(num_training_steps))

            with tempfile.TemporaryDirectory() as tempdir:
                if warm_start_bundle_file:
                    # We are fine-tuning from a pretrained bundle. Unpack the bundle and
                    # save its checkpoint to a temporary directory.
                    warm_start_bundle_file = os.path.expanduser(
                        warm_start_bundle_file)
                    bundle = sequence_generator_bundle.read_bundle_file(
                        warm_start_bundle_file)
                    checkpoint_filename = os.path.join(tempdir, 'model.ckpt')
                    with tf.gfile.Open(checkpoint_filename, 'wb') as f:
                        # For now, we support only 1 checkpoint file.
                        f.write(bundle.checkpoint_file[0])
                    variables_to_restore = tf_slim.get_variables_to_restore(
                        exclude=['global_step', '.*Adam.*', 'beta.*_power'])
                    init_op, init_feed_dict = tf_slim.assign_from_checkpoint(
                        checkpoint_filename, variables_to_restore)
                    init_fn = lambda scaffold, sess: sess.run(
                        init_op, init_feed_dict)
                else:
                    init_fn = None

                scaffold = tf.train.Scaffold(
                    init_fn=init_fn,
                    saver=tf.train.Saver(max_to_keep=checkpoints_to_keep,
                                         keep_checkpoint_every_n_hours=
                                         keep_checkpoint_every_n_hours))

                tf.logging.info('Starting training loop...')
                tf_slim.training.train(
                    train_op=train_op,
                    logdir=train_dir,
                    scaffold=scaffold,
                    hooks=hooks,
                    save_checkpoint_secs=save_checkpoint_secs,
                    save_summaries_steps=summary_frequency,
                    master=master,
                    is_chief=task == 0)
                tf.logging.info('Training complete.')
    def __init__(self,
                 network_name,
                 checkpoint_path,
                 batch_size,
                 image_size=None):
        self._network_name = network_name
        self._checkpoint_path = checkpoint_path
        self._batch_size = batch_size
        self._image_size = image_size
        self._layer = {}

        self._global_step = tf.train.get_or_create_global_step()

        # Retrieve the function that returns logits and endpoints
        self._network_fn = nets_factory.get_network_fn(self._network_name,
                                                       num_classes=num_classes,
                                                       is_training=False)

        # Retrieve the model scope from network factory
        self._model_scope = nets_factory.arg_scopes_map[self._network_name]

        # Fetch the default image size
        self._image_size = self._network_fn.default_image_size
        self._filename_queue = tf.FIFOQueue(100000, [tf.string],
                                            shapes=[[]],
                                            name="filename_queue")
        self._pl_image_files = tf.placeholder(tf.string,
                                              shape=[None],
                                              name="image_file_list")
        self._enqueue_op = self._filename_queue.enqueue_many(
            [self._pl_image_files])
        self._num_in_queue = self._filename_queue.size()

        self._batch_from_queue, self._batch_filenames = self._preproc_image_batch(
            self._batch_size, num_threads=4)

        #self._image_batch = tf.placeholder_with_default(
        #        self._batch_from_queue, shape=[self._batch_size, _STRIDE, self._image_size, self._image_size, 3])
        self._image_batch = tf.placeholder(
            tf.float32, [batch_size, _STRIDE, image_size, image_size, 3])

        # Retrieve the logits and network endpoints (for extracting activations)
        # Note: endpoints is a dictionary with endpoints[name] = tf.Tensor
        self._logits, self._endpoints = self._network_fn(self._image_batch)

        # Find the checkpoint file
        checkpoint_path = self._checkpoint_path
        if tf.gfile.IsDirectory(self._checkpoint_path):
            checkpoint_path = tf.train.latest_checkpoint(self._checkpoint_path)

        # Load pre-trained weights into the model
        variables_to_restore = slim.get_variables_to_restore()
        restore_fn = slim.assign_from_checkpoint_fn(self._checkpoint_path,
                                                    variables_to_restore)

        # Start the session and load the pre-trained weights
        self._sess = tf.Session()
        restore_fn(self._sess)

        # Local variables initializer, needed for queues etc.
        self._sess.run(tf.local_variables_initializer())

        # Managing the queues and threads
        self._coord = tf.train.Coordinator()
        self._threads = tf.train.start_queue_runners(coord=self._coord,
                                                     sess=self._sess)
def main(_):
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        _ = slim.get_or_create_global_step(
        )  # Required when creating the session.

        ######################
        # Select the dataset #
        ######################
        dataset = dataset_factory.get_dataset(FLAGS.dataset_name,
                                              FLAGS.dataset_split_name,
                                              FLAGS.dataset_dir)

        #########################
        # Configure the network #
        #########################
        inception_params = network_params.InceptionV3FCNParams(
            receptive_field_size=FLAGS.receptive_field_size,
            prelogit_dropout_keep_prob=0.8,
            depth_multiplier=0.1,
            min_depth=16,
            inception_fcn_stride=0,
        )
        conv_params = network_params.ConvScopeParams(
            dropout=False,
            dropout_keep_prob=0.8,
            batch_norm=True,
            batch_norm_decay=0.99,
            l2_weight_decay=4e-05,
        )
        network_fn = inception_v3_fcn.get_inception_v3_fcn_network_fn(
            inception_params,
            conv_params,
            num_classes=dataset.num_classes,
            is_training=False,
        )

        ##############################################################
        # Create a dataset provider that loads data from the dataset #
        ##############################################################
        provider = slim.dataset_data_provider.DatasetDataProvider(
            dataset,
            shuffle=False,
            common_queue_capacity=2 * FLAGS.batch_size,
            common_queue_min=FLAGS.batch_size)
        [image, label] = provider.get(['image', 'label'])

        #####################################
        # Select the preprocessing function #
        #####################################
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            'inception_v3', is_training=False)
        eval_image_size = FLAGS.receptive_field_size
        image = image_preprocessing_fn(image, eval_image_size, eval_image_size)
        images, labels = tf.train.batch([image, label],
                                        batch_size=FLAGS.batch_size,
                                        num_threads=PREPROCESSING_THREADS,
                                        capacity=5 * FLAGS.batch_size)

        ####################
        # Define the model #
        ####################
        logits, _ = network_fn(images)

        variables_to_restore = slim.get_variables_to_restore()

        predictions = tf.argmax(logits, 1)
        labels = tf.squeeze(labels)

        # Define the metrics:
        names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
            'Accuracy':
            slim.metrics.streaming_accuracy(predictions, labels),
            'Recall_2':
            slim.metrics.streaming_recall_at_k(logits, labels, 2),
        })

        # Print the summaries to screen.
        for name, value in names_to_values.items():
            summary_name = 'eval/%s' % name
            op = tf.summary.scalar(summary_name, value, collections=[])
            op = tf.Print(op, [value], summary_name)
            tf.add_to_collection(tf.GraphKeys.SUMMARIES, op)

        # This ensures that we make a single pass over all of the data.
        num_batches = math.ceil(dataset.num_samples / float(FLAGS.batch_size))

        if tf.gfile.IsDirectory(FLAGS.checkpoint_path):
            checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
        else:
            checkpoint_path = FLAGS.checkpoint_path

        tf.logging.info('Evaluating %s', checkpoint_path)

        slim.evaluation.evaluate_once(
            master='',
            checkpoint_path=checkpoint_path,
            logdir=FLAGS.eval_dir,
            num_evals=num_batches,
            eval_op=list(names_to_updates.values()),
            session_config=tf.ConfigProto(allow_soft_placement=True),
            variables_to_restore=variables_to_restore)
Exemplo n.º 12
0
    def initialize_backbone_from_pretrained_weights(self, path_to_pretrained_weights):

        variables_to_restore = slim.get_variables_to_restore(exclude=['global_step'])
        valid_prefix = 'backbone/'
        tf.compat.v1.train.init_from_checkpoint(path_to_pretrained_weights, {v.name[len(valid_prefix):].split(':')[0]: v for v in variables_to_restore if v.name.startswith(valid_prefix)})
Exemplo n.º 13
0
def predict(model_root, datasets_dir, model_name, test_image_name):
    with tf.Graph().as_default():
        tf_global_step = slim.get_or_create_global_step()

        test_image = os.path.join(datasets_dir, test_image_name)

        # dataset = convert_data.get_datasets('test',dataset_dir=datasets_dir)

        network_fn = net_select.get_network_fn(model_name,
                                               num_classes=20,
                                               is_training=False)
        batch_size = 1
        eval_image_size = network_fn.default_image_size

        # images, images_raw, labels = load_batch(datasets_dir,
        #                                         height=eval_image_size,
        #                                         width=eval_image_size)

        image_preprocessing_fn = preprocessing_select.get_preprocessing(
            model_name, is_training=False)

        image_data = tf.io.read_file(test_image)
        image_data = tf.image.decode_jpeg(image_data, channels=3)
        image_data = image_preprocessing_fn(image_data, eval_image_size,
                                            eval_image_size)
        image_data = tf.expand_dims(image_data, 0)

        logits_1, end_points_1 = network_fn(image_data)
        attention_maps = tf.reduce_mean(end_points_1['attention_maps'],
                                        axis=-1,
                                        keepdims=True)
        attention_maps = tf.image.resize(attention_maps,
                                         [eval_image_size, eval_image_size],
                                         method=tf.image.ResizeMethod.BILINEAR)
        bboxes = tf_v1.py_func(mask2bbox, [attention_maps], [tf.float32])
        bboxes = tf.reshape(bboxes, [batch_size, 4])
        # print(bboxes)
        box_ind = tf.range(batch_size, dtype=tf.int32)

        images = tf.image.crop_and_resize(
            image_data,
            bboxes,
            box_ind,
            crop_size=[eval_image_size, eval_image_size])
        logits_2, end_points_2 = network_fn(images, reuse=True)

        logits = tf.math.log(
            tf.nn.softmax(logits_1) * 0.5 + tf.nn.softmax(logits_2) * 0.5)

        checkpoint_path = os.path.join(model_root, model_name)

        if tf.io.gfile.isdir(checkpoint_path):
            checkpoint_path = tf.train.latest_checkpoint(checkpoint_path)
        else:
            checkpoint_path = checkpoint_path

        init_fn = slim.assign_from_checkpoint_fn(
            checkpoint_path, slim.get_variables_to_restore())

        # with tf_v1.Session() as sess:
        #     with slim.queues.QueueRunners(sess):
        #         sess.run(tf_v1.initialize_local_variables())
        #         init_fn(sess)
        #         np_probabilities, np_images_raw, np_labels = sess.run([logits, images_raw, labels])
        #
        #         for i in range(batch_size):
        #             image = np_images_raw[i, :, :, :]
        #             true_label = np_labels[i]
        #             predicted_label = np.argmax(np_probabilities[i, :])
        #             print('true is {}, predict is {}'.format(true_label, predicted_label))

        with tf_v1.Session() as sess:
            with slim.queues.QueueRunners(sess):
                sess.run(tf_v1.initialize_local_variables())
                init_fn(sess)
                np_images, np_probabilities = sess.run([image_data, logits])
                predicted_label = np.argmax(np_probabilities[0, :])
                print(predicted_label)
Exemplo n.º 14
0
def main(_):
  if not FLAGS.dataset_dir:
    raise ValueError('You must supply the dataset directory with --dataset_dir')

  tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
  with tf.Graph().as_default():
    tf_global_step = slim.get_or_create_global_step()

    ######################
    # Select the dataset #
    ######################
    dataset = dataset_factory.get_dataset(
        FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)

    ####################
    # Select the model #
    ####################
    n_hash = FLAGS.number_hashing_functions
    L_vec = FLAGS.neuron_vector_length
    quant_params = []
    for i in range(len(n_hash)):
        quant_params.append([int(n_hash[i]), int(L_vec[i])])

    network_fn = nets_factory.get_network_fn(
        FLAGS.model_name,
        num_classes=(dataset.num_classes - FLAGS.labels_offset),
        quant_params=quant_params, is_training=False)
#     network_fn = nets_factory.get_network_fn(
#         FLAGS.model_name,
#         num_classes=(dataset.num_classes - FLAGS.labels_offset),
#         is_training=False)

    ##############################################################
    # Create a dataset provider that loads data from the dataset #
    ##############################################################
    provider = slim.dataset_data_provider.DatasetDataProvider(
        dataset,
        shuffle=False,
        common_queue_capacity=2 * FLAGS.batch_size,
        common_queue_min=FLAGS.batch_size)
    [image, label] = provider.get(['image', 'label'])
    label -= FLAGS.labels_offset

    #####################################
    # Select the preprocessing function #
    #####################################
    preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
    image_preprocessing_fn = preprocessing_factory.get_preprocessing(
        preprocessing_name,
        is_training=False)

    eval_image_size = FLAGS.eval_image_size or network_fn.default_image_size

    image = image_preprocessing_fn(image, eval_image_size, eval_image_size)

    images, labels = tf.compat.v1.train.batch(
        [image, label],
        batch_size=FLAGS.batch_size,
        num_threads=FLAGS.num_preprocessing_threads,
        capacity=5 * FLAGS.batch_size)

    ####################
    # Define the model #
    ####################
    logits, _ = network_fn(images)

    if FLAGS.moving_average_decay:
      variable_averages = tf.train.ExponentialMovingAverage(
          FLAGS.moving_average_decay, tf_global_step)
      variables_to_restore = variable_averages.variables_to_restore(
          slim.get_model_variables())
      variables_to_restore[tf_global_step.op.name] = tf_global_step
    else:
      variables_to_restore = slim.get_variables_to_restore()

    predictions = tf.argmax(input=logits, axis=1)
    labels = tf.squeeze(labels)

    # Define the metrics:
    #names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
    names_to_values, names_to_updates = aggregate_metric_map({
        #'Accuracy': slim.metrics.streaming_accuracy(predictions,labels),
        'Accuracy': tf.compat.v1.metrics.accuracy(labels, predictions), ##FIXXED
        'Recall_5': (
            logits, labels, 5),
    })

    # Print the summaries to screen.
    for name, value in names_to_values.items():
      summary_name = 'eval/%s' % name
      op = tf.compat.v1.summary.scalar(summary_name, value, collections=[])
      op = tf.compat.v1.Print(op, [value], summary_name)
      tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.SUMMARIES, op)

    # TODO(sguada) use num_epochs=1
    if FLAGS.max_num_batches:
      num_batches = FLAGS.max_num_batches
    else:
      # This ensures that we make a single pass over all of the data.
      num_batches = math.ceil(dataset.num_samples / float(FLAGS.batch_size))

    if tf.io.gfile.isdir(FLAGS.checkpoint_path):
      checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
    else:
      checkpoint_path = FLAGS.checkpoint_path

    tf.compat.v1.logging.info('Evaluating %s' % checkpoint_path)

    config = tf.compat.v1.ConfigProto()
    config.gpu_options.allow_growth=True
#     config.log_device_placement=True
    
    slim.evaluation.evaluate_once(
        master=FLAGS.master,
        checkpoint_path=checkpoint_path,
        logdir=FLAGS.eval_dir,
        num_evals=num_batches,
        eval_op=list(names_to_updates.values()),
	    session_config=config,
        variables_to_restore=variables_to_restore)
Exemplo n.º 15
0
def train(
    config_yaml,
    displayiters,
    saveiters,
    maxiters,
    max_to_keep=5,
    keepdeconvweights=True,
    allow_growth=False,
):
    start_path = os.getcwd()
    os.chdir(
        str(Path(config_yaml).parents[0])
    )  # switch to folder of config_yaml (for logging)

    setup_logging()

    cfg = load_config(config_yaml)
    if cfg["optimizer"] != "adam":
        print(
            "Setting batchsize to 1! Larger batchsize not supported for this loader:",
            cfg["dataset_type"],
        )
        cfg["batch_size"] = 1

    if (
        cfg["partaffinityfield_predict"] and "multi-animal" in cfg["dataset_type"]
    ):  # the PAF code currently just hijacks the pairwise net stuff (for the batch feeding via Batch.pairwise_targets: 5)
        print("Activating limb prediction...")
        cfg["pairwise_predict"] = True

    dataset = PoseDatasetFactory.create(cfg)
    batch_spec = get_batch_spec(cfg)
    batch, enqueue_op, placeholders = setup_preloading(batch_spec)

    losses = PoseNetFactory.create(cfg).train(batch)
    total_loss = losses["total_loss"]

    for k, t in losses.items():
        tf.compat.v1.summary.scalar(k, t)
    merged_summaries = tf.compat.v1.summary.merge_all()
    net_type = cfg["net_type"]

    stem = Path(cfg["init_weights"]).stem
    if "snapshot" in stem and keepdeconvweights:
        print("Loading already trained DLC with backbone:", net_type)
        variables_to_restore = slim.get_variables_to_restore()
        start_iter = int(stem.split("-")[1])
    else:
        print("Loading ImageNet-pretrained", net_type)
        # loading backbone from ResNet, MobileNet etc.
        if "resnet" in net_type:
            variables_to_restore = slim.get_variables_to_restore(include=["resnet_v1"])
        elif "mobilenet" in net_type:
            variables_to_restore = slim.get_variables_to_restore(
                include=["MobilenetV2"]
            )
        elif "efficientnet" in net_type:
            variables_to_restore = slim.get_variables_to_restore(
                include=["efficientnet"]
            )
            variables_to_restore = {
                var.op.name.replace("efficientnet/", "")
                + "/ExponentialMovingAverage": var
                for var in variables_to_restore
            }
        else:
            print("Wait for DLC 2.3.")
        start_iter = 0

    restorer = tf.compat.v1.train.Saver(variables_to_restore)
    saver = tf.compat.v1.train.Saver(
        max_to_keep=max_to_keep
    )  # selects how many snapshots are stored, see https://github.com/AlexEMG/DeepLabCut/issues/8#issuecomment-387404835

    if allow_growth:
        config = tf.compat.v1.ConfigProto()
        config.gpu_options.allow_growth = True
        sess = tf.compat.v1.Session(config=config)
    else:
        sess = tf.compat.v1.Session()

    coord, thread = start_preloading(sess, enqueue_op, dataset, placeholders)
    train_writer = tf.compat.v1.summary.FileWriter(cfg["log_dir"], sess.graph)
    learning_rate, train_op, tstep = get_optimizer(total_loss, cfg)

    sess.run(tf.compat.v1.global_variables_initializer())
    sess.run(tf.compat.v1.local_variables_initializer())

    restorer.restore(sess, cfg["init_weights"])
    if maxiters is None:
        max_iter = int(cfg["multi_step"][-1][1])
    else:
        max_iter = min(int(cfg["multi_step"][-1][1]), int(maxiters))
        # display_iters = max(1,int(displayiters))
        print("Max_iters overwritten as", max_iter)

    if displayiters is None:
        display_iters = max(1, int(cfg["display_iters"]))
    else:
        display_iters = max(1, int(displayiters))
        print("Display_iters overwritten as", display_iters)

    if saveiters is None:
        save_iters = max(1, int(cfg["save_iters"]))

    else:
        save_iters = max(1, int(saveiters))
        print("Save_iters overwritten as", save_iters)

    cumloss, partloss, locrefloss, pwloss = 0.0, 0.0, 0.0, 0.0
    lr_gen = LearningRate(cfg)
    stats_path = Path(config_yaml).with_name("learning_stats.csv")
    lrf = open(str(stats_path), "w")

    print("Training parameters:")
    print(cfg)
    print("Starting multi-animal training....")
    max_iter += start_iter  # max_iter is relative to start_iter
    for it in range(start_iter, max_iter + 1):
        if "efficientnet" in net_type:
            lr_dict = {tstep: it - start_iter}
            current_lr = sess.run(learning_rate, feed_dict=lr_dict)
        else:
            current_lr = lr_gen.get_lr(it - start_iter)
            lr_dict = {learning_rate: current_lr}

        # [_, loss_val, summary] = sess.run([train_op, total_loss, merged_summaries],feed_dict={learning_rate: current_lr})
        [_, alllosses, loss_val, summary] = sess.run(
            [train_op, losses, total_loss, merged_summaries], feed_dict=lr_dict
        )

        partloss += alllosses["part_loss"]  # scoremap loss
        if cfg["location_refinement"]:
            locrefloss += alllosses["locref_loss"]
        if cfg["pairwise_predict"]:  # paf loss
            pwloss += alllosses["pairwise_loss"]

        cumloss += loss_val
        train_writer.add_summary(summary, it)

        if it % display_iters == 0 and it > start_iter:
            logging.info(
                "iteration: {} loss: {} scmap loss: {} locref loss: {} limb loss: {} lr: {}".format(
                    it,
                    "{0:.4f}".format(cumloss / display_iters),
                    "{0:.4f}".format(partloss / display_iters),
                    "{0:.4f}".format(locrefloss / display_iters),
                    "{0:.4f}".format(pwloss / display_iters),
                    current_lr,
                )
            )

            lrf.write(
                "iteration: {}, loss: {}, scmap loss: {}, locref loss: {}, limb loss: {}, lr: {}\n".format(
                    it,
                    "{0:.4f}".format(cumloss / display_iters),
                    "{0:.4f}".format(partloss / display_iters),
                    "{0:.4f}".format(locrefloss / display_iters),
                    "{0:.4f}".format(pwloss / display_iters),
                    current_lr,
                )
            )

            cumloss, partloss, locrefloss, pwloss = 0.0, 0.0, 0.0, 0.0
            lrf.flush()

        # Save snapshot
        if (it % save_iters == 0 and it != start_iter) or it == max_iter:
            model_name = cfg["snapshot_prefix"]
            saver.save(sess, model_name, global_step=it)

    lrf.close()

    sess.close()
    coord.request_stop()
    coord.join([thread])

    # return to original path.
    os.chdir(str(start_path))
Exemplo n.º 16
0
def train(
    config_yaml,
    displayiters,
    saveiters,
    maxiters,
    max_to_keep=5,
    keepdeconvweights=True,
    allow_growth=True,
):
    start_path = os.getcwd()
    os.chdir(str(Path(config_yaml).parents[0])
             )  # switch to folder of config_yaml (for logging)
    setup_logging()

    cfg = load_config(config_yaml)
    net_type = cfg["net_type"]
    if cfg["dataset_type"] in ("scalecrop", "tensorpack", "deterministic"):
        print(
            "Switching batchsize to 1, as tensorpack/scalecrop/deterministic loaders do not support batches >1. Use imgaug/default loader."
        )
        cfg["batch_size"] = 1  # in case this was edited for analysis.-

    dataset = PoseDatasetFactory.create(cfg)
    batch_spec = get_batch_spec(cfg)
    batch, enqueue_op, placeholders = setup_preloading(batch_spec)

    losses = PoseNetFactory.create(cfg).train(batch)
    total_loss = losses["total_loss"]

    for k, t in losses.items():
        tf.compat.v1.summary.scalar(k, t)
    merged_summaries = tf.compat.v1.summary.merge_all()

    stem = Path(cfg["init_weights"]).stem
    if "snapshot" in stem and keepdeconvweights:
        print("Loading already trained DLC with backbone:", net_type)
        variables_to_restore = slim.get_variables_to_restore()
        start_iter = int(stem.split("-")[1])
    else:
        print("Loading ImageNet-pretrained", net_type)
        # loading backbone from ResNet, MobileNet etc.
        if "resnet" in net_type:
            variables_to_restore = slim.get_variables_to_restore(
                include=["resnet_v1"])
        elif "mobilenet" in net_type:
            variables_to_restore = slim.get_variables_to_restore(
                include=["MobilenetV2"])
        elif "efficientnet" in net_type:
            variables_to_restore = slim.get_variables_to_restore(
                include=["efficientnet"])
            variables_to_restore = {
                var.op.name.replace("efficientnet/", "") +
                "/ExponentialMovingAverage": var
                for var in variables_to_restore
            }
        else:
            print("Wait for DLC 2.3.")
        start_iter = 0

    restorer = tf.compat.v1.train.Saver(variables_to_restore)
    saver = tf.compat.v1.train.Saver(
        max_to_keep=max_to_keep
    )  # selects how many snapshots are stored, see https://github.com/AlexEMG/DeepLabCut/issues/8#issuecomment-387404835

    if allow_growth:
        config = tf.compat.v1.ConfigProto()
        config.gpu_options.allow_growth = True
        sess = tf.compat.v1.Session(config=config)
    else:
        sess = tf.compat.v1.Session()

    coord, thread = start_preloading(sess, enqueue_op, dataset, placeholders)
    train_writer = tf.compat.v1.summary.FileWriter(cfg["log_dir"], sess.graph)

    if cfg.get("freezeencoder", False):
        if "efficientnet" in net_type:
            print("Freezing ONLY supported MobileNet/ResNet currently!!")
            learning_rate, train_op, tstep = get_optimizer(total_loss, cfg)

        print("Freezing encoder...")
        learning_rate, _, train_op = get_optimizer_with_freeze(total_loss, cfg)
    else:
        learning_rate, train_op, tstep = get_optimizer(total_loss, cfg)

    sess.run(tf.compat.v1.global_variables_initializer())
    sess.run(tf.compat.v1.local_variables_initializer())

    # Restore variables from disk.
    restorer.restore(sess, cfg["init_weights"])
    if maxiters is None:
        max_iter = int(cfg["multi_step"][-1][1])
    else:
        max_iter = min(int(cfg["multi_step"][-1][1]), int(maxiters))
        # display_iters = max(1,int(displayiters))
        print("Max_iters overwritten as", max_iter)

    if displayiters is None:
        display_iters = max(1, int(cfg["display_iters"]))
    else:
        display_iters = max(1, int(displayiters))
        print("Display_iters overwritten as", display_iters)

    if saveiters is None:
        save_iters = max(1, int(cfg["save_iters"]))

    else:
        save_iters = max(1, int(saveiters))
        print("Save_iters overwritten as", save_iters)

    cum_loss = 0.0
    lr_gen = LearningRate(cfg)

    stats_path = Path(config_yaml).with_name("learning_stats.csv")
    lrf = open(str(stats_path), "w")

    print("Training parameter:")
    print(cfg)
    print("Starting training....")
    max_iter += start_iter  # max_iter is relative to start_iter
    for it in range(start_iter, max_iter + 1):
        if "efficientnet" in net_type:
            lr_dict = {tstep: it - start_iter}
            current_lr = sess.run(learning_rate, feed_dict=lr_dict)
        else:
            current_lr = lr_gen.get_lr(it - start_iter)
            lr_dict = {learning_rate: current_lr}

        [_, loss_val,
         summary] = sess.run([train_op, total_loss, merged_summaries],
                             feed_dict=lr_dict)
        cum_loss += loss_val
        train_writer.add_summary(summary, it)

        if it % display_iters == 0 and it > start_iter:
            average_loss = cum_loss / display_iters
            cum_loss = 0.0
            logging.info("iteration: {} loss: {} lr: {}".format(
                it, "{0:.4f}".format(average_loss), current_lr))
            lrf.write("{}, {:.5f}, {}\n".format(it, average_loss, current_lr))
            lrf.flush()

        # Save snapshot
        if (it % save_iters == 0 and it != start_iter) or it == max_iter:
            model_name = cfg["snapshot_prefix"]
            saver.save(sess, model_name, global_step=it)

    lrf.close()
    sess.close()
    coord.request_stop()
    coord.join([thread])
    # return to original path.
    os.chdir(str(start_path))
Exemplo n.º 17
0
def train(config_yaml,
          displayiters,
          saveiters,
          maxiters,
          max_to_keep=5,
          keepdeconvweights=True,
          allow_growth=False):
    start_path = os.getcwd()
    os.chdir(str(Path(config_yaml).parents[0])
             )  #switch to folder of config_yaml (for logging)
    setup_logging()

    cfg = load_config(config_yaml)
    if cfg.dataset_type == 'default' or cfg.dataset_type == 'tensorpack' or cfg.dataset_type == 'deterministic':
        print(
            "Switching batchsize to 1, as default/tensorpack/deterministic loaders do not support batches >1. Use imgaug loader."
        )
        cfg['batch_size'] = 1  #in case this was edited for analysis.-

    dataset = create_dataset(cfg)
    batch_spec = get_batch_spec(cfg)
    batch, enqueue_op, placeholders = setup_preloading(batch_spec)
    losses = pose_net(cfg).train(batch)
    total_loss = losses['total_loss']

    for k, t in losses.items():
        TF.summary.scalar(k, t)
    merged_summaries = TF.summary.merge_all()

    if 'snapshot' in Path(cfg.init_weights).stem and keepdeconvweights:
        print("Loading already trained DLC with backbone:", cfg.net_type)
        variables_to_restore = slim.get_variables_to_restore()
    else:
        print("Loading ImageNet-pretrained", cfg.net_type)
        #loading backbone from ResNet, MobileNet etc.
        if 'resnet' in cfg.net_type:
            variables_to_restore = slim.get_variables_to_restore(
                include=["resnet_v1"])
        elif 'mobilenet' in cfg.net_type:
            variables_to_restore = slim.get_variables_to_restore(
                include=["MobilenetV2"])
        else:
            print("Wait for DLC 2.3.")

    restorer = TF.train.Saver(variables_to_restore)
    saver = TF.train.Saver(
        max_to_keep=max_to_keep
    )  # selects how many snapshots are stored, see https://github.com/AlexEMG/DeepLabCut/issues/8#issuecomment-387404835

    if allow_growth == True:
        config = tf.compat.v1.ConfigProto()
        config.gpu_options.allow_growth = True
        sess = TF.Session(config=config)
    else:
        sess = TF.Session()

    coord, thread = start_preloading(sess, enqueue_op, dataset, placeholders)
    train_writer = TF.summary.FileWriter(cfg.log_dir, sess.graph)
    learning_rate, train_op = get_optimizer(total_loss, cfg)

    sess.run(TF.global_variables_initializer())
    sess.run(TF.local_variables_initializer())

    # Restore variables from disk.
    restorer.restore(sess, cfg.init_weights)
    if maxiters == None:
        max_iter = int(cfg.multi_step[-1][1])
    else:
        max_iter = min(int(cfg.multi_step[-1][1]), int(maxiters))
        #display_iters = max(1,int(displayiters))
        print("Max_iters overwritten as", max_iter)

    if displayiters == None:
        display_iters = max(1, int(cfg.display_iters))
    else:
        display_iters = max(1, int(displayiters))
        print("Display_iters overwritten as", display_iters)

    if saveiters == None:
        save_iters = max(1, int(cfg.save_iters))

    else:
        save_iters = max(1, int(saveiters))
        print("Save_iters overwritten as", save_iters)

    cum_loss = 0.0
    lr_gen = LearningRate(cfg)

    stats_path = Path(config_yaml).with_name('learning_stats.csv')
    lrf = open(str(stats_path), 'w')

    print("Training parameter:")
    print(cfg)
    print("Starting training....")
    for it in range(max_iter + 1):
        current_lr = lr_gen.get_lr(it)
        [_, loss_val,
         summary] = sess.run([train_op, total_loss, merged_summaries],
                             feed_dict={learning_rate: current_lr})
        cum_loss += loss_val
        train_writer.add_summary(summary, it)

        if it % display_iters == 0 and it > 0:
            average_loss = cum_loss / display_iters
            cum_loss = 0.0
            logging.info("iteration: {} loss: {} lr: {}".format(
                it, "{0:.4f}".format(average_loss), current_lr))
            lrf.write("{}, {:.5f}, {}\n".format(it, average_loss, current_lr))
            lrf.flush()

        # Save snapshot
        if (it % save_iters == 0 and it != 0) or it == max_iter:
            model_name = cfg.snapshot_prefix
            saver.save(sess, model_name, global_step=it)

    lrf.close()
    sess.close()
    coord.request_stop()
    coord.join([thread])
    #return to original path.
    os.chdir(str(start_path))
Exemplo n.º 18
0
gpu_config = tf.ConfigProto()
gpu_config.gpu_options.allow_growth = True
gpu_config.gpu_options.per_process_gpu_memory_fraction = 0.8
with tf.Session(config=gpu_config) as sess:
    summary_writer = tf.summary.FileWriter('../logs', sess.graph)
    sess.run(tf.global_variables_initializer())
    train_data.init(sess)

    global_step = 0
    new_checkpoint = None
    if cfg_train_continue:
        new_checkpoint = tf.train.latest_checkpoint(
            '../checkpoints/checkpoints')
    if new_checkpoint:
        exclusions = ['global_step']
        net_except_logits = slim.get_variables_to_restore(exclude=exclusions)
        init_fn = slim.assign_from_checkpoint_fn(new_checkpoint,
                                                 net_except_logits,
                                                 ignore_missing_vars=True)
        init_fn(sess)
        print('load params from {}'.format(new_checkpoint))

    try:
        cur_epoch = 1
        while True:
            t0 = time.time()
            batch_x_img, batch_center_map, batch_scale_map, batch_offset_map, batch_landmark_map = train_data.batch(
                sess)
            t1 = time.time()
            debug_info_, train_loss_, loss_class_, loss_scale_, loss_offset_, loss_landmark_, loss_l2_, summary_, _ = sess.run(
                [
Exemplo n.º 19
0
def get_init_fn(train_dir=None,
                model_checkpoint=None,
                exclude_list=None,
                include_list=None,
                reset_global_step_if_necessary=True,
                ignore_missing_vars=True):
    """Gets model initializer function.

  The initialization logic is as follows:
    1. If a checkpoint is found in `train_dir`, initialize from it.
    2. Otherwise, if `model_checkpoint` is valid, initialize from it, and reset
       global step if necessary.
    3. Otherwise, do not initialize from any checkpoint.

  Args:
    train_dir: A string as the path to an existing training directory to resume.
      Use None to skip.
    model_checkpoint: A string as the path to an existing model checkpoint to
      initialize from. Use None to skip.
    exclude_list: A list of strings for the names of variables not to load.
    include_list: A list of strings for the names of variables to load. Use
      None to load all variables.
    reset_global_step_if_necessary: A boolean for whether to reset global step.
      Only used in the case of initializing from an existing checkpoint
      `model_checkpoint` rather than resuming training from `train_dir`.
    ignore_missing_vars: A boolean for whether to ignore missing variables. If
      False, errors will be raised if there is a missing variable.

  Returns:
    An model initializer function if an existing checkpoint is found. None
      otherwise.
  """
    # Make sure the exclude list is a list.
    if not exclude_list:
        exclude_list = []

    if train_dir:
        train_checkpoint = tf.train.latest_checkpoint(train_dir)
        if train_checkpoint:
            model_checkpoint = train_checkpoint
            logging.info('Resume latest training checkpoint in: %s.',
                         train_dir)
        elif model_checkpoint:
            logging.info('Use initial checkpoint: %s.', model_checkpoint)
            if reset_global_step_if_necessary:
                exclude_list.append('global_step')
                logging.info('Reset global step.')
    elif model_checkpoint:
        logging.info('Use initial checkpoint: %s.', model_checkpoint)
        if reset_global_step_if_necessary:
            exclude_list.append('global_step')
            logging.info('Reset global step.')

    if not model_checkpoint:
        logging.info('Do not initialize from a checkpoint.')
        return None

    variables_to_restore = tf_slim.get_variables_to_restore(
        include=include_list, exclude=exclude_list)

    return tf_slim.assign_from_checkpoint_fn(
        model_checkpoint,
        variables_to_restore,
        ignore_missing_vars=ignore_missing_vars)
Exemplo n.º 20
0
def main(model_root, datasets_dir, model_name, test_image_name):
    with tf.Graph().as_default():
        tf_global_step = slim.get_or_create_global_step()

        test_image = os.path.join(datasets_dir, test_image_name)

        dataset = convert_data.get_datasets('train', dataset_dir=datasets_dir)

        network_fn = net_select.get_network_fn(model_name,
                                               num_classes=dataset.num_classes,
                                               is_training=False)

        provider = slim.dataset_data_provider.DatasetDataProvider(
            dataset,
            shuffle=False,
            common_queue_capacity=20 * batch_size,
            common_queue_min=10 * batch_size)
        [image, label] = provider.get(['image', 'label'])

        image_preprocessing_fn = preprocessing_select.get_preprocessing(
            model_name, is_training=False)

        eval_image_size = network_fn.default_image_size
        image = image_preprocessing_fn(image, eval_image_size, eval_image_size)

        images, labels = tf_v1.train.batch(
            [image, label],
            batch_size=batch_size,
            num_threads=num_preprocessing_threads,
            capacity=5 * batch_size)

        checkpoint_path = os.path.join(model_root, model_name)
        if tf.io.gfile.isdir(checkpoint_path):
            checkpoint_path = tf.train.latest_checkpoint(checkpoint_path)
        else:
            checkpoint_path = checkpoint_path

        logits_1, end_points_1 = network_fn(images)
        attention_maps = tf.reduce_mean(end_points_1['attention_maps'],
                                        axis=-1,
                                        keepdims=True)
        attention_maps = tf.image.resize(attention_maps,
                                         [eval_image_size, eval_image_size],
                                         method=tf.image.ResizeMethod.BILINEAR)
        bboxes = tf_v1.py_func(mask2bbox, [attention_maps], [tf.float32])
        bboxes = tf.reshape(bboxes, [batch_size, 4])
        box_ind = tf.range(batch_size, dtype=tf.int32)

        images = tf.image.crop_and_resize(
            images,
            bboxes,
            box_ind,
            crop_size=[eval_image_size, eval_image_size])
        logits_2, end_points_2 = network_fn(images, reuse=True)

        logits = tf_v1.log(
            tf.nn.softmax(logits_1) * 0.5 + tf.nn.softmax(logits_2) * 0.5)
        """
        tf_v1.enable_eager_execution()

        #测试单张图片
        image_data = tf.io.read_file(test_image)
        image_data = tf.image.decode_jpeg(image_data,channels= 3)

        # plt.figure(1)
        # plt.imshow(image_data)

        image_data = image_preprocessing_fn(image_data, eval_image_size, eval_image_size)
        image_data = tf.expand_dims(image_data, 0)

        logits_3,end_points_3 = network_fn(image_data,reuse =True)
        attention_map = tf.reduce_mean(end_points_3['attention_maps'], axis=-1, keepdims=True)
        attention_map = tf.image.resize(attention_map, [eval_image_size, eval_image_size],
                                         method=tf.image.ResizeMethod.BILINEAR)
        bboxes = tf_v1.py_func(mask2bbox, [attention_map], [tf.float32])
        bboxes = tf.reshape(bboxes, [batch_size, 4])
        box_ind = tf.range(batch_size, dtype=tf.int32)

        image_data = tf.image.crop_and_resize(images, bboxes, box_ind, crop_size=[eval_image_size, eval_image_size])

        logits_4, end_points_4 = network_fn(image_data, reuse=True)
        logits_0 = tf_v1.log(tf.nn.softmax(logits_3) * 0.5 + tf.nn.softmax(logits_4) * 0.5)
        probabilities = logits_0[0,0:]

        print(probabilities)
        # sorted_inds = [i[0] for i in sorted(enumerate(-probabilities),key= lambda x:x[1])]
        sorted_inds = (np.argsort(probabilities.numpy())[::-1])

        train_info = sio.loadmat(os.path.join(datasets_dir, 'devkit', 'cars_train_annos.mat'))['annotations'][0]
        names = train_info['class']
        print(names)
        for i in range(5):
            index = sorted_inds[i]
            #  打印top5的预测类别和相应的概率值。
            print('Probability %0.2f => [%s]' % (probabilities[index],names[index+1][0][0]))
        """
        if moving_average_decay:
            variable_averages = tf.train.ExponentialMovingAverage(
                moving_average_decay, tf_global_step)
            variables_to_restore = variable_averages.variables_to_restore(
                slim.get_model_variables())
            variables_to_restore[tf_global_step.op.name] = tf_global_step
        else:
            variables_to_restore = slim.get_variables_to_restore()

        logits_to_updates = add_eval_summary(logits, labels, scope='/bilinear')
        logits_1_to_updates = add_eval_summary(logits_1,
                                               labels,
                                               scope='/logits_1')
        logits_2_to_updates = add_eval_summary(logits_2,
                                               labels,
                                               scope='/logits_2')

        if max_num_batches:
            num_batches = max_num_batches
        else:
            # This ensures that we make a single pass over all of the data.
            num_batches = math.ceil(dataset.num_samples / float(batch_size))

        config = tf_v1.ConfigProto(allow_soft_placement=True,
                                   log_device_placement=False)
        config.gpu_options.allow_growth = True
        config.gpu_options.per_process_gpu_memory_fraction = 1.0

        tf.compat.v1.disable_eager_execution()

        while True:
            if tf.io.gfile.isdir(checkpoint_path):
                checkpoint_path = tf.train.latest_checkpoint(checkpoint_path)
            else:
                checkpoint_path = checkpoint_path

            print('Evaluating %s' % checkpoint_path)
            eval_op = []
            # eval_op = list(logits_to_updates.values())
            eval_op.append(list(logits_to_updates.values()))
            eval_op.append(list(logits_1_to_updates.values()))
            eval_op.append(list(logits_2_to_updates.values()))
            # tf.convert_to_tensor(eval_op)
            # tf.cast(eval_op,dtype=tf.string)
            # print(eval_op)

            test_dir = checkpoint_path
            slim.evaluation.evaluate_once(
                master=' ',
                checkpoint_path=checkpoint_path,
                logdir=test_dir,
                num_evals=num_batches,
                eval_op=eval_op,
                variables_to_restore=variables_to_restore,
                final_op=None,
                session_config=config)