Exemplo n.º 1
0
 def __init__(self,
              variable_source: core.VariableSource,
              models: Dict[str, Callable[[core.VariableSource],
                                         types.ModelToSnapshot]],
              path: str,
              subdirectory: Optional[str] = None,
              max_to_keep: Optional[int] = None,
              add_uid: bool = False):
   self._variable_source = variable_source
   self._models = models
   if subdirectory is not None:
     self._path = paths.process_path(path, subdirectory, add_uid=add_uid)
   else:
     self._path = paths.process_path(path, add_uid=add_uid)
   self._max_to_keep = max_to_keep
   self._snapshot_paths: Optional[List[str]] = None
Exemplo n.º 2
0
    def __init__(
        self,
        objects_to_save: Mapping[str, snt.Module],
        *,
        directory: str = '~/acme/',
        time_delta_minutes: float = 30.0,
        snapshot_ttl_seconds: int = _DEFAULT_SNAPSHOT_TTL,
    ):
        """Builds the saver object.

    Args:
      objects_to_save: Mapping specifying what to snapshot.
      directory: Which directory to put the snapshot in.
      time_delta_minutes: How often to save the snapshot, in minutes.
      snapshot_ttl_seconds: TTL (time to leave) in seconds for snapshots.
    """
        objects_to_save = objects_to_save or {}

        self._time_delta_minutes = time_delta_minutes
        self._last_saved = 0.
        self._snapshots = {}

        # Save the base directory path so we can refer to it if needed.
        self.directory = paths.process_path(directory,
                                            'snapshots',
                                            ttl_seconds=snapshot_ttl_seconds)

        # Save a dictionary mapping paths to snapshot capable models.
        for name, module in objects_to_save.items():
            path = os.path.join(self.directory, name)
            self._snapshots[path] = make_snapshot(module)
Exemplo n.º 3
0
    def test_logging_input_is_file(self, add_uid: bool):

        # Set up logger.
        directory = paths.process_path(self.get_tempdir(),
                                       'logs',
                                       'my_label',
                                       add_uid=add_uid)
        file = open(os.path.join(directory, 'logs.csv'), 'a')
        logger = csv_logger.CSVLogger(directory_or_file=file, add_uid=add_uid)

        # Write data and close.
        for inp in _TEST_INPUTS:
            logger.write(inp)
        logger.close()

        # Logger doesn't close the file; caller must do this manually.
        self.assertFalse(file.closed)
        file.close()

        # Read back data.
        outputs = []
        with open(logger.file_path) as f:
            csv_reader = csv.DictReader(f)
            for row in csv_reader:
                outputs.append(dict(row))
        self.assertEqual(outputs, _TEST_INPUTS)
Exemplo n.º 4
0
    def test_logging_input_is_file(self):
        inputs = [{
            'c': 'foo',
            'a': '1337',
            'b': '42.0001',
        }, {
            'c': 'foo2',
            'a': '1338',
            'b': '43.0001',
        }]
        directory = paths.process_path(self.get_tempdir(),
                                       'logs',
                                       'my_label',
                                       add_uid=True)
        file = open(os.path.join(directory, 'logs.csv'), 'a')
        logger = csv_logger.CSVLogger(directory_or_file=file)
        for inp in inputs:
            logger.write(inp)
        outputs = []
        file_path = logger.file_path
        file.close()

        with open(file_path) as f:
            csv_reader = csv.DictReader(f)
            for row in csv_reader:
                outputs.append(dict(row))
        self.assertEqual(outputs, inputs)
