Exemple #1
0
 def _test_forward(self):
     """Generate a dummy input according to `nested_input_tensor_spec` and
     forward. Can be used to calculate output spec or testing the network.
     """
     inputs = common.zero_tensor_from_nested_spec(self._input_tensor_spec,
                                                  batch_size=1)
     states = common.zero_tensor_from_nested_spec(self.state_spec,
                                                  batch_size=1)
     return self.forward(inputs, states)
Exemple #2
0
 def test_encoding_network_input_preprocessor(self):
     input_spec = TensorSpec((1, ))
     inputs = common.zero_tensor_from_nested_spec(input_spec, batch_size=1)
     network = EncodingNetwork(input_tensor_spec=input_spec,
                               input_preprocessors=torch.tanh)
     output, _ = network(inputs)
     self.assertEqual(output.size()[1], 1)
Exemple #3
0
    def _calc_change(self, exp_array):
        """Calculate the distance between old/new action distributions.

        The distance is:
        ||logits_1 - logits_2||^2 for Categorical distribution
        KL(d1||d2) + KL(d2||d1) for others
        """

        def _dist(d1, d2):
            if isinstance(d1, tfp.distributions.Categorical):
                return tf.reduce_sum(tf.square(d1.logits - d2.logits), axis=-1)
            elif isinstance(d1, tfp.distributions.Deterministic):
                return tf.reduce_sum(tf.square(d1.loc - d2.loc), axis=-1)
            else:
                if isinstance(d1, SquashToSpecNormal):
                    # TODO `SquashToSpecNormal.kl_divergence` checks that  two distributions should have
                    #  same action mean and magnitude, but this check fails in graph mode
                    d1 = d1.input_distribution
                    d2 = d2.input_distribution
                return tf.reduce_sum(
                    d1.kl_divergence(d2) + d2.kl_divergence(d1), axis=-1)

        def _update_total_dists(new_action, exp, total_dists):
            old_action = nested_distributions_from_specs(
                common.to_distribution_spec(self.action_distribution_spec),
                exp.action_param)
            dists = nest_map(_dist, old_action, new_action)
            valid_masks = tf.cast(
                tf.not_equal(exp.step_type, StepType.LAST), tf.float32)
            dists = nest_map(lambda kl: tf.reduce_sum(kl * valid_masks), dists)
            return nest_map(lambda x, y: x + y, total_dists, dists)

        num_steps = exp_array.step_type.size()
        # element_shape for `TensorArray` can be (None, ...)
        batch_size = tf.shape(exp_array.step_type.read(0))[0]
        state = tf.nest.map_structure(lambda x: x.read(0), exp_array.state)
        # exp_array.state is no longer needed
        exp_array = exp_array._replace(state=())
        initial_state = common.zero_tensor_from_nested_spec(
            self.predict_state_spec, batch_size)
        total_dists = nest_map(lambda _: tf.zeros(()), self.action_spec)
        for t in tf.range(num_steps):
            exp = tf.nest.map_structure(lambda x: x.read(t), exp_array)
            state = common.reset_state_if_necessary(
                state, initial_state, exp.step_type == StepType.FIRST)
            time_step = ActionTimeStep(
                observation=exp.observation, step_type=exp.step_type)
            policy_step = self._ac_algorithm.predict(
                time_step=time_step, state=state)
            new_action, state = policy_step.action, policy_step.state
            new_action = common.to_distribution(new_action)
            total_dists = _update_total_dists(new_action, exp, total_dists)

        size = tf.cast(num_steps * batch_size, tf.float32)
        total_dists = nest_map(lambda d: d / size, total_dists)
        return total_dists
 def setUp(self):
     self._input_spec = [
         TensorSpec((3, 20, 20), torch.float32),
         TensorSpec((1, 20, 20), torch.float32)
     ]
     self._image = zero_tensor_from_nested_spec(self._input_spec,
                                                batch_size=1)
     self._conv_layer_params = ((8, 3, 1), (16, 3, 2, 1))
     self._fc_layer_params = (100, )
     self._input_preprocessors = [torch.tanh, None]
     self._preprocessing_combiner = NestConcat(dim=1)
Exemple #5
0
    def test_encoding_network_preprocessing_combiner(self):
        input_spec = dict(a=TensorSpec((3, 80, 80)),
                          b=[TensorSpec((80, 80)),
                             TensorSpec(())])
        imgs = common.zero_tensor_from_nested_spec(input_spec, batch_size=1)
        network = EncodingNetwork(input_tensor_spec=input_spec,
                                  preprocessing_combiner=NestSum(average=True),
                                  conv_layer_params=((1, 2, 2, 0), ))

        self.assertEqual(network._processed_input_tensor_spec,
                         TensorSpec((3, 80, 80)))

        output, _ = network(imgs)
        self.assertTensorEqual(output, torch.zeros((40 * 40, )))
