Esempio n. 1
0
def build_pathnet_eval_graph(task_names, batch_size, num_classes_for_tasks,
                             router_fn):
    """Constructs the PathNet eval graph.

  Args:
    task_names: (list of strings) names of tasks.
    batch_size: (int) batch size to use.
    num_classes_for_tasks: (list of ints) number of classes for each task.
    router_fn: function that, given a single argument `num_components`, returns
      a router (see routers in `pathnet/pathnet_lib.py`) for a layer containing
      `num_components` components.

  Returns:
    A tuple of (`p_inputs`, `p_task_id`, `out_logits`). `p_inputs` and
    `p_task_id` are placeholders for input image and scalar task id,
    respectively. `out_logits` are the final network output (classification
    logits).
  """
    num_tasks = len(task_names)

    # PathNet layers

    keras_layers = models.get_keras_layers_for_omniglot_experiment()

    pathnet_layers = models.build_model_from_keras_layers(
        _OMNIGLOT_INPUT_SHAPE, num_tasks, keras_layers, router_fn)

    # Task-specific linear heads

    pathnet_layers.append(
        utils.create_layer_with_task_specific_linear_heads(
            num_classes_for_tasks))

    # Output components

    pathnet_layers.append(
        create_uniform_layer(num_components=num_tasks,
                             component_fn=lambda: pn_components.
                             ModelHeadComponent(loss_fn=loss_fn),
                             combiner_fn=pn.SelectCombiner,
                             router_fn=lambda: None))

    pathnet = pn.PathNet(pathnet_layers,
                         tf.contrib.training.HParams(batch_size=batch_size))

    p_inputs, _, p_task_id, _, out_logits = utils.build_pathnet_graph(
        pathnet, _OMNIGLOT_INPUT_SHAPE, training=False)

    return p_inputs, p_task_id, out_logits
Esempio n. 2
0
def run_pathnet_training_and_evaluation(
    task_names,
    task_data,
    input_data_shape,
    training_hparams,
    components_layers,
    evaluate_on,
    summary_dir,
    resume_checkpoint_dir=None,
    save_checkpoint_every_n_steps=250,
    intermediate_eval_steps=[]):
  """Trains and evaluates a PathNet multitask image classification model.

  Args:
    task_names: (list of strings) names of tasks.
    task_data: (list of dicts) list of dictionaries, one per task.
      Each dictionary should map strings into `tf.data.Dataset`s.
      The `i`-th dictionary should contain all dataset splits (such as 'train',
      'test', 'eval', etc) for the `i`-th task. The splits can be arbitrary,
      but to run the training, every dataset should contain a 'train' split.
    input_data_shape: (sequence of ints) expected shape of input images
      (excluding batch dimension). For example, for the MNIST dataset
      `input_data_shape=[28, 28, 1]`.
    training_hparams: (tf.contrib.training.HParams) training hyperparameters.
    components_layers: (list of `pn.ComponentsLayer`s) layers that make up
      the PathNet model.
    evaluate_on: (list of strings) dataset splits on which the trained PathNet
      should be evaluated. These keys should be present in every dictionary
      in `task_data`.
    summary_dir: (string) directory for the summary writer.
    resume_checkpoint_dir: (string or None) directory for the checkpoint
      to reload, or None if should start from scratch.
    save_checkpoint_every_n_steps: (int) frequency for saving model checkpoints.
    intermediate_eval_steps: (list of ints) training step numbers at which
      accuracy should be evaluated. An evaluation after the last step is
      always performed.
  """
  session = tf.Session(graph=tf.get_default_graph())

  summary_writer = tf.contrib.summary.create_file_writer(summary_dir)
  summary_writer.set_as_default()

  num_tasks = len(task_names)

  # Every `num_tasks` subsequent steps contain exactly one step for each task,
  # and always in the order as they appear in `task_data`. Setting the logging
  # frequency to `num_tasks + 1` (or any other number coprime with `num_tasks`)
  # guarantees that each task will get to record summaries with the same
  # frequency.
  with tf.contrib.summary.record_summaries_every_n_global_steps(num_tasks + 1):
    pathnet = pn.PathNet(components_layers, training_hparams)
    num_steps = training_hparams.num_steps

    eval_steps = intermediate_eval_steps + [num_steps]

    # Loop each training dataset forever.
    train_data = [
        dataset['train'].repeat().make_one_shot_iterator().get_next()
        for dataset in task_data
    ]

    # Attach the task id to each dataset.
    train_data = list(enumerate(train_data))

    p_inputs = tf.placeholder(tf.float32, shape=[None] + input_data_shape)
    p_labels = tf.placeholder(tf.int64, shape=[None])
    p_task_id = tf.placeholder(tf.int32, shape=[], name='task_id')

    train_step_op, _ = build_pathnet_graph(
        p_inputs, p_labels, p_task_id, pathnet, training=True)

  with tf.variable_scope(tf.get_variable_scope(), reuse=True):
    _, out_logits_eval = build_pathnet_graph(
        p_inputs, p_labels, p_task_id, pathnet, training=False)

  session.run(tf.global_variables_initializer())
  tf.contrib.summary.initialize(session=session)

  saver = tf.train.Saver(tf.global_variables())

  start_step = 0

  p_task_accuracies = {}
  accuracy_summary_op = {}

  for data_split in evaluate_on:
    p_task_accuracies[data_split], accuracy_summary_op[data_split] = \
      create_accuracy_summary_ops(
          task_names, summary_name_prefix='final_eval_%s' % data_split)

  if resume_checkpoint_dir is not None:
    print('Resuming from checkpoint: %s' % resume_checkpoint_dir)

    last_global_step = int(resume_checkpoint_dir.split('-')[-1])

    assert last_global_step % num_tasks == 0
    start_step = last_global_step // num_tasks

    saver.restore(session, resume_checkpoint_dir)

  for dataset in task_data:
    for data_split in evaluate_on:
      num_batches = count_batches(
          session, dataset[data_split].make_one_shot_iterator().get_next())

      dataset[data_split] = dataset[data_split].repeat()
      dataset[data_split] = (
          dataset[data_split].make_one_shot_iterator().get_next())

      dataset[data_split] = (dataset[data_split], num_batches)

  for step in tqdm(range(start_step, num_steps)):
    random.shuffle(train_data)

    run_pathnet_training_step(
        session, p_inputs, p_labels, p_task_id, train_step_op, train_data)

    if step + 1 in eval_steps:
      for data_split in evaluate_on:
        eval_data = [dataset[data_split] for dataset in task_data]

        print('Running evaluation on: %s' % data_split)

        task_accuracies = run_pathnet_evaluation(
            session=session,
            p_inputs=p_inputs,
            p_task_id=p_task_id,
            out_logits=out_logits_eval,
            task_names=task_names,
            eval_data=eval_data)

        run_accuracy_summary_ops(
            session,
            p_task_accuracies[data_split],
            task_accuracies,
            accuracy_summary_op[data_split])

    if (step + 1) % save_checkpoint_every_n_steps == 0:
      path = summary_dir + '/chkpt'
      saver.save(
          session, path, global_step=tf.train.get_or_create_global_step())