示例#1
0
def create_layer_with_task_specific_linear_heads(num_classes_for_tasks):
    """Returns a `pathnet_lib.ComponentsLayer` with linear task specific layers.

  This is a small helper function to create a layer of fully connected
  components for multiple classification tasks (with possibly different
  numbers of classses). The constructed layer assumes that the next layer
  contains task heads to compute the task loss, and uses a
  `pathnet_lib.TaskBasedRouter` to route into the next layer.

  Args:
    num_classes_for_tasks: (list of ints) number of classes for each task.

  Returns:
    A `pathnet_lib.ComponentsLayer` containing one FC layer per task.
  """
    num_tasks = len(num_classes_for_tasks)

    components = []
    for num_classes in num_classes_for_tasks:
        components.append(
            pn.RoutedComponent(
                pn_components.FCLComponent(numbers_of_units=[num_classes]),
                pn.IndependentTaskBasedRouter(num_tasks=num_tasks)))

    return pn.ComponentsLayer(components=components,
                              combiner=pn.SelectCombiner())
示例#2
0
def create_uniform_layer(num_components, component_fn, combiner_fn, router_fn):
    """Creates a layer of components with the same architecture and router.

  Args:
    num_components: (int) number of components.
    component_fn: function that creates a new component.
    combiner_fn: function that creates a combiner used to aggregate the outputs.
    router_fn: function that creates the router used to route into the next
      layer. It should have a `__call__` method with the same arguments as
      the routers defined in pathnet_lib.

  Returns:
    A `ComponentsLayer` with `num_components` components.
  """
    components = [
        pn_lib.RoutedComponent(component_fn(), router_fn())
        for _ in range(num_components)
    ]

    return pn_lib.ComponentsLayer(components, combiner_fn())
示例#3
0
def create_wrapped_routed_layer(
    components,
    router,
    router_out,
    combiner,
    in_shape,
    out_shape,
    sparse=True,
    record_summaries_from_components=False):
  """Create a layer of components with a single router at the input.

  This wraps a layer of components by adding appropriately placed identities.
  It allows for a view of the routing proccess where a single input is routed
  through a subset of components, and then aggregated. This is in contrast to
  the default behavior of the pathnet library, where routing happens at the end
  of each component independently.

  Args:
    components: (list) components to route through.
    router: the router used to select the components. It should have
      a `__call__` method with the same arguments as the routers defined
      in pathnet_lib.
    router_out: the router used to route the final output. It should have
      a `__call__` method with the same arguments as the routers defined
      in pathnet_lib.
    combiner: the combiner used to aggregate the outputs.
    in_shape: (sequence of ints) input shape.
    out_shape: (sequence of ints) output shape.
    sparse: (bool) whether to set the `sparse` flag for the components layer.
    record_summaries_from_components: (bool) whether to record summaries
      coming from `components`.

  Returns:
    A list of `ComponentsLayer` with the desired behavior.
  """

  routed_components = []
  for component in components:
    routed_components.append(pn_lib.RoutedComponent(
        component,
        pn_lib.SinglePathRouter(),
        record_summaries=record_summaries_from_components))

  router_layer = pn_lib.ComponentsLayer(
      components=[pn_lib.RoutedComponent(
          pn_components.IdentityComponent(out_shape=in_shape),
          router,
          record_summaries=False)],
      combiner=pn_lib.SumCombiner())

  components_layer = pn_lib.ComponentsLayer(
      components=routed_components,
      combiner=combiner,
      sparse=sparse)

  aggregation_layer = pn_lib.ComponentsLayer(
      components=[pn_lib.RoutedComponent(
          pn_components.IdentityComponent(out_shape=out_shape),
          router_out,
          record_summaries=False)],
      combiner=pn_lib.SumCombiner())

  return [router_layer, components_layer, aggregation_layer]