Exemple #6
0
    def test_q_value_network(self, lstm_hidden_size):
        input_spec = [TensorSpec((3, 20, 20), torch.float32)]
        conv_layer_params = ((8, 3, 1), (16, 3, 2, 1))
        image = common.zero_tensor_from_nested_spec(input_spec, batch_size=1)

        network_ctor, state = self._init(lstm_hidden_size)

        q_net = network_ctor(input_spec,
                             self._action_spec,
                             input_preprocessors=[torch.relu],
                             preprocessing_combiner=NestSum(),
                             conv_layer_params=conv_layer_params)
        q_value, state = q_net(image, state)

        # (batch_size, num_actions)
        self.assertEqual(q_value.shape, (1, self._num_actions))
Exemple #7
0
    def test_encoding_network_nested_input(self, lstm):
        input_spec = dict(a=TensorSpec((3, 80, 80)),
                          b=[
                              TensorSpec((80, )),
                              BoundedTensorSpec((), dtype="int64"),
                              dict(x=TensorSpec((100, )),
                                   y=TensorSpec((200, )))
                          ])
        imgs = common.zero_tensor_from_nested_spec(input_spec, batch_size=1)
        input_preprocessors = dict(
            a=EmbeddingPreprocessor(input_spec["a"],
                                    conv_layer_params=((1, 2, 2, 0), ),
                                    embedding_dim=100),
            b=[
                EmbeddingPreprocessor(input_spec["b"][0], embedding_dim=50),
                EmbeddingPreprocessor(input_spec["b"][1], embedding_dim=50),
                dict(x=None, y=torch.relu)
            ])

        if lstm:
            network_ctor = functools.partial(LSTMEncodingNetwork,
                                             hidden_size=(100, ))
        else:
            network_ctor = EncodingNetwork

        network = network_ctor(input_tensor_spec=input_spec,
                               input_preprocessors=input_preprocessors,
                               preprocessing_combiner=NestConcat())
        output, _ = network(imgs, state=[(torch.zeros((1, 100)), ) * 2])

        if lstm:
            self.assertEqual(network.output_spec, TensorSpec((100, )))
            self.assertEqual(output.size()[-1], 100)
        else:
            self.assertEqual(len(list(network.parameters())), 4 + 2 + 1)
            self.assertEqual(network.output_spec, TensorSpec((500, )))
            self.assertEqual(output.size()[-1], 500)
