Example #1
0
def create_identity_input_layer(num_tasks, data_shape, router_out):
  """Creates a layer of identity components used to pass the PathNet input.

  Args:
    num_tasks: (int) number of tasks.
    data_shape: (sequence of ints) input data shape.
    router_out: the router used to route into the next layer. It should have
      a `__call__` method with the same arguments as the routers defined
      in pathnet_lib.

  Returns:
    A `ComponentsLayer` with one `IdentityComponent` per task.
  """
  return create_uniform_layer(
      num_components=num_tasks,
      component_fn=lambda: pn_components.IdentityComponent(data_shape),
      combiner_fn=pn_lib.SelectCombiner,
      router_fn=lambda: router_out)
Example #2
0
def create_wrapped_routed_layer(
    components,
    router,
    router_out,
    combiner,
    in_shape,
    out_shape,
    sparse=True,
    record_summaries_from_components=False):
  """Create a layer of components with a single router at the input.

  This wraps a layer of components by adding appropriately placed identities.
  It allows for a view of the routing proccess where a single input is routed
  through a subset of components, and then aggregated. This is in contrast to
  the default behavior of the pathnet library, where routing happens at the end
  of each component independently.

  Args:
    components: (list) components to route through.
    router: the router used to select the components. It should have
      a `__call__` method with the same arguments as the routers defined
      in pathnet_lib.
    router_out: the router used to route the final output. It should have
      a `__call__` method with the same arguments as the routers defined
      in pathnet_lib.
    combiner: the combiner used to aggregate the outputs.
    in_shape: (sequence of ints) input shape.
    out_shape: (sequence of ints) output shape.
    sparse: (bool) whether to set the `sparse` flag for the components layer.
    record_summaries_from_components: (bool) whether to record summaries
      coming from `components`.

  Returns:
    A list of `ComponentsLayer` with the desired behavior.
  """

  routed_components = []
  for component in components:
    routed_components.append(pn_lib.RoutedComponent(
        component,
        pn_lib.SinglePathRouter(),
        record_summaries=record_summaries_from_components))

  router_layer = pn_lib.ComponentsLayer(
      components=[pn_lib.RoutedComponent(
          pn_components.IdentityComponent(out_shape=in_shape),
          router,
          record_summaries=False)],
      combiner=pn_lib.SumCombiner())

  components_layer = pn_lib.ComponentsLayer(
      components=routed_components,
      combiner=combiner,
      sparse=sparse)

  aggregation_layer = pn_lib.ComponentsLayer(
      components=[pn_lib.RoutedComponent(
          pn_components.IdentityComponent(out_shape=out_shape),
          router_out,
          record_summaries=False)],
      combiner=pn_lib.SumCombiner())

  return [router_layer, components_layer, aggregation_layer]