def run_training(
        config=None,
        tuner=None,
        logdir=None,
        trial_name=None,  # pylint: disable=unused-argument
        is_chief=True):
    """Do all training runs.

  This is the top level training function for policy gradient based models.
  Run this from the main function.

  Args:
    config: config_lib.Config instance containing global config (agent and
        environment hparams). If None, config will be parsed from FLAGS.config.
    tuner: (unused) A tuner instance. Leave as None if not tuning.
    logdir: Parent directory where all data from all runs will be written. If
        None, FLAGS.logdir will be used.
    trial_name: (unused) If tuning, set this to a unique string that identifies
        this trial. If `tuner` is not None, this also must be set.
    is_chief: True if this worker is the chief.

  Returns:
    List of results dicts which were written to disk. Each training run gets a
    results dict. Results dict contains metrics, i.e. (name, value) pairs which
    give information about the training run.

  Raises:
    ValueError: If FLAGS.num_workers does not divide FLAGS.num_repetitions.
    ValueError: If results dicts read from disk contain invalid data.
  """
    if not config:
        # If custom config is not given, get it from flags.
        config = defaults.default_config_with_updates(FLAGS.config)
    if not logdir:
        logdir = FLAGS.logdir

    if FLAGS.num_repetitions % FLAGS.num_workers != 0:
        raise ValueError('Number of workers must divide number of repetitions')
    num_local_reps = FLAGS.num_repetitions // FLAGS.num_workers
    logging.info('Running %d reps globally.', FLAGS.num_repetitions)
    logging.info('This worker will run %d local reps.', num_local_reps)
    if FLAGS.max_npe:
        max_generations = FLAGS.max_npe // config.batch_size
        logging.info('Max samples per rep: %d', FLAGS.max_npe)
        logging.info('Max generations per rep: %d', max_generations)
    else:
        max_generations = sys.maxint
        logging.info('Running unlimited generations.')

    assert FLAGS.num_workers > 0
    logging.info('Starting experiment. Directory: "%s"', logdir)
    results = results_lib.Results(logdir, FLAGS.task_id)
    local_results_list = results.read_this_shard()
    if local_results_list:
        if local_results_list[0]['max_npe'] != FLAGS.max_npe:
            raise ValueError(
                'Cannot resume training. Max-NPE changed. Was %s, now %s',
                local_results_list[0]['max_npe'], FLAGS.max_npe)
        if local_results_list[0][
                'max_global_repetitions'] != FLAGS.num_repetitions:
            raise ValueError(
                'Cannot resume training. Number of repetitions changed. Was %s, '
                'now %s', local_results_list[0]['max_global_repetitions'],
                FLAGS.num_repetitions)
    start_rep = len(local_results_list)

    for rep in xrange(start_rep, num_local_reps):
        global_rep = num_local_reps * FLAGS.task_id + rep
        logging.info('Starting repetition: Rep = %d. (global rep = %d)', rep,
                     global_rep)

        # Save data for each rep, like checkpoints, goes into separate folders.
        run_dir = os.path.join(logdir, 'run_%d' % global_rep)

        if not tf.gfile.IsDirectory(run_dir):
            tf.gfile.MakeDirs(run_dir)
        checkpoint_writer = CheckpointWriter(run_dir,
                                             population_size=config.batch_size)

        data_manager = data.DataManager(config, run_number=global_rep)
        task_eval_fn = ga_lib.make_task_eval_fn(data_manager.rl_task)

        if config.agent.algorithm == 'rand':
            logging.info('Running random search.')
            assert FLAGS.max_npe
            result = run_random_search(FLAGS.max_npe, run_dir, task_eval_fn,
                                       config.timestep_limit)
        else:
            assert config.agent.algorithm == 'ga'
            logging.info('Running genetic algorithm.')
            pop = ga_lib.make_population(ga_lib.random_individual(
                config.timestep_limit),
                                         n=config.batch_size)
            hof = utils.MaxUniquePriorityQueue(2)  # Hall of fame.
            result = ga_lib.ga_loop(pop,
                                    cxpb=config.agent.crossover_rate,
                                    mutpb=config.agent.mutation_rate,
                                    task_eval_fn=task_eval_fn,
                                    ngen=max_generations,
                                    halloffame=hof,
                                    checkpoint_writer=checkpoint_writer)

        logging.info('Finished rep. Num gens: %d', result.generations)

        results_dict = {
            'max_npe': FLAGS.max_npe,
            'batch_size': config.batch_size,
            'max_batches': FLAGS.max_npe // config.batch_size,
            'npe': result.num_programs,
            'max_global_repetitions': FLAGS.num_repetitions,
            'max_local_repetitions': num_local_reps,
            'code_solution': result.best_code if result.solution_found else '',
            'best_reward': result.reward,
            'num_batches': result.generations,
            'found_solution': result.solution_found,
            'task': data_manager.task_name,
            'global_rep': global_rep
        }
        logging.info('results_dict: %s', results_dict)
        results.append(results_dict)

    if is_chief:
        logging.info(
            'Worker is chief. Waiting for all workers to finish so that results '
            'can be reported to the tuner.')

        global_results_list, shard_stats = results.read_all(
            num_shards=FLAGS.num_workers)
        while not all(s.finished for s in shard_stats):
            logging.info(
                'Still waiting on these workers: %s', ', '.join([
                    '%d (%d reps left)' %
                    (i, s.max_local_reps - s.num_local_reps_completed)
                    for i, s in enumerate(shard_stats) if not s.finished
                ]))
            sleep(60)
            global_results_list, shard_stats = results.read_all(
                num_shards=FLAGS.num_workers)

        logging.info(
            '%d results obtained. Chief worker is exiting the experiment.',
            len(global_results_list))

        return global_results_list
