Exemplo n.º 1
0
 def test_device_ctx(self):
     with alf.device("cpu"):
         self.assertEqual(alf.get_default_device(), "cpu")
         self.assertEqual(torch.tensor([1]).device.type, "cpu")
         if torch.cuda.is_available():
             with alf.device("cuda"):
                 self.assertEqual(alf.get_default_device(), "cuda")
                 self.assertEqual(torch.tensor([1]).device.type, "cuda")
         self.assertEqual(alf.get_default_device(), "cpu")
         self.assertEqual(torch.tensor([1]).device.type, "cpu")
Exemplo n.º 2
0
 def _stack_time_steps(self, time_steps):
     """Given a list of TimeStep, combine to one with a batch dimension."""
     if self._flatten:
         stacked = nest.fast_map_structure_flatten(
             lambda *arrays: torch.stack(arrays),
             self._time_step_with_env_info_spec, *time_steps)
     else:
         stacked = nest.fast_map_structure(
             lambda *arrays: torch.stack(arrays), *time_steps)
     if alf.get_default_device() == "cuda":
         cpu = stacked
         stacked = nest.map_structure(lambda x: x.cuda(), cpu)
         stacked._cpu = cpu
     return stacked
Exemplo n.º 3
0
    def train_iter(self, num_particles=None, state=None):
        """Perform one epoch (iteration) of training.

        Args:
            num_particles (int): number of sampled particles. Default is None.
            state: not used

        Return:
            mini_batch number
        """

        assert self._train_loader is not None, "Must set data_loader first."
        alf.summary.increment_global_counter()
        with record_time("time/train"):
            loss = 0.
            if self._loss_type == 'classification':
                avg_acc = []
            for batch_idx, (data, target) in enumerate(self._train_loader):
                data = data.to(alf.get_default_device())
                target = target.to(alf.get_default_device())
                alg_step = self.train_step((data, target),
                                           num_particles=num_particles,
                                           state=state)
                loss_info, params = self.update_with_gradient(alg_step.info)
                loss += loss_info.extra.generator.loss
                if self._loss_type == 'classification':
                    avg_acc.append(alg_step.info.extra.generator.extra)
        acc = None
        if self._loss_type == 'classification':
            acc = torch.as_tensor(avg_acc).mean() * 100
        if self._logging_training:
            if self._loss_type == 'classification':
                logging.info("Avg acc: {}".format(acc))
            logging.info("Cum loss: {}".format(loss))
        self.summarize_train(loss_info, params, cum_loss=loss, avg_acc=acc)

        return batch_idx + 1
Exemplo n.º 4
0
    def evaluate(self, num_particles=None):
        """Evaluate on a randomly drawn ensemble. 

        Args:
            num_particles (int): number of sampled particles. Default is None.
        """

        assert self._test_loader is not None, "Must set test_loader first."
        logging.info("==> Begin testing")
        if self._use_fc_bn:
            self._generator.eval()
        params = self.sample_parameters(num_particles=num_particles)
        self._param_net.set_parameters(params)
        if self._use_fc_bn:
            self._generator.train()
        with record_time("time/test"):
            if self._loss_type == 'classification':
                test_acc = 0.
            test_loss = 0.
            for i, (data, target) in enumerate(self._test_loader):
                data = data.to(alf.get_default_device())
                target = target.to(alf.get_default_device())
                output, _ = self._param_net(data)  # [B, N, D]
                loss, extra = self._vote(output, target)
                if self._loss_type == 'classification':
                    test_acc += extra.item()
                test_loss += loss.loss.item()

        if self._loss_type == 'classification':
            test_acc /= len(self._test_loader.dataset)
            alf.summary.scalar(name='eval/test_acc', data=test_acc * 100)
        if self._logging_evaluate:
            if self._loss_type == 'classification':
                logging.info("Test acc: {}".format(test_acc * 100))
            logging.info("Test loss: {}".format(test_loss))
        alf.summary.scalar(name='eval/test_loss', data=test_loss)