Exemple #8
0
    def test_frame_stacker(self, stack_axis=0):
        data_spec = DataItem(step_type=alf.TensorSpec((), dtype=torch.int32),
                             observation=dict(scalar=alf.TensorSpec(()),
                                              vector=alf.TensorSpec((7, )),
                                              matrix=alf.TensorSpec((5, 6)),
                                              tensor=alf.TensorSpec(
                                                  (2, 3, 4))))
        replay_buffer = ReplayBuffer(data_spec=data_spec,
                                     num_environments=2,
                                     max_length=1024,
                                     num_earliest_frames_ignored=2)
        frame_stacker = FrameStacker(
            data_spec.observation,
            stack_size=3,
            stack_axis=stack_axis,
            fields=['scalar', 'vector', 'matrix', 'tensor'])

        new_spec = frame_stacker.transformed_observation_spec
        self.assertEqual(new_spec['scalar'].shape, (3, ))
        self.assertEqual(new_spec['vector'].shape, (21, ))
        if stack_axis == -1:
            self.assertEqual(new_spec['matrix'].shape, (5, 18))
            self.assertEqual(new_spec['tensor'].shape, (2, 3, 12))
        elif stack_axis == 0:
            self.assertEqual(new_spec['matrix'].shape, (15, 6))
            self.assertEqual(new_spec['tensor'].shape, (6, 3, 4))

        def _step_type(t, period):
            if t % period == 0:
                return StepType.FIRST
            if t % period == period - 1:
                return StepType.LAST
            return StepType.MID

        observation = alf.nest.map_structure(
            lambda spec: spec.randn((1000, 2)), data_spec.observation)
        state = common.zero_tensor_from_nested_spec(frame_stacker.state_spec,
                                                    2)

        def _get_stacked_data(t, b):
            if stack_axis == -1:
                return dict(scalar=observation['scalar'][t, b],
                            vector=observation['vector'][t, b].reshape(-1),
                            matrix=observation['matrix'][t, b].transpose(
                                0, 1).reshape(5, 18),
                            tensor=observation['tensor'][t, b].permute(
                                1, 2, 0, 3).reshape(2, 3, 12))
            elif stack_axis == 0:
                return dict(scalar=observation['scalar'][t, b],
                            vector=observation['vector'][t, b].reshape(-1),
                            matrix=observation['matrix'][t, b].reshape(15, 6),
                            tensor=observation['tensor'][t,
                                                         b].reshape(6, 3, 4))

        def _check_equal(stacked, expected, b):
            self.assertEqual(stacked['scalar'][b], expected['scalar'])
            self.assertEqual(stacked['vector'][b], expected['vector'])
            self.assertEqual(stacked['matrix'][b], expected['matrix'])
            self.assertEqual(stacked['tensor'][b], expected['tensor'])

        for t in range(1000):
            batch = DataItem(
                step_type=torch.tensor([_step_type(t, 17),
                                        _step_type(t, 22)]),
                observation=alf.nest.map_structure(lambda x: x[t],
                                                   observation))
            replay_buffer.add_batch(batch)
            timestep, state = frame_stacker.transform_timestep(batch, state)
            if t == 0:
                for b in (0, 1):
                    expected = _get_stacked_data([0, 0, 0], b)
                    _check_equal(timestep.observation, expected, b)
            if t == 1:
                for b in (0, 1):
                    expected = _get_stacked_data([0, 0, 1], b)
                    _check_equal(timestep.observation, expected, b)
            if t == 2:
                for b in (0, 1):
                    expected = _get_stacked_data([0, 1, 2], b)
                    _check_equal(timestep.observation, expected, b)
            if t == 16:
                for b in (0, 1):
                    expected = _get_stacked_data([14, 15, 16], b)
                    _check_equal(timestep.observation, expected, b)
            if t == 17:
                for b, t in ((0, [17, 17, 17]), (1, [15, 16, 17])):
                    expected = _get_stacked_data(t, b)
                    _check_equal(timestep.observation, expected, b)
            if t == 18:
                for b, t in ((0, [17, 17, 18]), (1, [16, 17, 18])):
                    expected = _get_stacked_data(t, b)
                    _check_equal(timestep.observation, expected, b)
            if t == 22:
                for b, t in ((0, [20, 21, 22]), (1, [22, 22, 22])):
                    expected = _get_stacked_data(t, b)
                    _check_equal(timestep.observation, expected, b)

        batch_info = BatchInfo(env_ids=torch.tensor([0, 1, 0, 1],
                                                    dtype=torch.int64),
                               positions=torch.tensor([0, 1, 18, 22],
                                                      dtype=torch.int64))

        # [4, 2, ...]
        experience = replay_buffer.get_field(
            '', batch_info.env_ids.unsqueeze(-1),
            batch_info.positions.unsqueeze(-1) + torch.arange(2))
        experience = experience._replace(batch_info=batch_info,
                                         replay_buffer=replay_buffer)
        experience = frame_stacker.transform_experience(experience)
        expected = _get_stacked_data([0, 0, 0], 0)
        _check_equal(experience.observation, expected, (0, 0))
        expected = _get_stacked_data([0, 0, 1], 0)
        _check_equal(experience.observation, expected, (0, 1))

        expected = _get_stacked_data([0, 0, 1], 1)
        _check_equal(experience.observation, expected, (1, 0))
        expected = _get_stacked_data([0, 1, 2], 1)
        _check_equal(experience.observation, expected, (1, 1))

        expected = _get_stacked_data([17, 17, 18], 0)
        _check_equal(experience.observation, expected, (2, 0))
        expected = _get_stacked_data([17, 18, 19], 0)
        _check_equal(experience.observation, expected, (2, 1))

        expected = _get_stacked_data([22, 22, 22], 1)
        _check_equal(experience.observation, expected, (3, 0))
        expected = _get_stacked_data([22, 22, 23], 1)
        _check_equal(experience.observation, expected, (3, 1))
    def _prepare_specs(self, algorithm):
        """Prepare various tensor specs."""
        def extract_spec(nest):
            return tf.nest.map_structure(
                lambda t: tf.TensorSpec(t.shape[1:], t.dtype), nest)

        time_step = self.get_initial_time_step()
        self._time_step_spec = extract_spec(time_step)
        self._action_spec = self._env.action_spec()

        policy_step = algorithm.predict(time_step, self._initial_state)
        info_spec = extract_spec(policy_step.info)
        self._pred_policy_step_spec = PolicyStep(
            action=self._action_spec,
            state=algorithm.predict_state_spec,
            info=info_spec)

        def _to_distribution_spec(spec):
            if isinstance(spec, tf.TensorSpec):
                return DistributionSpec(tfp.distributions.Deterministic,
                                        input_params_spec={"loc": spec},
                                        sample_spec=spec)
            return spec

        self._action_distribution_spec = tf.nest.map_structure(
            _to_distribution_spec, algorithm.action_distribution_spec)
        self._action_dist_param_spec = tf.nest.map_structure(
            lambda spec: spec.input_params_spec,
            self._action_distribution_spec)

        self._experience_spec = Experience(
            step_type=self._time_step_spec.step_type,
            reward=self._time_step_spec.reward,
            discount=self._time_step_spec.discount,
            observation=self._time_step_spec.observation,
            prev_action=self._action_spec,
            action=self._action_spec,
            info=info_spec,
            action_distribution=self._action_dist_param_spec)

        action_dist_params = common.zero_tensor_from_nested_spec(
            self._experience_spec.action_distribution, self._env.batch_size)
        action_dist = nested_distributions_from_specs(
            self._action_distribution_spec, action_dist_params)
        exp = Experience(step_type=time_step.step_type,
                         reward=time_step.reward,
                         discount=time_step.discount,
                         observation=time_step.observation,
                         prev_action=time_step.prev_action,
                         action=time_step.prev_action,
                         info=policy_step.info,
                         action_distribution=action_dist)

        processed_exp = algorithm.preprocess_experience(exp)
        self._processed_experience_spec = self._experience_spec._replace(
            info=extract_spec(processed_exp.info))

        policy_step = common.algorithm_step(
            algorithm,
            ob_transformer=self._observation_transformer,
            time_step=exp,
            state=common.get_initial_policy_state(self._env.batch_size,
                                                  algorithm.train_state_spec),
            training=True)
        info_spec = extract_spec(policy_step.info)
        self._training_info_spec = make_training_info(
            action=self._action_spec,
            action_distribution=self._action_dist_param_spec,
            step_type=self._time_step_spec.step_type,
            reward=self._time_step_spec.reward,
            discount=self._time_step_spec.discount,
            info=info_spec,
            collect_info=self._processed_experience_spec.info,
            collect_action_distribution=self._action_dist_param_spec)