Exemplo n.º 5
0
def train_and_evaluate(train_ds, eval_ds, test_ds, key, test_cutoffs=None):
    """Main training loop for distance learning."""
    ckpt_dir = paths.process_path(FLAGS.logdir, 'ckpts', add_uid=True)
    state = train_utils.restore_or_initialize(FLAGS.encoder_conv_filters,
                                              FLAGS.encoder_conv_size, key,
                                              ckpt_dir, FLAGS.learning_rate)

    summary_dir = paths.process_path(FLAGS.logdir, 'tb', add_uid=True)
    summary_writer = tf.summary.create_file_writer(summary_dir)

    best_epoch = {}
    if state.epoch == 0:
        # Evaluate initialization.
        eval_metrics = eval_on_dataset(state, eval_ds)
        test_metrics = eval_on_dataset(state, test_ds, cutoffs=test_cutoffs)
        logging.log_metrics(eval_metrics, summary_writer, state.step, 'val')
        logging.log_metrics_to_stdout(eval_metrics, state.epoch, 'val epoch  ')
        logging.log_metrics(test_metrics, summary_writer, state.step, 'test')
        update_best_loss(eval_metrics, state, ckpt_dir, best_epoch, 'val')
        update_best_loss(test_metrics, state, ckpt_dir, best_epoch, 'test')
    for e in range(state.epoch, FLAGS.num_epochs):
        metrics = collections.defaultdict(float)
        counts = collections.defaultdict(int)
        train_iter = {k: iter(v) for k, v in train_ds.items()}
        for batch in zip(*train_iter.values()):
            batch = dict(zip(train_iter.keys(), batch))
            dist_grads, domain_grads, batch_metrics, _ = apply_model(
                state, batch)
            state = update_model(state, dist_grads, domain_grads)
            update_incremental_mean(metrics, counts, batch_metrics, batch)
        state = state.replace(epoch=e + 1)
        # Note: train 'epoch' repeats short datasets, validation epoch does not.
        metrics['total_distance_model_loss'] = (
            metrics['distance_loss'] + metrics['paired_loss'] +
            metrics['adversarial_domain_loss'])
        logging.log_metrics(metrics, summary_writer, e + 1, 'train')
        logging.log_metrics_to_stdout(metrics, e + 1, 'train epoch')

        eval_metrics = eval_on_dataset(state, eval_ds)
        logging.log_metrics(eval_metrics, summary_writer, e + 1, 'val')
        logging.log_metrics_to_stdout(eval_metrics, e + 1, 'val epoch  ')
        update_best_loss(eval_metrics, state, ckpt_dir, best_epoch, 'val')
        test_metrics = eval_on_dataset(state, test_ds, cutoffs=test_cutoffs)
        logging.log_metrics(test_metrics, summary_writer, e + 1, 'test')
        update_best_loss(test_metrics, state, ckpt_dir, best_epoch, 'test')
        train_utils.save_checkpoint(ckpt_dir, state)
Exemplo n.º 6
0
 def __init__(self,
              variable_source: core.VariableSource,
              models: Dict[str, Callable[[core.VariableSource],
                                         types.ModelToSnapshot]],
              path: str,
              add_uid: bool = False):
   self._variable_source = variable_source
   self._models = models
   self._path = paths.process_path(path, add_uid=add_uid)
Exemplo n.º 7
0
    def _path(self, subdir: Optional[str] = None) -> str:
        if subdir:
            path = str(self._directory / self._time_stamp / subdir /
                       self._label)
        else:
            path = str(self._directory / self._time_stamp / self._label)

        # Recursively replace "~"
        return paths.process_path(path)
Exemplo n.º 8
0
 def __init__(self,
              directory: str = '~/acme',
              label: str = '',
              time_delta: float = 0.):
     directory = paths.process_path(directory, 'logs', label, add_uid=True)
     self._file_path = os.path.join(directory, 'logs.csv')
     logging.info('Logging to %s', self._file_path)
     self._time = time.time()
     self._time_delta = time_delta
     self._header_exists = False
Exemplo n.º 9
0
    def __init__(
        self,
        environment: dm_env.Environment,
        executor: mava.core.Executor,
        filename: str = "agents",
        counter: counting.Counter = None,
        logger: loggers.Logger = None,
        should_update: bool = True,
        label: str = "parallel_environment_loop",
        record_every: int = 1000,
        path: str = "~/mava",
        fps: int = 15,
        counter_str: str = "evaluator_episodes",
        format: str = "video",
        figsize: Union[float, Tuple[int, int]] = (360, 640),
    ):
        assert (
            format == "gif" or format == "video"
        ), "Only gif and video format are supported."

        # Internalize agent and environment.
        super().__init__(
            environment=environment,
            executor=executor,
            counter=counter,
            logger=logger,
            should_update=should_update,
            label=label,
        )
        self._record_every = record_every
        self._path = paths.process_path(path, "recordings", add_uid=False)
        self._filename = filename
        self._record_current_episode = False
        self._fps = fps
        self._frames: List = []

        self._parent_environment_step = self._environment.step
        self._environment.step = self.step

        self._parent_environment_reset = self._environment.reset
        self._environment.reset = self.reset

        self._counter_str = counter_str

        self._format = format
        self._figsize = figsize