Exemplo n.º 5
0
def convert_device(nests, device=None):
    """Convert the device of the tensors in nests to the specified
        or to the default device.
    Args:
        nests (nested Tensors): Nested list/tuple/dict of Tensors.
        device (None|str): the target device, should either be `cuda` or `cpu`.
            If None, then the default device will be used as the target device.
    Returns:
        nests (nested Tensors): Nested list/tuple/dict of Tensors after device
            conversion.

    Raises:
        NotImplementedError if the target device is not one of
            None, `cpu` or `cuda` when cuda is available, or AssertionError
            if target device is `cuda` but cuda is unavailable.


    """
    def _convert_cuda(tensor):
        if tensor.device.type != 'cuda':
            return tensor.cuda()
        else:
            return tensor

    def _convert_cpu(tensor):
        if tensor.device.type != 'cpu':
            return tensor.cpu()
        else:
            return tensor

    if device is None:
        d = alf.get_default_device()
    else:
        d = device

    if d == 'cpu':
        return nest.map_structure(_convert_cpu, nests)
    elif d == 'cuda':
        assert torch.cuda.is_available(), "cuda is unavailable"
        return nest.map_structure(_convert_cuda, nests)
    else:
        raise NotImplementedError("Unknown device %s" % d)
Exemplo n.º 6
0
 def __init__(self,
              tensor_spec: TensorSpec,
              window_size,
              name="WindowAverager"):
     """
     WindowAverager calculate the average of the past ``window_size`` samples.
     Args:
         tensor_spec (nested TensorSpec): the ``TensorSpec`` for the value to be
             averaged
         window_size (int): the size of the window
         name (str): name of this averager
     """
     super().__init__()
     self._name = name
     self._buf = alf.nest.map_structure(
         # Should put data on the default device instead of "cpu"
         lambda spec: DataBuffer(spec, window_size, alf.get_default_device(
         )),
         tensor_spec)
     self._tensor_spec = tensor_spec
Exemplo n.º 7
0
def convert_device(nests):
    """Convert the device of the tensors in nests to default device."""
    def _convert_cuda(tensor):
        if tensor.device.type != 'cuda':
            return tensor.cuda()
        else:
            return tensor

    def _convert_cpu(tensor):
        if tensor.device.type != 'cpu':
            return tensor.cpu()
        else:
            return tensor

    d = alf.get_default_device()
    if d == 'cpu':
        return nest.map_structure(_convert_cpu, nests)
    elif d == 'cuda':
        return nest.map_structure(_convert_cuda, nests)
    else:
        raise NotImplementedError("Unknown device %s" % d)
Exemplo n.º 8
0
    def get_batch(self, batch_size, batch_length):
        """Randomly get ``batch_size`` trajectories from the buffer.

        It could hindsight relabel the experience via postprocess_exp_fn.

        Note: The environments where the sampels are from are ordered in the
            returned batch.

        Args:
            batch_size (int): get so many trajectories
            batch_length (int): the length of each trajectory
        Returns:
            tuple:
                - nested Tensors: The samples. Its shapes are [batch_size, batch_length, ...]
                - BatchInfo: Information about the batch. Its shapes are [batch_size].
                    - env_ids: environment id for each sequence
                    - positions: starting position in the replay buffer for each sequence.
                    - importance_weights: priority divided by the average of all
                        non-zero priorities in the buffer.
        """
        with alf.device(self._device):
            recent_batch_size = 0
            if self._recent_data_ratio > 0:
                d = batch_length - 1 + self._num_earliest_frames_ignored
                avg_size = self.total_size / float(self._num_envs) - d
                if (avg_size * self._recent_data_ratio >
                        self._recent_data_steps):
                    # If this condition is False, regular sampling without considering
                    # recent data will get enough samples from recent data. So
                    # we don't need to have a separate step just for sampling from
                    # the recent data.
                    recent_batch_size = math.ceil(batch_size *
                                                  self._recent_data_ratio)

            normal_batch_size = batch_size - recent_batch_size
            if self._prioritized_sampling:
                info = self._prioritized_sample(normal_batch_size,
                                                batch_length)
            else:
                info = self._uniform_sample(normal_batch_size, batch_length)

            if recent_batch_size > 0:
                # Note that _uniform_sample() get samples duplicated with those
                # from _recent_sample()
                recent_info = self._recent_sample(recent_batch_size,
                                                  batch_length)
                info = alf.nest.map_structure(lambda *x: torch.cat(x),
                                              recent_info, info)

            start_pos = info.positions
            env_ids = info.env_ids

            idx = start_pos.reshape(-1, 1)  # [B, 1]
            idx = self.circular(
                idx + torch.arange(batch_length).unsqueeze(0))  # [B, T]
            out_env_ids = env_ids.reshape(-1,
                                          1).expand(batch_size,
                                                    batch_length)  # [B, T]
            result = alf.nest.map_structure(lambda b: b[(out_env_ids, idx)],
                                            self._buffer)

            if alf.summary.should_record_summaries():
                alf.summary.scalar(
                    "replayer/" + self._name + ".original_reward_mean",
                    torch.mean(result.reward[:-1]))

            if self._postprocess_exp_fn:
                result, info = self._postprocess_exp_fn(self, result, info)

        if alf.get_default_device() == self._device:
            return result, info
        else:
            return convert_device(result), convert_device(info)