Exemple #10
0
 def get_initial_train_state(self, batch_size):
     """Always return the training state spec."""
     return common.zero_tensor_from_nested_spec(
         self._algorithm.train_state_spec, batch_size)
Exemple #11
0
    def _test_preprocess_experience(self, train_reward_function, td_steps,
                                    reanalyze_ratio, expected):
        """
        The following summarizes how the data is generated:

        .. code-block:: python

            # position:   01234567890123
            step_type0 = 'FMMMLFMMLFMMMM'
            step_type1 = 'FMMMMMLFMMMMLF'
            scale = 1. for current model
                    2. for target model
            observation = [position] * 3
            reward = position if train_reward_function and td_steps!=-1
                     else position * (step_type == LAST)
            value = 0.5 * position * scale
            action_probs = scale * [position, position+1, position] for env 0
                           scale * [position+1, position, position] for env 1
            action = 1 for env 0
                     0 for env 1

        """
        reanalyze_td_steps = 2

        num_unroll_steps = 4
        batch_size = 2
        obs_dim = 3

        observation_spec = alf.TensorSpec([obs_dim])
        action_spec = alf.BoundedTensorSpec((),
                                            minimum=0,
                                            maximum=1,
                                            dtype=torch.int32)
        reward_spec = alf.TensorSpec(())
        time_step_spec = ds.time_step_spec(observation_spec, action_spec,
                                           reward_spec)

        global _mcts_model_id
        _mcts_model_id = 0
        muzero = MuzeroAlgorithm(observation_spec,
                                 action_spec,
                                 model_ctor=_create_mcts_model,
                                 mcts_algorithm_ctor=MockMCTSAlgorithm,
                                 num_unroll_steps=num_unroll_steps,
                                 td_steps=td_steps,
                                 train_game_over_function=True,
                                 train_reward_function=train_reward_function,
                                 reanalyze_ratio=reanalyze_ratio,
                                 reanalyze_td_steps=reanalyze_td_steps,
                                 data_transformer_ctor=partial(FrameStacker,
                                                               stack_size=2))

        data_transformer = FrameStacker(observation_spec, stack_size=2)
        time_step = common.zero_tensor_from_nested_spec(
            time_step_spec, batch_size)
        dt_state = common.zero_tensor_from_nested_spec(
            data_transformer.state_spec, batch_size)
        state = muzero.get_initial_predict_state(batch_size)
        transformed_time_step, dt_state = data_transformer.transform_timestep(
            time_step, dt_state)
        alg_step = muzero.rollout_step(transformed_time_step, state)
        alg_step_spec = dist_utils.extract_spec(alg_step)

        experience = ds.make_experience(time_step, alg_step, state)
        experience_spec = ds.make_experience(time_step_spec, alg_step_spec,
                                             muzero.train_state_spec)
        replay_buffer = ReplayBuffer(data_spec=experience_spec,
                                     num_environments=batch_size,
                                     max_length=16,
                                     keep_episodic_info=True)

        #             01234567890123
        step_type0 = 'FMMMLFMMLFMMMM'
        step_type1 = 'FMMMMMLFMMMMLF'

        dt_state = common.zero_tensor_from_nested_spec(
            data_transformer.state_spec, batch_size)
        for i in range(len(step_type0)):
            step_type = [step_type0[i], step_type1[i]]
            step_type = [
                ds.StepType.MID if c == 'M' else
                (ds.StepType.FIRST if c == 'F' else ds.StepType.LAST)
                for c in step_type
            ]
            step_type = torch.tensor(step_type, dtype=torch.int32)
            reward = reward = torch.full([batch_size], float(i))
            if not train_reward_function or td_steps == -1:
                reward = reward * (step_type == ds.StepType.LAST).to(
                    torch.float32)
            time_step = time_step._replace(
                discount=(step_type != ds.StepType.LAST).to(torch.float32),
                step_type=step_type,
                observation=torch.tensor([[i, i + 1, i], [i + 1, i, i]],
                                         dtype=torch.float32),
                reward=reward,
                env_id=torch.arange(batch_size, dtype=torch.int32))
            transformed_time_step, dt_state = data_transformer.transform_timestep(
                time_step, dt_state)
            alg_step = muzero.rollout_step(transformed_time_step, state)
            experience = ds.make_experience(time_step, alg_step, state)
            replay_buffer.add_batch(experience)
            state = alg_step.state

        env_ids = torch.tensor([0] * 14 + [1] * 14, dtype=torch.int64)
        positions = torch.tensor(list(range(14)) + list(range(14)),
                                 dtype=torch.int64)
        experience = replay_buffer.get_field(None,
                                             env_ids.unsqueeze(-1).cpu(),
                                             positions.unsqueeze(-1).cpu())
        experience = experience._replace(replay_buffer=replay_buffer,
                                         batch_info=BatchInfo(
                                             env_ids=env_ids,
                                             positions=positions),
                                         rollout_info_field='rollout_info')
        processed_experience = muzero.preprocess_experience(experience)
        import pprint
        pprint.pprint(processed_experience.rollout_info)
        alf.nest.map_structure(lambda x, y: self.assertEqual(x, y),
                               processed_experience.rollout_info, expected)
