Пример #1
0
def build_model_from_keras_layers(input_data_shape, num_tasks, keras_layers,
                                  router_fn):
    """Creates PathNet layers from Keras layers.

  Args:
    input_data_shape: (sequence of ints) expected input shape.
    num_tasks: (int) number of tasks.
    keras_layers: (list of lists of `keras.layer.Layer`) keras layers to be
      wrapped into routed layers. The keras layers for a specific model can be
      obtained by calling `get_keras_layers_for_mnist_experiment`.
    router_fn: function that, given a single argument `num_components`, returns
      a router (see routers in `pathnet/pathnet_lib.py`) for a layer containing
      `num_components` components.

  Returns:
    A list of `pathnet_lib/ComponentsLayer`. The returned list starts with
    an input layer returned from `pathnet_lib.create_identity_input_layer`.
    The input/output shape for the subsequent layers is automatically computed
    by `utils.compute_output_shape_and_create_routed_layer`. The last layer
    uses an `IndependentTaskBasedRouter`, since it assumes the next layer
    contains task-specific heads.
  """
    pathnet_layers = []

    pathnet_layers.append(
        create_identity_input_layer(num_tasks,
                                    input_data_shape,
                                    router_out=pn.SinglePathRouter()))

    wrapped_router_fn = wrap_router_fn(router_fn)
    data_shape = input_data_shape

    for layer_index, layer in enumerate(keras_layers):
        if layer_index < len(keras_layers) - 1:
            # Not the last layer
            router_out = pn.SinglePathRouter()
        else:
            # The last layer - the layer after that will have task-specific heads
            router_out = pn.IndependentTaskBasedRouter(num_tasks=num_tasks)

        # Create routed layer and update current data shape
        new_layers, data_shape = utils.compute_output_shape_and_create_routed_layer(
            keras_components=layer,
            in_shape=data_shape,
            router_fn=wrapped_router_fn,
            router_out=router_out)

        pathnet_layers += new_layers

    return pathnet_layers
Пример #2
0
 def wrapped_router_fn(num_components):
     if num_components == 1:
         return pn.SinglePathRouter()
     else:
         return router_fn(num_components)
Пример #3
0
def create_wrapped_routed_layer(
    components,
    router,
    router_out,
    combiner,
    in_shape,
    out_shape,
    sparse=True,
    record_summaries_from_components=False):
  """Create a layer of components with a single router at the input.

  This wraps a layer of components by adding appropriately placed identities.
  It allows for a view of the routing proccess where a single input is routed
  through a subset of components, and then aggregated. This is in contrast to
  the default behavior of the pathnet library, where routing happens at the end
  of each component independently.

  Args:
    components: (list) components to route through.
    router: the router used to select the components. It should have
      a `__call__` method with the same arguments as the routers defined
      in pathnet_lib.
    router_out: the router used to route the final output. It should have
      a `__call__` method with the same arguments as the routers defined
      in pathnet_lib.
    combiner: the combiner used to aggregate the outputs.
    in_shape: (sequence of ints) input shape.
    out_shape: (sequence of ints) output shape.
    sparse: (bool) whether to set the `sparse` flag for the components layer.
    record_summaries_from_components: (bool) whether to record summaries
      coming from `components`.

  Returns:
    A list of `ComponentsLayer` with the desired behavior.
  """

  routed_components = []
  for component in components:
    routed_components.append(pn_lib.RoutedComponent(
        component,
        pn_lib.SinglePathRouter(),
        record_summaries=record_summaries_from_components))

  router_layer = pn_lib.ComponentsLayer(
      components=[pn_lib.RoutedComponent(
          pn_components.IdentityComponent(out_shape=in_shape),
          router,
          record_summaries=False)],
      combiner=pn_lib.SumCombiner())

  components_layer = pn_lib.ComponentsLayer(
      components=routed_components,
      combiner=combiner,
      sparse=sparse)

  aggregation_layer = pn_lib.ComponentsLayer(
      components=[pn_lib.RoutedComponent(
          pn_components.IdentityComponent(out_shape=out_shape),
          router_out,
          record_summaries=False)],
      combiner=pn_lib.SumCombiner())

  return [router_layer, components_layer, aggregation_layer]