def build_pathnet_eval_graph( task_names, batch_size, num_classes_for_tasks, router_fn): """Constructs the PathNet eval graph. Args: task_names: (list of strings) names of tasks. batch_size: (int) batch size to use. num_classes_for_tasks: (list of ints) number of classes for each task. router_fn: function that, given a single argument `num_components`, returns a router (see routers in `pathnet/pathnet_lib.py`) for a layer containing `num_components` components. Returns: A tuple of (`p_inputs`, `p_task_id`, `out_logits`). `p_inputs` and `p_task_id` are placeholders for input image and scalar task id, respectively. `out_logits` are the final network output (classification logits). """ num_tasks = len(task_names) # PathNet layers keras_layers = models.get_keras_layers_for_omniglot_experiment() pathnet_layers = models.build_model_from_keras_layers( _OMNIGLOT_INPUT_SHAPE, num_tasks, keras_layers, router_fn) # Task-specific linear heads pathnet_layers.append( utils.create_layer_with_task_specific_linear_heads(num_classes_for_tasks)) # Output components pathnet_layers.append(create_uniform_layer( num_components=num_tasks, component_fn=lambda: pn_components.ModelHeadComponent(loss_fn=loss_fn), combiner_fn=pn.SelectCombiner, router_fn=lambda: None)) pathnet = pn.PathNet( pathnet_layers, tf.contrib.training.HParams(batch_size=batch_size)) p_inputs, _, p_task_id, _, out_logits = utils.build_pathnet_graph( pathnet, _OMNIGLOT_INPUT_SHAPE, training=False) return p_inputs, p_task_id, out_logits
def component_fn(): return pn_components.ModelHeadComponent( loss_fn=loss_fn, auxiliary_loss_fn=auxiliary_loss_fn)