Exemple #1
0
def main(_):
    # Load environment.
    environment = utils.load_environment(FLAGS.env_name)
    environment_spec = specs.make_environment_spec(environment)

    # Create Rewarder.
    demonstrations = utils.load_demonstrations(demo_dir=FLAGS.demo_dir,
                                               env_name=FLAGS.env_name)
    pwil_rewarder = rewarder.PWILRewarder(
        demonstrations,
        subsampling=FLAGS.subsampling,
        env_specs=environment_spec,
        num_demonstrations=FLAGS.num_demonstrations,
        observation_only=FLAGS.state_only)

    # Define D4PG agent.
    agent_networks = utils.make_d4pg_networks(environment_spec.actions)
    agent = d4pg.D4PG(
        environment_spec=environment_spec,
        policy_network=agent_networks['policy'],
        critic_network=agent_networks['critic'],
        observation_network=agent_networks['observation'],
        samples_per_insert=FLAGS.samples_per_insert,
        sigma=FLAGS.sigma,
    )

    # Prefill the agent's Replay Buffer.
    utils.prefill_rb_with_demonstrations(
        agent=agent,
        demonstrations=pwil_rewarder.demonstrations,
        num_transitions_rb=FLAGS.num_transitions_rb,
        reward=pwil_rewarder.reward_scale)

    # Create the eval policy (without exploration noise).
    eval_policy = snt.Sequential([
        agent_networks['observation'],
        agent_networks['policy'],
    ])
    eval_agent = FeedForwardActor(policy_network=eval_policy)

    # Define train/eval loops.
    logger = csv_logger.CSVLogger(directory=FLAGS.workdir, label='train_logs')
    train_loop = imitation_loop.TrainEnvironmentLoop(environment,
                                                     agent,
                                                     pwil_rewarder,
                                                     logger=logger)

    eval_logger = csv_logger.CSVLogger(directory=FLAGS.workdir,
                                       label='eval_logs')
    eval_loop = imitation_loop.EvalEnvironmentLoop(environment,
                                                   eval_agent,
                                                   pwil_rewarder,
                                                   logger=eval_logger)

    for _ in range(FLAGS.num_iterations):
        train_loop.run(num_steps=FLAGS.num_steps_per_iteration)
        eval_loop.run(num_episodes=FLAGS.num_eval_episodes)
Exemple #2
0
  def test_logging(self):
    inputs = [{
        'c': 'foo',
        'a': '1337',
        'b': '42.0001',
    }, {
        'c': 'foo2',
        'a': '1338',
        'b': '43.0001',
    }]
    directory = self.get_tempdir()
    label = 'test'
    logger = csv_logger.CSVLogger(directory=directory, label=label)
    for inp in inputs:
      logger.write(inp)
    outputs = []
    file_path = logger.file_path

    # Make sure logger flushed all pending writes to disk.
    del logger
    gc.collect()

    with open(file_path) as f:
      csv_reader = csv.DictReader(f)
      for row in csv_reader:
        outputs.append(dict(row))
    self.assertEqual(outputs, inputs)
Exemple #3
0
def make_default_logger(
    label: str,
    save_data: bool = True,
    time_delta: float = 1.0,
) -> base.Logger:
  """Make a default Acme logger.

  Args:
    label: Name to give to the logger.
    save_data: Ignored.
    time_delta: Time (in seconds) between logging events.

  Returns:
    A logger (pipe) object that responds to logger.write(some_dict).
  """
  terminal_logger = terminal.TerminalLogger(label=label, time_delta=time_delta)

  loggers = [terminal_logger]
  if save_data:
    loggers.append(csv.CSVLogger(label))

  logger = aggregators.Dispatcher(loggers)
  logger = filters.NoneFilter(logger)
  logger = filters.TimeFilter(logger, time_delta)
  return logger
Exemple #4
0
    def test_logging_input_is_file(self, add_uid: bool):

        # Set up logger.
        directory = paths.process_path(self.get_tempdir(),
                                       'logs',
                                       'my_label',
                                       add_uid=add_uid)
        file = open(os.path.join(directory, 'logs.csv'), 'a')
        logger = csv_logger.CSVLogger(directory_or_file=file, add_uid=add_uid)

        # Write data and close.
        for inp in _TEST_INPUTS:
            logger.write(inp)
        logger.close()

        # Logger doesn't close the file; caller must do this manually.
        self.assertFalse(file.closed)
        file.close()

        # Read back data.
        outputs = []
        with open(logger.file_path) as f:
            csv_reader = csv.DictReader(f)
            for row in csv_reader:
                outputs.append(dict(row))
        self.assertEqual(outputs, _TEST_INPUTS)