Exemplo n.º 10
0
 def _create_file(self, directory_or_file: Union[str, TextIO],
                  label: str) -> TextIO:
     """Opens a file if input is a directory or use existing file."""
     if isinstance(directory_or_file, str):
         directory = paths.process_path(directory_or_file,
                                        'logs',
                                        label,
                                        add_uid=True)
         file_path = os.path.join(directory, 'logs.csv')
         file = self._open(file_path, mode='a')
     else:
         file = directory_or_file
         if label:
             logging.info(
                 'File, not directory, passed to CSVLogger; label not '
                 'used.')
         if not file.mode.startswith('a'):
             raise ValueError(
                 'File must be open in append mode; instead got '
                 f'{file.mode}.')
     return file
Exemplo n.º 11
0
  def __init__(
      self,
      object_to_save: core.Saveable,
      directory: str = '~/acme/',
      subdirectory: str = 'default',
      time_delta_minutes: float = 10.,
      add_uid: bool = True,
      checkpoint_ttl_seconds: int = _DEFAULT_CHECKPOINT_TTL,
  ):
    """Builds the saver object.

    Args:
      object_to_save: The object to save in this checkpoint, this must have a
        save and restore method.
      directory: Which directory to put the checkpoint in.
      subdirectory: Sub-directory to use (e.g. if multiple checkpoints are being
        saved).
      time_delta_minutes: How often to save the checkpoint, in minutes.
      add_uid: If True adds a UID to the checkpoint path, see
        `paths.get_unique_id()` for how this UID is generated.
      checkpoint_ttl_seconds: TTL (time to live) in seconds for checkpoints.
    """
    # TODO(tamaranorman) accept a Union[Saveable, Mapping[str, Saveable]] here
    self._object_to_save = object_to_save
    self._time_delta_minutes = time_delta_minutes

    self._last_saved = 0.
    self._lock = threading.Lock()

    self._checkpoint_dir = paths.process_path(
        directory,
        'checkpoints',
        subdirectory,
        ttl_seconds=checkpoint_ttl_seconds,
        backups=False,
        add_uid=add_uid)

    # Restore from the most recent checkpoint (if it exists).
    self.restore()
Exemplo n.º 12
0
    def __init__(
        self,
        objects_to_save: Mapping[str, Union[Checkpointable, core.Saveable]],
        *,
        directory: str = '~/acme/',
        subdirectory: str = 'default',
        time_delta_minutes: float = 10.0,
        enable_checkpointing: bool = True,
        add_uid: bool = True,
        max_to_keep: int = 1,
        checkpoint_ttl_seconds: int = _DEFAULT_CHECKPOINT_TTL,
        keep_checkpoint_every_n_hours: Optional[int] = None,
    ):
        """Builds the saver object.

    Args:
      objects_to_save: Mapping specifying what to checkpoint.
      directory: Which directory to put the checkpoint in.
      subdirectory: Sub-directory to use (e.g. if multiple checkpoints are being
        saved).
      time_delta_minutes: How often to save the checkpoint, in minutes.
      enable_checkpointing: whether to checkpoint or not.
      add_uid: If True adds a UID to the checkpoint path, see
        `paths.get_unique_id()` for how this UID is generated.
      max_to_keep: The maximum number of checkpoints to keep.
      checkpoint_ttl_seconds: TTL (time to leave) in seconds for checkpoints.
      keep_checkpoint_every_n_hours: keep_checkpoint_every_n_hours passed to
        tf.train.CheckpointManager.
    """

        # Convert `Saveable` objects to TF `Checkpointable` first, if necessary.
        def to_ckptable(
                x: Union[Checkpointable, core.Saveable]) -> Checkpointable:
            if isinstance(x, core.Saveable):
                return SaveableAdapter(x)
            return x

        objects_to_save = {
            k: to_ckptable(v)
            for k, v in objects_to_save.items()
        }

        self._time_delta_minutes = time_delta_minutes
        self._last_saved = 0.
        self._enable_checkpointing = enable_checkpointing
        self._checkpoint_manager = None

        if enable_checkpointing:
            # Checkpoint object that handles saving/restoring.
            self._checkpoint = tf.train.Checkpoint(**objects_to_save)
            self._checkpoint_dir = paths.process_path(
                directory,
                'checkpoints',
                subdirectory,
                ttl_seconds=checkpoint_ttl_seconds,
                backups=False,
                add_uid=add_uid)

            # Create a manager to maintain different checkpoints.
            self._checkpoint_manager = tf.train.CheckpointManager(
                self._checkpoint,
                directory=self._checkpoint_dir,
                max_to_keep=max_to_keep,
                keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours)

            self.restore()
