def run_omniglot_experiment( pathnet_layers, training_hparams, task_names, task_data, resume_checkpoint_dir=None): """Runs the Omniglot experiment. Args: pathnet_layers: (list of `pn.ComponentsLayer`s) layers that make up the PathNet model. training_hparams: (tf.contrib.training.HParams) training hyperparameters. task_names: (list of strings) names of tasks. task_data: (list of dicts) list of dictionaries, one per task. Each dictionary should map strings 'train', 'validation' and 'test' into `tf.data.Dataset`s for training, validation, and testing, respectively. resume_checkpoint_dir: (string or None) directory for the checkpoint to reload, or None if should start from scratch. """ for task_id in range(len(task_data)): task_data[task_id] = data.batch_all( task_data[task_id], training_hparams.batch_size) utils.run_pathnet_training_and_evaluation( task_names=task_names, task_data=task_data, input_data_shape=_OMNIGLOT_INPUT_SHAPE, training_hparams=training_hparams, components_layers=pathnet_layers, evaluate_on=['train', 'validation', 'test'], resume_checkpoint_dir=resume_checkpoint_dir, summary_dir=FLAGS.logdir)
def run_routing_experiment( pathnet_layers, training_hparams, task_names, task_data, resume_checkpoint_dir=None, intermediate_eval_steps=[]): """Runs the three task clusters routing experiment. Args: pathnet_layers: (list of `pn.ComponentsLayer`s) layers that make up the PathNet model. training_hparams: (tf.contrib.training.HParams) training hyperparameters. task_names: (list of strings) names of tasks. task_data: (list of dicts) list of dictionaries, one per task. Each dictionary should map strings 'train', 'validation' and 'test' into `tf.data.Dataset`s for training, validation, and testing, respectively. resume_checkpoint_dir: (string or None) directory for the checkpoint to reload, or None if should start from scratch. intermediate_eval_steps: (list of ints) training step numbers at which accuracy should be evaluated. """ for dataset in task_data: dataset['train_clean'] = dataset['train'] dataset['train'] = dataset['train'].map(data.augment_with_random_crop) for task_id in range(len(task_data)): task_data[task_id] = data.batch_all( task_data[task_id], training_hparams.batch_size) utils.run_pathnet_training_and_evaluation( task_names=task_names, task_data=task_data, input_data_shape=_INPUT_SHAPE, training_hparams=training_hparams, components_layers=pathnet_layers, evaluate_on=['train_clean', 'test'], resume_checkpoint_dir=resume_checkpoint_dir, summary_dir=FLAGS.logdir, intermediate_eval_steps=intermediate_eval_steps, save_checkpoint_every_n_steps=sys.maxsize)
def construct_pathnet_and_run_mnist_experiment(task_names, task_data, num_classes_for_tasks, router_fn): """Runs the MNIST experiment. Args: task_names: (list of strings) names of tasks. task_data: (list of dicts) list of dictionaries, one per task. Each dictionary should map strings 'train' and 'test' into `tf.data.Dataset`s for training and testing, respectively. num_classes_for_tasks: (list of ints) number of classes for each task. router_fn: function that, given a single argument `num_components`, returns a router (see routers in `pathnet/pathnet_lib.py`) for a layer containing `num_components` components. """ num_tasks = len(task_names) input_data_shape = [28, 28, 1] batch_size = 16 for task_id in range(num_tasks): task_data[task_id] = data.batch_all(task_data[task_id], batch_size) # Train each task for 10 epochs n_epochs = 10 training_hparams = tf.contrib.training.HParams(num_steps=n_epochs * 60000 // batch_size, batch_size=batch_size, learning_rate=0.005) routers = [] def get_router(num_components): routers.append(router_fn(num_components)) return routers[-1] # PathNet layers keras_layers = models.get_keras_layers_for_mnist_experiment( num_components=num_tasks) pathnet_layers = models.build_model_from_keras_layers( input_data_shape, num_tasks, keras_layers, get_router) # Task-specific linear heads pathnet_layers.append( utils.create_layer_with_task_specific_linear_heads( num_classes_for_tasks)) # Output components to compute task loss auxiliary_loss_fn = utils.create_auxiliary_loss_function( routers=routers, num_total_components=12, num_total_steps=training_hparams.num_steps * num_tasks, budget=FLAGS.budget, budget_penalty=FLAGS.budget_penalty, entropy_penalty=FLAGS.entropy_penalty, entropy_penalty_alpha=FLAGS.entropy_penalty_alpha) def component_fn(): return pn_components.ModelHeadComponent( loss_fn=loss_fn, auxiliary_loss_fn=auxiliary_loss_fn) pathnet_layers.append( create_uniform_layer(num_components=num_tasks, component_fn=component_fn, combiner_fn=pn.SelectCombiner, router_fn=lambda: None)) utils.run_pathnet_training_and_evaluation( task_names=task_names, task_data=task_data, input_data_shape=input_data_shape, training_hparams=training_hparams, components_layers=pathnet_layers, evaluate_on=['train', 'test'], summary_dir=FLAGS.logdir)
def main(_): num_alphabets = 20 task_names = ['Omniglot-%d' % task_id for task_id in range(num_alphabets)] task_data, num_classes = data.get_data_for_multitask_omniglot_setup( num_alphabets) batch_size = 16 for task_id in range(num_alphabets): task_data[task_id] = data.batch_all(task_data[task_id], batch_size) router_fn = utils.get_router_fn_by_name(num_alphabets, FLAGS.method) session = tf.Session(graph=tf.get_default_graph()) tf.train.get_or_create_global_step() summary_writer = tf.contrib.summary.create_file_writer(FLAGS.logdir) summary_writer.set_as_default() tf.contrib.summary.initialize(session=session) p_inputs, p_task_id, out_logits = build_pathnet_eval_graph( task_names, batch_size, num_classes, router_fn) evaluate_on = ['train', 'validation', 'test'] p_task_accuracies = {} accuracy_summary_op = {} for data_split in evaluate_on: (p_task_accuracies[data_split], accuracy_summary_op[data_split]) =\ utils.create_accuracy_summary_ops( task_names, summary_name_prefix='eval_%s' % data_split) # This `Saver` is not used to save variables, only to restore them from # the checkpoints. saver = tf.train.Saver(tf.global_variables()) previous_checkpoint_path = '' time_waited_for_checkpoints = 0 while time_waited_for_checkpoints < _MAX_WAIT_FOR_NEW_CHECKPOINTS: latest_checkpoint_path = tf.train.latest_checkpoint(FLAGS.logdir) if latest_checkpoint_path in [None, previous_checkpoint_path]: print('Found no new checkpoints') time_waited_for_checkpoints += _CHECK_FOR_CHECKPOINTS_FREQUENCY time.sleep(_CHECK_FOR_CHECKPOINTS_FREQUENCY) continue else: time_waited_for_checkpoints = 0 print('Reloading checkpoint: %s' % latest_checkpoint_path) previous_checkpoint_path = latest_checkpoint_path saver.restore(session, latest_checkpoint_path) for data_split in evaluate_on: eval_data = [ dataset[data_split].make_one_shot_iterator().get_next() for dataset in task_data ] print('Evaluating on: %s' % data_split) task_accuracies = utils.run_pathnet_evaluation( session, p_inputs, p_task_id, out_logits, task_names, eval_data) utils.run_accuracy_summary_ops( session, p_task_accuracies[data_split], task_accuracies, accuracy_summary_op[data_split])