def run_omniglot_experiment(
    pathnet_layers,
    training_hparams,
    task_names,
    task_data,
    resume_checkpoint_dir=None):
  """Runs the Omniglot experiment.

  Args:
    pathnet_layers: (list of `pn.ComponentsLayer`s) layers that make up
      the PathNet model.
    training_hparams: (tf.contrib.training.HParams) training hyperparameters.
    task_names: (list of strings) names of tasks.
    task_data: (list of dicts) list of dictionaries, one per task.
      Each dictionary should map strings 'train', 'validation' and 'test' into
      `tf.data.Dataset`s for training, validation, and testing, respectively.
    resume_checkpoint_dir: (string or None) directory for the checkpoint
      to reload, or None if should start from scratch.
  """
  for task_id in range(len(task_data)):
    task_data[task_id] = data.batch_all(
        task_data[task_id], training_hparams.batch_size)

  utils.run_pathnet_training_and_evaluation(
      task_names=task_names,
      task_data=task_data,
      input_data_shape=_OMNIGLOT_INPUT_SHAPE,
      training_hparams=training_hparams,
      components_layers=pathnet_layers,
      evaluate_on=['train', 'validation', 'test'],
      resume_checkpoint_dir=resume_checkpoint_dir,
      summary_dir=FLAGS.logdir)
def run_routing_experiment(
    pathnet_layers,
    training_hparams,
    task_names,
    task_data,
    resume_checkpoint_dir=None,
    intermediate_eval_steps=[]):
  """Runs the three task clusters routing experiment.

  Args:
    pathnet_layers: (list of `pn.ComponentsLayer`s) layers that make up
      the PathNet model.
    training_hparams: (tf.contrib.training.HParams) training hyperparameters.
    task_names: (list of strings) names of tasks.
    task_data: (list of dicts) list of dictionaries, one per task.
      Each dictionary should map strings 'train', 'validation' and 'test' into
      `tf.data.Dataset`s for training, validation, and testing, respectively.
    resume_checkpoint_dir: (string or None) directory for the checkpoint
      to reload, or None if should start from scratch.
    intermediate_eval_steps: (list of ints) training step numbers at which
      accuracy should be evaluated.
  """
  for dataset in task_data:
    dataset['train_clean'] = dataset['train']
    dataset['train'] = dataset['train'].map(data.augment_with_random_crop)

  for task_id in range(len(task_data)):
    task_data[task_id] = data.batch_all(
        task_data[task_id], training_hparams.batch_size)

  utils.run_pathnet_training_and_evaluation(
      task_names=task_names,
      task_data=task_data,
      input_data_shape=_INPUT_SHAPE,
      training_hparams=training_hparams,
      components_layers=pathnet_layers,
      evaluate_on=['train_clean', 'test'],
      resume_checkpoint_dir=resume_checkpoint_dir,
      summary_dir=FLAGS.logdir,
      intermediate_eval_steps=intermediate_eval_steps,
      save_checkpoint_every_n_steps=sys.maxsize)
