def build_pathnet_eval_graph(task_names, batch_size, num_classes_for_tasks, router_fn): """Constructs the PathNet eval graph. Args: task_names: (list of strings) names of tasks. batch_size: (int) batch size to use. num_classes_for_tasks: (list of ints) number of classes for each task. router_fn: function that, given a single argument `num_components`, returns a router (see routers in `pathnet/pathnet_lib.py`) for a layer containing `num_components` components. Returns: A tuple of (`p_inputs`, `p_task_id`, `out_logits`). `p_inputs` and `p_task_id` are placeholders for input image and scalar task id, respectively. `out_logits` are the final network output (classification logits). """ num_tasks = len(task_names) # PathNet layers keras_layers = models.get_keras_layers_for_omniglot_experiment() pathnet_layers = models.build_model_from_keras_layers( _OMNIGLOT_INPUT_SHAPE, num_tasks, keras_layers, router_fn) # Task-specific linear heads pathnet_layers.append( utils.create_layer_with_task_specific_linear_heads( num_classes_for_tasks)) # Output components pathnet_layers.append( create_uniform_layer(num_components=num_tasks, component_fn=lambda: pn_components. ModelHeadComponent(loss_fn=loss_fn), combiner_fn=pn.SelectCombiner, router_fn=lambda: None)) pathnet = pn.PathNet(pathnet_layers, tf.contrib.training.HParams(batch_size=batch_size)) p_inputs, _, p_task_id, _, out_logits = utils.build_pathnet_graph( pathnet, _OMNIGLOT_INPUT_SHAPE, training=False) return p_inputs, p_task_id, out_logits
def run_pathnet_training_and_evaluation( task_names, task_data, input_data_shape, training_hparams, components_layers, evaluate_on, summary_dir, resume_checkpoint_dir=None, save_checkpoint_every_n_steps=250, intermediate_eval_steps=[]): """Trains and evaluates a PathNet multitask image classification model. Args: task_names: (list of strings) names of tasks. task_data: (list of dicts) list of dictionaries, one per task. Each dictionary should map strings into `tf.data.Dataset`s. The `i`-th dictionary should contain all dataset splits (such as 'train', 'test', 'eval', etc) for the `i`-th task. The splits can be arbitrary, but to run the training, every dataset should contain a 'train' split. input_data_shape: (sequence of ints) expected shape of input images (excluding batch dimension). For example, for the MNIST dataset `input_data_shape=[28, 28, 1]`. training_hparams: (tf.contrib.training.HParams) training hyperparameters. components_layers: (list of `pn.ComponentsLayer`s) layers that make up the PathNet model. evaluate_on: (list of strings) dataset splits on which the trained PathNet should be evaluated. These keys should be present in every dictionary in `task_data`. summary_dir: (string) directory for the summary writer. resume_checkpoint_dir: (string or None) directory for the checkpoint to reload, or None if should start from scratch. save_checkpoint_every_n_steps: (int) frequency for saving model checkpoints. intermediate_eval_steps: (list of ints) training step numbers at which accuracy should be evaluated. An evaluation after the last step is always performed. """ session = tf.Session(graph=tf.get_default_graph()) summary_writer = tf.contrib.summary.create_file_writer(summary_dir) summary_writer.set_as_default() num_tasks = len(task_names) # Every `num_tasks` subsequent steps contain exactly one step for each task, # and always in the order as they appear in `task_data`. Setting the logging # frequency to `num_tasks + 1` (or any other number coprime with `num_tasks`) # guarantees that each task will get to record summaries with the same # frequency. with tf.contrib.summary.record_summaries_every_n_global_steps(num_tasks + 1): pathnet = pn.PathNet(components_layers, training_hparams) num_steps = training_hparams.num_steps eval_steps = intermediate_eval_steps + [num_steps] # Loop each training dataset forever. train_data = [ dataset['train'].repeat().make_one_shot_iterator().get_next() for dataset in task_data ] # Attach the task id to each dataset. train_data = list(enumerate(train_data)) p_inputs = tf.placeholder(tf.float32, shape=[None] + input_data_shape) p_labels = tf.placeholder(tf.int64, shape=[None]) p_task_id = tf.placeholder(tf.int32, shape=[], name='task_id') train_step_op, _ = build_pathnet_graph( p_inputs, p_labels, p_task_id, pathnet, training=True) with tf.variable_scope(tf.get_variable_scope(), reuse=True): _, out_logits_eval = build_pathnet_graph( p_inputs, p_labels, p_task_id, pathnet, training=False) session.run(tf.global_variables_initializer()) tf.contrib.summary.initialize(session=session) saver = tf.train.Saver(tf.global_variables()) start_step = 0 p_task_accuracies = {} accuracy_summary_op = {} for data_split in evaluate_on: p_task_accuracies[data_split], accuracy_summary_op[data_split] = \ create_accuracy_summary_ops( task_names, summary_name_prefix='final_eval_%s' % data_split) if resume_checkpoint_dir is not None: print('Resuming from checkpoint: %s' % resume_checkpoint_dir) last_global_step = int(resume_checkpoint_dir.split('-')[-1]) assert last_global_step % num_tasks == 0 start_step = last_global_step // num_tasks saver.restore(session, resume_checkpoint_dir) for dataset in task_data: for data_split in evaluate_on: num_batches = count_batches( session, dataset[data_split].make_one_shot_iterator().get_next()) dataset[data_split] = dataset[data_split].repeat() dataset[data_split] = ( dataset[data_split].make_one_shot_iterator().get_next()) dataset[data_split] = (dataset[data_split], num_batches) for step in tqdm(range(start_step, num_steps)): random.shuffle(train_data) run_pathnet_training_step( session, p_inputs, p_labels, p_task_id, train_step_op, train_data) if step + 1 in eval_steps: for data_split in evaluate_on: eval_data = [dataset[data_split] for dataset in task_data] print('Running evaluation on: %s' % data_split) task_accuracies = run_pathnet_evaluation( session=session, p_inputs=p_inputs, p_task_id=p_task_id, out_logits=out_logits_eval, task_names=task_names, eval_data=eval_data) run_accuracy_summary_ops( session, p_task_accuracies[data_split], task_accuracies, accuracy_summary_op[data_split]) if (step + 1) % save_checkpoint_every_n_steps == 0: path = summary_dir + '/chkpt' saver.save( session, path, global_step=tf.train.get_or_create_global_step())