Esempio n. 2
0
def run_training(config=None, tuner=None, logdir=None, trial_name=None,  # pylint: disable=unused-argument
                 is_chief=True):
  """Do all training runs.

  This is the top level training function for policy gradient based models.
  Run this from the main function.

  Args:
    config: config_lib.Config instance containing global config (agent and
        environment hparams). If None, config will be parsed from FLAGS.config.
    tuner: (unused) A tuner instance. Leave as None if not tuning.
    logdir: Parent directory where all data from all runs will be written. If
        None, FLAGS.logdir will be used.
    trial_name: (unused) If tuning, set this to a unique string that identifies
        this trial. If `tuner` is not None, this also must be set.
    is_chief: True if this worker is the chief.

  Returns:
    List of results dicts which were written to disk. Each training run gets a
    results dict. Results dict contains metrics, i.e. (name, value) pairs which
    give information about the training run.

  Raises:
    ValueError: If FLAGS.num_workers does not divide FLAGS.num_repetitions.
    ValueError: If results dicts read from disk contain invalid data.
  """
  if not config:
    # If custom config is not given, get it from flags.
    config = defaults.default_config_with_updates(FLAGS.config)
  if not logdir:
    logdir = FLAGS.logdir

  if FLAGS.num_repetitions % FLAGS.num_workers != 0:
    raise ValueError('Number of workers must divide number of repetitions')
  num_local_reps = FLAGS.num_repetitions // FLAGS.num_workers
  logging.info('Running %d reps globally.', FLAGS.num_repetitions)
  logging.info('This worker will run %d local reps.', num_local_reps)
  if FLAGS.max_npe:
    max_generations = FLAGS.max_npe // config.batch_size
    logging.info('Max samples per rep: %d', FLAGS.max_npe)
    logging.info('Max generations per rep: %d', max_generations)
  else:
    max_generations = sys.maxint
    logging.info('Running unlimited generations.')

  assert FLAGS.num_workers > 0
  logging.info('Starting experiment. Directory: "%s"', logdir)
  results = results_lib.Results(logdir, FLAGS.task_id)
  local_results_list = results.read_this_shard()
  if local_results_list:
    if local_results_list[0]['max_npe'] != FLAGS.max_npe:
      raise ValueError(
          'Cannot resume training. Max-NPE changed. Was %s, now %s',
          local_results_list[0]['max_npe'], FLAGS.max_npe)
    if local_results_list[0]['max_global_repetitions'] != FLAGS.num_repetitions:
      raise ValueError(
          'Cannot resume training. Number of repetitions changed. Was %s, '
          'now %s',
          local_results_list[0]['max_global_repetitions'],
          FLAGS.num_repetitions)
  start_rep = len(local_results_list)

  for rep in xrange(start_rep, num_local_reps):
    global_rep = num_local_reps * FLAGS.task_id + rep
    logging.info(
        'Starting repetition: Rep = %d. (global rep = %d)',
        rep, global_rep)

    # Save data for each rep, like checkpoints, goes into separate folders.
    run_dir = os.path.join(logdir, 'run_%d' % global_rep)

    if not tf.gfile.IsDirectory(run_dir):
      tf.gfile.MakeDirs(run_dir)
    checkpoint_writer = CheckpointWriter(run_dir,
                                         population_size=config.batch_size)

    data_manager = data.DataManager(config, run_number=global_rep)
    task_eval_fn = ga_lib.make_task_eval_fn(data_manager.rl_task)

    if config.agent.algorithm == 'rand':
      logging.info('Running random search.')
      assert FLAGS.max_npe
      result = run_random_search(
          FLAGS.max_npe, run_dir, task_eval_fn, config.timestep_limit)
    else:
      assert config.agent.algorithm == 'ga'
      logging.info('Running genetic algorithm.')
      pop = ga_lib.make_population(
          ga_lib.random_individual(config.timestep_limit),
          n=config.batch_size)
      hof = utils.MaxUniquePriorityQueue(2)  # Hall of fame.
      result = ga_lib.ga_loop(
          pop,
          cxpb=config.agent.crossover_rate, mutpb=config.agent.mutation_rate,
          task_eval_fn=task_eval_fn,
          ngen=max_generations, halloffame=hof,
          checkpoint_writer=checkpoint_writer)

    logging.info('Finished rep. Num gens: %d', result.generations)

    results_dict = {
        'max_npe': FLAGS.max_npe,
        'batch_size': config.batch_size,
        'max_batches': FLAGS.max_npe // config.batch_size,
        'npe': result.num_programs,
        'max_global_repetitions': FLAGS.num_repetitions,
        'max_local_repetitions': num_local_reps,
        'code_solution': result.best_code if result.solution_found else '',
        'best_reward': result.reward,
        'num_batches': result.generations,
        'found_solution': result.solution_found,
        'task': data_manager.task_name,
        'global_rep': global_rep}
    logging.info('results_dict: %s', results_dict)
    results.append(results_dict)

  if is_chief:
    logging.info(
        'Worker is chief. Waiting for all workers to finish so that results '
        'can be reported to the tuner.')

    global_results_list, shard_stats = results.read_all(
        num_shards=FLAGS.num_workers)
    while not all(s.finished for s in shard_stats):
      logging.info(
          'Still waiting on these workers: %s',
          ', '.join(
              ['%d (%d reps left)'
               % (i, s.max_local_reps - s.num_local_reps_completed)
               for i, s in enumerate(shard_stats)
               if not s.finished]))
      sleep(60)
      global_results_list, shard_stats = results.read_all(
          num_shards=FLAGS.num_workers)

    logging.info(
        '%d results obtained. Chief worker is exiting the experiment.',
        len(global_results_list))

    return global_results_list