Beispiel #3
0
def construct_pathnet_and_run_mnist_experiment(task_names, task_data,
                                               num_classes_for_tasks,
                                               router_fn):
    """Runs the MNIST experiment.

  Args:
    task_names: (list of strings) names of tasks.
    task_data: (list of dicts) list of dictionaries, one per task.
      Each dictionary should map strings 'train' and 'test' into
      `tf.data.Dataset`s for training and testing, respectively.
    num_classes_for_tasks: (list of ints) number of classes for each task.
    router_fn: function that, given a single argument `num_components`, returns
      a router (see routers in `pathnet/pathnet_lib.py`) for a layer containing
      `num_components` components.

  """
    num_tasks = len(task_names)

    input_data_shape = [28, 28, 1]
    batch_size = 16

    for task_id in range(num_tasks):
        task_data[task_id] = data.batch_all(task_data[task_id], batch_size)

    # Train each task for 10 epochs
    n_epochs = 10

    training_hparams = tf.contrib.training.HParams(num_steps=n_epochs *
                                                   60000 // batch_size,
                                                   batch_size=batch_size,
                                                   learning_rate=0.005)

    routers = []

    def get_router(num_components):
        routers.append(router_fn(num_components))
        return routers[-1]

    # PathNet layers

    keras_layers = models.get_keras_layers_for_mnist_experiment(
        num_components=num_tasks)

    pathnet_layers = models.build_model_from_keras_layers(
        input_data_shape, num_tasks, keras_layers, get_router)

    # Task-specific linear heads

    pathnet_layers.append(
        utils.create_layer_with_task_specific_linear_heads(
            num_classes_for_tasks))

    # Output components to compute task loss

    auxiliary_loss_fn = utils.create_auxiliary_loss_function(
        routers=routers,
        num_total_components=12,
        num_total_steps=training_hparams.num_steps * num_tasks,
        budget=FLAGS.budget,
        budget_penalty=FLAGS.budget_penalty,
        entropy_penalty=FLAGS.entropy_penalty,
        entropy_penalty_alpha=FLAGS.entropy_penalty_alpha)

    def component_fn():
        return pn_components.ModelHeadComponent(
            loss_fn=loss_fn, auxiliary_loss_fn=auxiliary_loss_fn)

    pathnet_layers.append(
        create_uniform_layer(num_components=num_tasks,
                             component_fn=component_fn,
                             combiner_fn=pn.SelectCombiner,
                             router_fn=lambda: None))

    utils.run_pathnet_training_and_evaluation(
        task_names=task_names,
        task_data=task_data,
        input_data_shape=input_data_shape,
        training_hparams=training_hparams,
        components_layers=pathnet_layers,
        evaluate_on=['train', 'test'],
        summary_dir=FLAGS.logdir)
def main(_):
  num_alphabets = 20
  task_names = ['Omniglot-%d' % task_id for task_id in range(num_alphabets)]

  task_data, num_classes = data.get_data_for_multitask_omniglot_setup(
      num_alphabets)

  batch_size = 16
  for task_id in range(num_alphabets):
    task_data[task_id] = data.batch_all(task_data[task_id], batch_size)

  router_fn = utils.get_router_fn_by_name(num_alphabets, FLAGS.method)

  session = tf.Session(graph=tf.get_default_graph())

  tf.train.get_or_create_global_step()

  summary_writer = tf.contrib.summary.create_file_writer(FLAGS.logdir)
  summary_writer.set_as_default()

  tf.contrib.summary.initialize(session=session)

  p_inputs, p_task_id, out_logits = build_pathnet_eval_graph(
      task_names, batch_size, num_classes, router_fn)

  evaluate_on = ['train', 'validation', 'test']

  p_task_accuracies = {}
  accuracy_summary_op = {}

  for data_split in evaluate_on:
    (p_task_accuracies[data_split], accuracy_summary_op[data_split]) =\
        utils.create_accuracy_summary_ops(
            task_names, summary_name_prefix='eval_%s' % data_split)

  # This `Saver` is not used to save variables, only to restore them from
  # the checkpoints.
  saver = tf.train.Saver(tf.global_variables())

  previous_checkpoint_path = ''
  time_waited_for_checkpoints = 0

  while time_waited_for_checkpoints < _MAX_WAIT_FOR_NEW_CHECKPOINTS:
    latest_checkpoint_path = tf.train.latest_checkpoint(FLAGS.logdir)

    if latest_checkpoint_path in [None, previous_checkpoint_path]:
      print('Found no new checkpoints')

      time_waited_for_checkpoints += _CHECK_FOR_CHECKPOINTS_FREQUENCY
      time.sleep(_CHECK_FOR_CHECKPOINTS_FREQUENCY)

      continue
    else:
      time_waited_for_checkpoints = 0

    print('Reloading checkpoint: %s' % latest_checkpoint_path)
    previous_checkpoint_path = latest_checkpoint_path

    saver.restore(session, latest_checkpoint_path)

    for data_split in evaluate_on:
      eval_data = [
          dataset[data_split].make_one_shot_iterator().get_next()
          for dataset in task_data
      ]

      print('Evaluating on: %s' % data_split)

      task_accuracies = utils.run_pathnet_evaluation(
          session, p_inputs, p_task_id, out_logits, task_names, eval_data)

      utils.run_accuracy_summary_ops(
          session,
          p_task_accuracies[data_split],
          task_accuracies,
          accuracy_summary_op[data_split])