Exemple #12
0
    def prepare_off_policy_specs(self, time_step: ActionTimeStep):
        """Prepare various tensor specs for off_policy training.

        prepare_off_policy_specs is called by OffPolicyDriver._prepare_spec().

        """

        self._env_batch_size = time_step.step_type.shape[0]
        self._time_step_spec = common.extract_spec(time_step)
        initial_state = common.get_initial_policy_state(
            self._env_batch_size, self.train_state_spec)
        transformed_timestep = self.transform_timestep(time_step)
        policy_step = self.rollout(transformed_timestep, initial_state)
        info_spec = common.extract_spec(policy_step.info)

        self._action_distribution_spec = tf.nest.map_structure(
            common.to_distribution_spec, self.action_distribution_spec)
        self._action_dist_param_spec = tf.nest.map_structure(
            lambda spec: spec.input_params_spec,
            self._action_distribution_spec)

        self._experience_spec = Experience(
            step_type=self._time_step_spec.step_type,
            reward=self._time_step_spec.reward,
            discount=self._time_step_spec.discount,
            observation=self._time_step_spec.observation,
            prev_action=self._action_spec,
            action=self._action_spec,
            info=info_spec,
            action_distribution=self._action_dist_param_spec,
            state=self.train_state_spec if self._use_rollout_state else ())

        action_dist_params = common.zero_tensor_from_nested_spec(
            self._experience_spec.action_distribution, self._env_batch_size)
        action_dist = nested_distributions_from_specs(
            self._action_distribution_spec, action_dist_params)

        exp = Experience(step_type=time_step.step_type,
                         reward=time_step.reward,
                         discount=time_step.discount,
                         observation=time_step.observation,
                         prev_action=time_step.prev_action,
                         action=time_step.prev_action,
                         info=policy_step.info,
                         action_distribution=action_dist,
                         state=initial_state if self._use_rollout_state else
                         ())

        transformed_exp = self.transform_timestep(exp)
        processed_exp = self.preprocess_experience(transformed_exp)
        self._processed_experience_spec = self._experience_spec._replace(
            observation=common.extract_spec(processed_exp.observation),
            info=common.extract_spec(processed_exp.info))

        policy_step = common.algorithm_step(
            algorithm_step_func=self.train_step,
            time_step=processed_exp,
            state=initial_state)
        info_spec = common.extract_spec(policy_step.info)
        self._training_info_spec = TrainingInfo(
            action_distribution=self._action_dist_param_spec, info=info_spec)