Exemple #5
0
    def test_logging_input_is_file(self):
        inputs = [{
            'c': 'foo',
            'a': '1337',
            'b': '42.0001',
        }, {
            'c': 'foo2',
            'a': '1338',
            'b': '43.0001',
        }]
        directory = paths.process_path(self.get_tempdir(),
                                       'logs',
                                       'my_label',
                                       add_uid=True)
        file = open(os.path.join(directory, 'logs.csv'), 'a')
        logger = csv_logger.CSVLogger(directory_or_file=file)
        for inp in inputs:
            logger.write(inp)
        outputs = []
        file_path = logger.file_path
        file.close()

        with open(file_path) as f:
            csv_reader = csv.DictReader(f)
            for row in csv_reader:
                outputs.append(dict(row))
        self.assertEqual(outputs, inputs)
Exemple #6
0
    def test_flush(self):

        logger = csv_logger.CSVLogger(self.get_tempdir(), flush_every=1)
        for inp in _TEST_INPUTS:
            logger.write(inp)

        # Read back data.
        outputs = []
        with open(logger.file_path) as f:
            csv_reader = csv.DictReader(f)
            for row in csv_reader:
                outputs.append(dict(row))
        self.assertEqual(outputs, _TEST_INPUTS)
Exemple #7
0
    def test_logging_input_is_directory(self):

        # Set up logger.
        directory = self.get_tempdir()
        label = 'test'
        logger = csv_logger.CSVLogger(directory_or_file=directory, label=label)

        # Write data and close.
        for inp in _TEST_INPUTS:
            logger.write(inp)
        logger.close()

        # Read back data.
        outputs = []
        with open(logger.file_path) as f:
            csv_reader = csv.DictReader(f)
            for row in csv_reader:
                outputs.append(dict(row))
        self.assertEqual(outputs, _TEST_INPUTS)
Exemple #8
0
def make_default_logger(
    label: str,
    save_data: bool = True,
    time_delta: float = 1.0,
    asynchronous: bool = False,
    print_fn: Optional[Callable[[str], None]] = None,
    serialize_fn: Optional[Callable[[Mapping[str, Any]], str]] = base.to_numpy,
    steps_key: str = 'steps',
) -> base.Logger:
    """Makes a default Acme logger.

  Args:
    label: Name to give to the logger.
    save_data: Whether to persist data.
    time_delta: Time (in seconds) between logging events.
    asynchronous: Whether the write function should block or not.
    print_fn: How to print to terminal (defaults to print).
    serialize_fn: An optional function to apply to the write inputs before
      passing them to the various loggers.
    steps_key: Ignored.

  Returns:
    A logger object that responds to logger.write(some_dict).
  """
    del steps_key
    if not print_fn:
        print_fn = logging.info
    terminal_logger = terminal.TerminalLogger(label=label, print_fn=print_fn)

    loggers = [terminal_logger]

    if save_data:
        loggers.append(csv.CSVLogger(label=label))

    # Dispatch to all writers and filter Nones and by time.
    logger = aggregators.Dispatcher(loggers, serialize_fn)
    logger = filters.NoneFilter(logger)
    if asynchronous:
        logger = async_logger.AsyncLogger(logger)
    logger = filters.TimeFilter(logger, time_delta)

    return logger
Exemple #9
0
 def test_logging(self):
     inputs = [{
         'c': 'foo',
         'a': '1337',
         'b': '42.0001',
     }, {
         'c': 'foo2',
         'a': '1338',
         'b': '43.0001',
     }]
     directory = self.get_tempdir()
     label = 'test'
     logger = csv_logger.CSVLogger(directory=directory, label=label)
     for inp in inputs:
         logger.write(inp)
     with open(logger.file_path) as f:
         csv_reader = csv.DictReader(f)
         for idx, row in enumerate(csv_reader):
             row = dict(row)
             self.assertEqual(row, inputs[idx])