Exemplo n.º 13
0
def train_and_evaluate(distance_fn, rng):
    """Train a policy on the learned distance function and evaluate task success.

  Args:
    distance_fn: function mapping a (state, goal)-pair to a state embedding and
        a distance estimate used for policy learning.
    rng: random key used to initialize evaluation actor.
  """
    goal_image = load_goal_image(FLAGS.robot_data_path)
    logdir = FLAGS.logdir
    video_dir = paths.process_path(logdir, 'videos')
    print('Writing videos to', video_dir)
    counter = counting.Counter()
    eval_counter = counting.Counter(counter, prefix='eval', time_delta=0.0)
    # Include training episodes and steps and walltime in the first eval logs.
    counter.increment(episodes=0, steps=0, walltime=0)

    environment = make_environment(
        task=FLAGS.task,
        end_on_success=FLAGS.end_on_success,
        max_episode_steps=FLAGS.max_episode_steps,
        distance_fn=distance_fn,
        goal_image=goal_image,
        baseline_distance=FLAGS.baseline_distance,
        logdir=video_dir,
        counter=counter,
        record_every=FLAGS.record_episodes_frequency,
        num_episodes_to_record=FLAGS.num_episodes_to_record)
    environment_spec = specs.make_environment_spec(environment)
    print('Environment spec')
    print(environment_spec)
    agent_networks = sac.make_networks(environment_spec)

    config = sac.SACConfig(
        target_entropy=sac.target_entropy_from_env_spec(environment_spec),
        num_sgd_steps_per_step=FLAGS.num_sgd_steps_per_step,
        min_replay_size=FLAGS.min_replay_size)
    agent = deprecated_sac.SAC(environment_spec,
                               agent_networks,
                               config=config,
                               counter=counter,
                               seed=FLAGS.seed)

    env_logger = loggers.CSVLogger(logdir, 'env_loop', flush_every=5)
    eval_env_logger = loggers.CSVLogger(logdir, 'eval_env_loop', flush_every=1)
    train_loop = acme.EnvironmentLoop(environment,
                                      agent,
                                      label='train_loop',
                                      logger=env_logger,
                                      counter=counter)

    eval_actor = agent.builder.make_actor(random_key=rng,
                                          policy=sac.apply_policy_and_sample(
                                              agent_networks, eval_mode=True),
                                          environment_spec=environment_spec,
                                          variable_source=agent)

    eval_video_dir = paths.process_path(logdir, 'eval_videos')
    print('Writing eval videos to', eval_video_dir)
    if FLAGS.baseline_distance_from_goal_to_goal:
        state = goal_image
        if distance_fn.history_length > 1:
            state = np.stack([goal_image] * distance_fn.history_length,
                             axis=-1)
        unused_embeddings, baseline_distance = distance_fn(state, goal_image)
        print('Baseline prediction', baseline_distance)
    else:
        baseline_distance = FLAGS.baseline_distance
    eval_env = make_environment(task=FLAGS.task,
                                end_on_success=False,
                                max_episode_steps=FLAGS.max_episode_steps,
                                distance_fn=distance_fn,
                                goal_image=goal_image,
                                eval_mode=True,
                                logdir=eval_video_dir,
                                counter=eval_counter,
                                record_every=FLAGS.num_eval_episodes,
                                num_episodes_to_record=FLAGS.num_eval_episodes,
                                baseline_distance=baseline_distance)

    eval_loop = acme.EnvironmentLoop(eval_env,
                                     eval_actor,
                                     label='eval_loop',
                                     logger=eval_env_logger,
                                     counter=eval_counter)

    assert FLAGS.num_steps % FLAGS.eval_every == 0
    for _ in range(FLAGS.num_steps // FLAGS.eval_every):
        eval_loop.run(num_episodes=FLAGS.num_eval_episodes)
        train_loop.run(num_steps=FLAGS.eval_every)
    eval_loop.run(num_episodes=FLAGS.num_eval_episodes)
Exemplo n.º 14
0
 def test_process_path(self):
     root_directory = self.get_tempdir()
     with mock.patch.object(paths, 'get_unique_id') as mock_unique_id:
         mock_unique_id.return_value = ('test', )
         path = paths.process_path(root_directory, 'foo', 'bar')
     self.assertEqual(path, f'{root_directory}/test/foo/bar')