def build_pathnet_eval_graph(
    task_names, batch_size, num_classes_for_tasks, router_fn):
  """Constructs the PathNet eval graph.

  Args:
    task_names: (list of strings) names of tasks.
    batch_size: (int) batch size to use.
    num_classes_for_tasks: (list of ints) number of classes for each task.
    router_fn: function that, given a single argument `num_components`, returns
      a router (see routers in `pathnet/pathnet_lib.py`) for a layer containing
      `num_components` components.

  Returns:
    A tuple of (`p_inputs`, `p_task_id`, `out_logits`). `p_inputs` and
    `p_task_id` are placeholders for input image and scalar task id,
    respectively. `out_logits` are the final network output (classification
    logits).
  """
  num_tasks = len(task_names)

  # PathNet layers

  keras_layers = models.get_keras_layers_for_omniglot_experiment()

  pathnet_layers = models.build_model_from_keras_layers(
      _OMNIGLOT_INPUT_SHAPE, num_tasks, keras_layers, router_fn)

  # Task-specific linear heads

  pathnet_layers.append(
      utils.create_layer_with_task_specific_linear_heads(num_classes_for_tasks))

  # Output components

  pathnet_layers.append(create_uniform_layer(
      num_components=num_tasks,
      component_fn=lambda: pn_components.ModelHeadComponent(loss_fn=loss_fn),
      combiner_fn=pn.SelectCombiner,
      router_fn=lambda: None))

  pathnet = pn.PathNet(
      pathnet_layers, tf.contrib.training.HParams(batch_size=batch_size))

  p_inputs, _, p_task_id, _, out_logits = utils.build_pathnet_graph(
      pathnet, _OMNIGLOT_INPUT_SHAPE, training=False)

  return p_inputs, p_task_id, out_logits
Exemplo n.º 2
0
 def component_fn():
     return pn_components.ModelHeadComponent(
         loss_fn=loss_fn, auxiliary_loss_fn=auxiliary_loss_fn)