Exemplo n.º 9
0
    def __init__(self,
                 observation_spec,
                 action_spec,
                 model_ctor,
                 mcts_algorithm_ctor,
                 num_unroll_steps,
                 td_steps,
                 recurrent_gradient_scaling_factor=0.5,
                 reward_normalizer=None,
                 reward_clip_value=-1.,
                 train_reward_function=True,
                 train_game_over_function=True,
                 reanalyze_ratio=0.,
                 reanalyze_td_steps=5,
                 reanalyze_batch_size=None,
                 data_transformer_ctor=None,
                 target_update_tau=1.,
                 target_update_period=1000,
                 debug_summaries=False,
                 name="MuZero"):
        """
        Args:
            observation_spec (TensorSpec): representing the observations.
            action_spec (BoundedTensorSpec): representing the actions.
            model_ctor (Callable): will be called as
                ``model_ctor(observation_spec=?, action_spec=?, debug_summaries=?)``
                to construct the model. The model should follow the interface
                ``alf.algorithms.mcts_models.MCTSModel``.
            mcts_algorithm_ctor (Callable): will be called as
                ``mcts_algorithm_ctor(observation_spec=?, action_spec=?, debug_summaries=?)``
                to construct an ``MCTSAlgorithm`` instance.
            num_unroll_steps (int): steps for unrolling the model during training.
            td_steps (int): bootstrap so many steps into the future for calculating
                the discounted return. -1 means to bootstrap to the end of the game.
                Can only used for environments whose rewards are zero except for
                the last step as the current implmentation only use the reward
                at the last step to calculate the return.
            recurrent_gradient_scaling_factor (float): the gradient go through
                the ``model.recurrent_inference`` is scaled by this factor. This
                is suggested in Appendix G.
            reward_normalizer (Normalizer|None): if provided, will be used to
                normalize reward.
            train_reward_function (bool): whether train reward function. If
                False, reward should only be given at the last step of an episode.
            train_game_over_function (bool): whether train game over function.
            reanalyze_ratio (float): float number in [0., 1.]. Reanalyze so much
                portion of data retrieved from replay buffer. Reanalyzing means
                using recent model to calculate the value and policy target.
            reanalyze_td_steps (int): the n for the n-step return for reanalyzing.
            reanalyze_batch_size (int|None): the memory usage may be too much for
                reanalyzing all the data for one training iteration. If so, provide
                a number for this so that it will analyzing the data in several
                batches.
            data_transformer_ctor (Callable|list[Callable]): should be same as
                ``TrainerConfig.data_transformer_ctor``.
            target_update_tau (float): Factor for soft update of the target
                networks used for reanalyzing.
            target_update_period (int): Period for soft update of the target
                networks used for reanalyzing.
            debug_summaries (bool):
            name (str):
        """
        model = model_ctor(observation_spec,
                           action_spec,
                           debug_summaries=debug_summaries)
        mcts = mcts_algorithm_ctor(observation_spec=observation_spec,
                                   action_spec=action_spec,
                                   debug_summaries=debug_summaries)
        mcts.set_model(model)
        self._device = alf.get_default_device()
        super().__init__(observation_spec=observation_spec,
                         action_spec=action_spec,
                         train_state_spec=mcts.predict_state_spec,
                         predict_state_spec=mcts.predict_state_spec,
                         rollout_state_spec=mcts.predict_state_spec,
                         debug_summaries=debug_summaries,
                         name=name)

        self._mcts = mcts
        self._model = model
        self._num_unroll_steps = num_unroll_steps
        self._td_steps = td_steps
        self._discount = mcts.discount
        self._recurrent_gradient_scaling_factor = recurrent_gradient_scaling_factor
        self._reward_normalizer = reward_normalizer
        self._reward_clip_value = reward_clip_value
        self._train_reward_function = train_reward_function
        self._train_game_over_function = train_game_over_function
        self._reanalyze_ratio = reanalyze_ratio
        self._reanalyze_td_steps = reanalyze_td_steps
        self._reanalyze_batch_size = reanalyze_batch_size
        self._data_transformer = None
        self._data_transformer_ctor = data_transformer_ctor

        self._update_target = None
        if reanalyze_ratio > 0:
            self._target_model = model_ctor(observation_spec,
                                            action_spec,
                                            debug_summaries=debug_summaries)
            self._update_target = common.get_target_updater(
                models=[self._model],
                target_models=[self._target_model],
                tau=target_update_tau,
                period=target_update_period)