Exemple #13
0
    def _calc_change(self, exp_array):
        """Calculate the distance between old/new action distributions.

        The squared distance is:
        ||logits_1 - logits_2||^2 for Categorical distribution
        ||loc_1 - loc_2||^2 for Deterministic distribution
        KL(d1||d2) + KL(d2||d1) for others
        """
        def _dist(d1, d2):
            if isinstance(d1, tfp.distributions.Categorical):
                dist = tf.square(d1.logits - d2.logits)
            elif isinstance(d1, tf.Tensor):
                dist = tf.square(d1 - d2)
            else:
                if isinstance(d1, SquashToSpecNormal):
                    # TODO `SquashToSpecNormal.kl_divergence` checks that two
                    # distributions should have same action mean and magnitude,
                    # but this check fails in graph mode
                    d1 = d1.input_distribution
                    d2 = d2.input_distribution
                dist = d1.kl_divergence(d2) + d2.kl_divergence(d1)

            if len(dist.shape) > 1:
                # reduce to shape [B]
                reduce_dims = list(range(1, len(dist.shape)))
                dist = tf.reduce_sum(dist, axis=reduce_dims)
            return dist

        def _update_total_dists(new_action, exp, total_dists):
            old_action = nest_utils.params_to_distributions(
                exp.action_param, self._action_distribution_spec)
            dists = nest_map(_dist, old_action, new_action)
            valid_masks = tf.cast(tf.not_equal(exp.step_type, StepType.LAST),
                                  tf.float32)
            dists = nest_map(lambda kl: tf.reduce_sum(kl * valid_masks), dists)
            return nest_map(lambda x, y: x + y, total_dists, dists)

        num_steps = exp_array.step_type.size()
        # element_shape for `TensorArray` can be (None, ...)
        batch_size = tf.shape(exp_array.step_type.read(0))[0]
        state = tf.nest.map_structure(lambda x: x.read(0), exp_array.state)
        # exp_array.state is no longer needed
        exp_array = exp_array._replace(state=())
        initial_state = common.zero_tensor_from_nested_spec(
            self.predict_state_spec, batch_size)
        total_dists = nest_map(lambda _: tf.zeros(()), self.action_spec)
        for t in tf.range(num_steps):
            exp = tf.nest.map_structure(lambda x: x.read(t), exp_array)
            state = common.reset_state_if_necessary(
                state, initial_state, exp.step_type == StepType.FIRST)
            time_step = ActionTimeStep(observation=exp.observation,
                                       step_type=exp.step_type)
            policy_step = self._ac_algorithm.predict(time_step=time_step,
                                                     state=state,
                                                     epsilon_greedy=1.0)
            assert (
                common.is_namedtuple(policy_step.info)
                and "action_distribution" in policy_step.info._fields
            ), ("PolicyStep.info from ac_algorithm.predict() should be "
                "a namedtuple containing `action_distribution` in order to "
                "use TracAlgorithm.")
            new_action = policy_step.info.action_distribution
            state = policy_step.state
            total_dists = _update_total_dists(new_action, exp, total_dists)

        size = tf.cast(num_steps * batch_size, tf.float32)
        total_dists = nest_map(lambda d: tf.sqrt(d / size), total_dists)
        return total_dists
Exemple #14
0
    def __init__(self, num_expansions, model_output, known_value_bounds):
        batch_size, branch_factor = model_output.action_probs.shape
        action_spec = dist_utils.extract_spec(model_output.actions, from_dim=2)
        state_spec = dist_utils.extract_spec(model_output.state, from_dim=1)
        if known_value_bounds:
            self.fixed_bounds = True
            self.minimum, self.maximum = known_value_bounds
        else:
            self.fixed_bounds = False
            self.minimum, self.maximum = MAXIMUM_FLOAT_VALUE, -MAXIMUM_FLOAT_VALUE
        self.minimum = torch.full((batch_size, ),
                                  self.minimum,
                                  dtype=torch.float32)
        self.maximum = torch.full((batch_size, ),
                                  self.maximum,
                                  dtype=torch.float32)
        if known_value_bounds:
            self.normalize_scale = 1 / (self.maximum - self.minimum + 1e-30)
            self.normalize_base = self.minimum
        else:
            self.normalize_scale = torch.ones((batch_size, ))
            self.normalize_base = torch.zeros((batch_size, ))

        self.B = torch.arange(batch_size)
        self.root_indices = torch.zeros((batch_size, ), dtype=torch.int64)
        self.branch_factor = branch_factor

        parent_shape = (batch_size, num_expansions)
        children_shape = (batch_size, num_expansions, branch_factor)

        self.visit_count = torch.zeros(parent_shape, dtype=torch.int32)

        # the player who will take action from the current state
        self.to_play = torch.zeros(parent_shape, dtype=torch.int64)
        self.prior = torch.zeros(children_shape)

        # value for player 0
        self.value_sum = torch.zeros(parent_shape)

        # 0 for not expanded, value in range [0, num_expansions)
        self.children_index = torch.zeros(children_shape, dtype=torch.int64)
        self.model_state = common.zero_tensor_from_nested_spec(
            state_spec, parent_shape)

        # reward for player 0
        self.reward = None
        if isinstance(model_output.reward, torch.Tensor):
            self.reward = torch.zeros(parent_shape)

        self.action = None
        if isinstance(model_output.actions, torch.Tensor):
            # candidate actions for this state
            self.action = torch.zeros(
                children_shape + action_spec.shape, dtype=action_spec.dtype)

        self.game_over = None
        if isinstance(model_output.game_over, torch.Tensor):
            self.game_over = torch.zeros(parent_shape, dtype=torch.bool)

        # value in range [0, branch_factor)
        self.best_child_index = torch.zeros(parent_shape, dtype=torch.int64)
        self.ucb_score = torch.zeros(children_shape)
Exemple #15
0
    def test_preprocess_experience(self):
        """
        The following summarizes how the data is generated:

        .. code-block:: python

            # position:   01234567890123
            step_type0 = 'FMMMLFMMLFMMMM'
            step_type1 = 'FMMMMMLFMMMMLF'
            reward = position if train_reward_function and td_steps!=-1
                     else position * (step_type == LAST)
            action = t + 1 for env 0
                     t for env 1

        """
        num_unroll_steps = 4
        batch_size = 2
        obs_dim = 3
        observation_spec = alf.TensorSpec([obs_dim])
        action_spec = alf.BoundedTensorSpec((1, ),
                                            minimum=0,
                                            maximum=1,
                                            dtype=torch.float32)
        reward_spec = alf.TensorSpec(())
        time_step_spec = ds.time_step_spec(observation_spec, action_spec,
                                           reward_spec)

        repr_learner = PredictiveRepresentationLearner(
            observation_spec,
            action_spec,
            num_unroll_steps=num_unroll_steps,
            decoder_ctor=partial(SimpleDecoder,
                                 target_field='reward',
                                 decoder_net_ctor=partial(
                                     EncodingNetwork, fc_layer_params=(4, ))),
            encoding_net_ctor=LSTMEncodingNetwork,
            dynamics_net_ctor=LSTMEncodingNetwork)

        time_step = common.zero_tensor_from_nested_spec(
            time_step_spec, batch_size)
        state = repr_learner.get_initial_predict_state(batch_size)
        alg_step = repr_learner.rollout_step(time_step, state)
        alg_step = alg_step._replace(output=torch.tensor([[1.], [0.]]))
        alg_step_spec = dist_utils.extract_spec(alg_step)

        experience = ds.make_experience(time_step, alg_step, state)
        experience_spec = ds.make_experience(time_step_spec, alg_step_spec,
                                             repr_learner.train_state_spec)
        replay_buffer = ReplayBuffer(data_spec=experience_spec,
                                     num_environments=batch_size,
                                     max_length=16,
                                     keep_episodic_info=True)

        #             01234567890123
        step_type0 = 'FMMMLFMMLFMMMM'
        step_type1 = 'FMMMMMLFMMMMLF'

        for i in range(len(step_type0)):
            step_type = [step_type0[i], step_type1[i]]
            step_type = [
                ds.StepType.MID if c == 'M' else
                (ds.StepType.FIRST if c == 'F' else ds.StepType.LAST)
                for c in step_type
            ]
            step_type = torch.tensor(step_type, dtype=torch.int32)
            reward = reward = torch.full([batch_size], float(i))
            time_step = time_step._replace(
                discount=(step_type != ds.StepType.LAST).to(torch.float32),
                step_type=step_type,
                observation=torch.tensor([[i, i + 1, i], [i + 1, i, i]],
                                         dtype=torch.float32),
                reward=reward,
                env_id=torch.arange(batch_size, dtype=torch.int32))
            alg_step = repr_learner.rollout_step(time_step, state)
            alg_step = alg_step._replace(output=i + torch.tensor([[1.], [0.]]))
            experience = ds.make_experience(time_step, alg_step, state)
            replay_buffer.add_batch(experience)
            state = alg_step.state

        env_ids = torch.tensor([0] * 14 + [1] * 14, dtype=torch.int64)
        positions = torch.tensor(list(range(14)) + list(range(14)),
                                 dtype=torch.int64)
        experience = replay_buffer.get_field(None,
                                             env_ids.unsqueeze(-1).cpu(),
                                             positions.unsqueeze(-1).cpu())
        experience = experience._replace(replay_buffer=replay_buffer,
                                         batch_info=BatchInfo(
                                             env_ids=env_ids,
                                             positions=positions),
                                         rollout_info_field='rollout_info')
        processed_experience = repr_learner.preprocess_experience(experience)
        pprint.pprint(processed_experience.rollout_info)

        # yapf: disable
        expected = PredictiveRepresentationLearnerInfo(
            action=torch.tensor(
               [[[ 1.,  2.,  3.,  4.,  5.]],
                [[ 2.,  3.,  4.,  5.,  5.]],
                [[ 3.,  4.,  5.,  5.,  5.]],
                [[ 4.,  5.,  5.,  5.,  5.]],
                [[ 5.,  5.,  5.,  5.,  5.]],
                [[ 6.,  7.,  8.,  9.,  9.]],
                [[ 7.,  8.,  9.,  9.,  9.]],
                [[ 8.,  9.,  9.,  9.,  9.]],
                [[ 9.,  9.,  9.,  9.,  9.]],
                [[10., 11., 12., 13., 14.]],
                [[11., 12., 13., 14., 14.]],
                [[12., 13., 14., 14., 14.]],
                [[13., 14., 14., 14., 14.]],
                [[14., 14., 14., 14., 14.]],
                [[ 0.,  1.,  2.,  3.,  4.]],
                [[ 1.,  2.,  3.,  4.,  5.]],
                [[ 2.,  3.,  4.,  5.,  6.]],
                [[ 3.,  4.,  5.,  6.,  6.]],
                [[ 4.,  5.,  6.,  6.,  6.]],
                [[ 5.,  6.,  6.,  6.,  6.]],
                [[ 6.,  6.,  6.,  6.,  6.]],
                [[ 7.,  8.,  9., 10., 11.]],
                [[ 8.,  9., 10., 11., 12.]],
                [[ 9., 10., 11., 12., 12.]],
                [[10., 11., 12., 12., 12.]],
                [[11., 12., 12., 12., 12.]],
                [[12., 12., 12., 12., 12.]],
                [[13., 13., 13., 13., 13.]]]).unsqueeze(-1),
            mask=torch.tensor(
               [[[ True,  True,  True,  True,  True]],
                [[ True,  True,  True,  True, False]],
                [[ True,  True,  True, False, False]],
                [[ True,  True, False, False, False]],
                [[ True, False, False, False, False]],
                [[ True,  True,  True,  True, False]],
                [[ True,  True,  True, False, False]],
                [[ True,  True, False, False, False]],
                [[ True, False, False, False, False]],
                [[ True,  True,  True,  True,  True]],
                [[ True,  True,  True,  True, False]],
                [[ True,  True,  True, False, False]],
                [[ True,  True, False, False, False]],
                [[ True, False, False, False, False]],
                [[ True,  True,  True,  True,  True]],
                [[ True,  True,  True,  True,  True]],
                [[ True,  True,  True,  True,  True]],
                [[ True,  True,  True,  True, False]],
                [[ True,  True,  True, False, False]],
                [[ True,  True, False, False, False]],
                [[ True, False, False, False, False]],
                [[ True,  True,  True,  True,  True]],
                [[ True,  True,  True,  True,  True]],
                [[ True,  True,  True,  True, False]],
                [[ True,  True,  True, False, False]],
                [[ True,  True, False, False, False]],
                [[ True, False, False, False, False]],
                [[ True, False, False, False, False]]]),
            target=torch.tensor(
               [[[ 0.,  1.,  2.,  3.,  4.]],
                [[ 1.,  2.,  3.,  4.,  4.]],
                [[ 2.,  3.,  4.,  4.,  4.]],
                [[ 3.,  4.,  4.,  4.,  4.]],
                [[ 4.,  4.,  4.,  4.,  4.]],
                [[ 5.,  6.,  7.,  8.,  8.]],
                [[ 6.,  7.,  8.,  8.,  8.]],
                [[ 7.,  8.,  8.,  8.,  8.]],
                [[ 8.,  8.,  8.,  8.,  8.]],
                [[ 9., 10., 11., 12., 13.]],
                [[10., 11., 12., 13., 13.]],
                [[11., 12., 13., 13., 13.]],
                [[12., 13., 13., 13., 13.]],
                [[13., 13., 13., 13., 13.]],
                [[ 0.,  1.,  2.,  3.,  4.]],
                [[ 1.,  2.,  3.,  4.,  5.]],
                [[ 2.,  3.,  4.,  5.,  6.]],
                [[ 3.,  4.,  5.,  6.,  6.]],
                [[ 4.,  5.,  6.,  6.,  6.]],
                [[ 5.,  6.,  6.,  6.,  6.]],
                [[ 6.,  6.,  6.,  6.,  6.]],
                [[ 7.,  8.,  9., 10., 11.]],
                [[ 8.,  9., 10., 11., 12.]],
                [[ 9., 10., 11., 12., 12.]],
                [[10., 11., 12., 12., 12.]],
                [[11., 12., 12., 12., 12.]],
                [[12., 12., 12., 12., 12.]],
                [[13., 13., 13., 13., 13.]]]))
        # yapf: enable

        alf.nest.map_structure(lambda x, y: self.assertEqual(x, y),
                               processed_experience.rollout_info, expected)