def step(self, action):
        # Action is the suggested slate (indices of the docs in the
        # suggested ones).

        scores = [
            np.dot(self.current_user, doc)
            for doc in self.currently_suggested_docs
        ]
        best_reward = np.max(scores)

        # User choice model: User picks a doc stochastically,
        # where probs are dot products between user- and doc feature
        # (categories) vectors (rewards).
        # There is also a no-click doc whose weight is 0.0.
        user_doc_overlaps = np.array([scores[a] for a in action] + [0.0])
        which_clicked = np.random.choice(np.arange(self.slate_size + 1),
                                         p=softmax(user_doc_overlaps))

        reward = 0.0
        if which_clicked < self.slate_size:
            # Reward is 1.0 - regret if clicked. 0.0 if not clicked.
            regret = best_reward - user_doc_overlaps[which_clicked]
            reward = 1 - regret
            # If anything clicked, deduct from the current user's time budget.
            self.current_user_budget -= 1.0
        done = self.current_user_budget <= 0.0

        # Compile response.
        response = tuple({
            "click": int(idx == which_clicked),
            "engagement": reward if idx == which_clicked else 0.0,
        } for idx in range(len(user_doc_overlaps) - 1))

        return self._get_obs(response=response), reward, done, {}
Ejemplo n.º 2
0
    def test_gumbel_softmax(self):
        """Tests the GumbelSoftmax ActionDistribution (tf + eager only)."""
        for fw, sess in framework_iterator(frameworks=["tf", "tfe"],
                                           session=True):
            batch_size = 1000
            num_categories = 5
            input_space = Box(-1.0, 1.0, shape=(batch_size, num_categories))

            # Batch of size=n and deterministic.
            inputs = input_space.sample()
            gumbel_softmax = GumbelSoftmax(inputs, {}, temperature=1.0)

            expected = softmax(inputs)
            # Sample n times, expect always mean value (deterministic draw).
            out = gumbel_softmax.deterministic_sample()
            check(out, expected)

            # Batch of size=n and non-deterministic -> expect roughly that
            # the max-likelihood (argmax) ints are output (most of the time).
            inputs = input_space.sample()
            gumbel_softmax = GumbelSoftmax(inputs, {}, temperature=1.0)
            expected_mean = np.mean(np.argmax(inputs, -1)).astype(np.float32)
            outs = gumbel_softmax.sample()
            if sess:
                outs = sess.run(outs)
            check(np.mean(np.argmax(outs, -1)), expected_mean, rtol=0.08)
    def test_log_probs_from_logits_and_actions(self, batch_size):
        """Tests log_probs_from_logits_and_actions."""
        seq_len = 7
        num_actions = 3

        policy_logits = _shaped_arange(seq_len, batch_size, num_actions) + 10
        actions = np.random.randint(0,
                                    num_actions - 1,
                                    size=(seq_len, batch_size),
                                    dtype=np.int32)

        action_log_probs_tensor = vtrace.log_probs_from_logits_and_actions(
            policy_logits, actions)

        # Ground Truth
        # Using broadcasting to create a mask that indexes action logits
        action_index_mask = actions[..., None] == np.arange(num_actions)

        def index_with_mask(array, mask):
            return array[mask].reshape(*array.shape[:-1])

        # Note: Normally log(softmax) is not a good idea because it's not
        # numerically stable. However, in this test we have well-behaved
        # values.
        ground_truth_v = index_with_mask(np.log(softmax(policy_logits)),
                                         action_index_mask)

        with self.test_session() as session:
            self.assertAllClose(ground_truth_v,
                                session.run(action_log_probs_tensor))
Ejemplo n.º 4
0
    def test_multi_categorical(self):
        batch_size = 100
        num_categories = 3
        num_sub_distributions = 5
        # Create 5 categorical distributions of 3 categories each.
        inputs_space = Box(-1.0,
                           2.0,
                           shape=(batch_size,
                                  num_sub_distributions * num_categories))
        values_space = Box(0,
                           num_categories - 1,
                           shape=(num_sub_distributions, batch_size),
                           dtype=np.int32)

        inputs = inputs_space.sample()
        input_lengths = [num_categories] * num_sub_distributions
        inputs_split = np.split(inputs, num_sub_distributions, axis=1)

        for fw in framework_iterator():
            # Create the correct distribution object.
            cls = MultiCategorical if fw != "torch" else TorchMultiCategorical
            multi_categorical = cls(inputs, None, input_lengths)

            # Batch of size=3 and deterministic (True).
            expected = np.transpose(np.argmax(inputs_split, axis=-1))
            # Sample, expect always max value
            # (max likelihood for deterministic draw).
            out = multi_categorical.deterministic_sample()
            check(out, expected)

            # Batch of size=3 and non-deterministic -> expect roughly the mean.
            out = multi_categorical.sample()
            check(tf.reduce_mean(out)
                  if fw != "torch" else torch.mean(out.float()),
                  1.0,
                  decimals=0)

            # Test log-likelihood outputs.
            probs = softmax(inputs_split)
            values = values_space.sample()

            out = multi_categorical.logp(values if fw != "torch" else [
                torch.Tensor(values[i]) for i in range(num_sub_distributions)
            ])  # v in np.stack(values, 1)])
            expected = []
            for i in range(batch_size):
                expected.append(
                    np.sum(
                        np.log(
                            np.array([
                                probs[j][i][values[j][i]]
                                for j in range(num_sub_distributions)
                            ]))))
            check(out, expected, decimals=4)

            # Test entropy outputs.
            out = multi_categorical.entropy()
            expected_entropy = -np.sum(np.sum(probs * np.log(probs), 0), -1)
            check(out, expected_entropy)
Ejemplo n.º 5
0
    def test_categorical(self):
        batch_size = 10000
        num_categories = 4
        # Create categorical distribution with n categories.
        inputs_space = Box(-1.0,
                           2.0,
                           shape=(batch_size, num_categories),
                           dtype=np.float32)
        values_space = Box(0,
                           num_categories - 1,
                           shape=(batch_size, ),
                           dtype=np.int32)

        inputs = inputs_space.sample()

        for fw, sess in framework_iterator(session=True,
                                           frameworks=("tf", "tf2", "torch")):
            # Create the correct distribution object.
            cls = JAXCategorical if fw == "jax" else Categorical if \
                fw != "torch" else TorchCategorical
            categorical = cls(inputs, {})

            # Do a stability test using extreme NN outputs to see whether
            # sampling and logp'ing result in NaN or +/-inf values.
            self._stability_test(cls,
                                 inputs_space.shape,
                                 fw=fw,
                                 sess=sess,
                                 bounds=(0, num_categories - 1))

            # Batch of size=3 and deterministic (True).
            expected = np.transpose(np.argmax(inputs, axis=-1))
            # Sample, expect always max value
            # (max likelihood for deterministic draw).
            out = categorical.deterministic_sample()
            check(out, expected)

            # Batch of size=3 and non-deterministic -> expect roughly the mean.
            out = categorical.sample()
            check(np.mean(out) if fw == "jax" else tf.reduce_mean(out)
                  if fw != "torch" else torch.mean(out.float()),
                  1.0,
                  decimals=0)

            # Test log-likelihood outputs.
            probs = softmax(inputs)
            values = values_space.sample()

            out = categorical.logp(
                values if fw != "torch" else torch.Tensor(values))
            expected = []
            for i in range(batch_size):
                expected.append(np.sum(np.log(np.array(probs[i][values[i]]))))
            check(out, expected, decimals=4)

            # Test entropy outputs.
            out = categorical.entropy()
            expected_entropy = -np.sum(probs * np.log(probs), -1)
            check(out, expected_entropy)
Ejemplo n.º 6
0
    def test_log_probs_from_logits_and_actions(self):
        """Tests log_probs_from_logits_and_actions."""
        seq_len = 7
        num_actions = 3
        batch_size = 4

        for fw, sess in framework_iterator(frameworks=("torch", "tf"),
                                           session=True):
            vtrace = vtrace_tf if fw == "tf" else vtrace_torch
            policy_logits = Box(-1.0, 1.0, (seq_len, batch_size, num_actions),
                                np.float32).sample()
            actions = np.random.randint(0,
                                        num_actions - 1,
                                        size=(seq_len, batch_size),
                                        dtype=np.int32)

            if fw == "torch":
                action_log_probs_tensor = \
                    vtrace.log_probs_from_logits_and_actions(
                        torch.from_numpy(policy_logits),
                        torch.from_numpy(actions))
            else:
                action_log_probs_tensor = \
                    vtrace.log_probs_from_logits_and_actions(
                        policy_logits, actions)

            # Ground Truth
            # Using broadcasting to create a mask that indexes action logits
            action_index_mask = actions[..., None] == np.arange(num_actions)

            def index_with_mask(array, mask):
                return array[mask].reshape(*array.shape[:-1])

            # Note: Normally log(softmax) is not a good idea because it's not
            # numerically stable. However, in this test we have well-behaved
            # values.
            ground_truth_v = index_with_mask(np.log(softmax(policy_logits)),
                                             action_index_mask)

            if sess:
                action_log_probs_tensor = sess.run(action_log_probs_tensor)
            check(action_log_probs_tensor, ground_truth_v)
Ejemplo n.º 7
0
    def postprocess_trajectory(self,
                               policy: "Policy",
                               sample_batch: SampleBatch,
                               tf_sess: Optional["tf.Session"] = None):
        noisy_action_dist = noise_free_action_dist = None
        # Adjust the stddev depending on the action (pi)-distance.
        # Also see [1] for details.
        # TODO(sven): Find out whether this can be scrapped by simply using
        #  the `sample_batch` to get the noisy/noise-free action dist.
        _, _, fetches = policy.compute_actions(
            obs_batch=sample_batch[SampleBatch.CUR_OBS],
            # TODO(sven): What about state-ins and seq-lens?
            prev_action_batch=sample_batch.get(SampleBatch.PREV_ACTIONS),
            prev_reward_batch=sample_batch.get(SampleBatch.PREV_REWARDS),
            explore=self.weights_are_currently_noisy)

        # Categorical case (e.g. DQN).
        if policy.dist_class in (Categorical, TorchCategorical):
            action_dist = softmax(fetches[SampleBatch.ACTION_DIST_INPUTS])
        # Deterministic (Gaussian actions, e.g. DDPG).
        elif policy.dist_class in [Deterministic, TorchDeterministic]:
            action_dist = fetches[SampleBatch.ACTION_DIST_INPUTS]
        else:
            raise NotImplementedError  # TODO(sven): Other action-dist cases.

        if self.weights_are_currently_noisy:
            noisy_action_dist = action_dist
        else:
            noise_free_action_dist = action_dist

        _, _, fetches = policy.compute_actions(
            obs_batch=sample_batch[SampleBatch.CUR_OBS],
            prev_action_batch=sample_batch.get(SampleBatch.PREV_ACTIONS),
            prev_reward_batch=sample_batch.get(SampleBatch.PREV_REWARDS),
            explore=not self.weights_are_currently_noisy)

        # Categorical case (e.g. DQN).
        if policy.dist_class in (Categorical, TorchCategorical):
            action_dist = softmax(fetches[SampleBatch.ACTION_DIST_INPUTS])
            # Deterministic (Gaussian actions, e.g. DDPG).
        elif policy.dist_class in [Deterministic, TorchDeterministic]:
            action_dist = fetches[SampleBatch.ACTION_DIST_INPUTS]

        if noisy_action_dist is None:
            noisy_action_dist = action_dist
        else:
            noise_free_action_dist = action_dist

        delta = distance = None
        # Categorical case (e.g. DQN).
        if policy.dist_class in (Categorical, TorchCategorical):
            # Calculate KL-divergence (DKL(clean||noisy)) according to [2].
            # TODO(sven): Allow KL-divergence to be calculated by our
            #  Distribution classes (don't support off-graph/numpy yet).
            distance = np.nanmean(
                np.sum(
                    noise_free_action_dist *
                    np.log(noise_free_action_dist /
                           (noisy_action_dist + SMALL_NUMBER)), 1))
            current_epsilon = self.sub_exploration.get_info(
                sess=tf_sess)["cur_epsilon"]
            delta = -np.log(1 - current_epsilon +
                            current_epsilon / self.action_space.n)
        elif policy.dist_class in [Deterministic, TorchDeterministic]:
            # Calculate MSE between noisy and non-noisy output (see [2]).
            distance = np.sqrt(
                np.mean(np.square(noise_free_action_dist - noisy_action_dist)))
            current_scale = self.sub_exploration.get_info(
                sess=tf_sess)["cur_scale"]
            delta = getattr(self.sub_exploration, "ou_sigma", 0.2) * \
                current_scale

        # Adjust stddev according to the calculated action-distance.
        if distance <= delta:
            self.stddev_val *= 1.01
        else:
            self.stddev_val /= 1.01

        # Set self.stddev to calculated value.
        if self.framework == "tf":
            self.stddev.load(self.stddev_val, session=tf_sess)
        else:
            self.stddev = self.stddev_val

        return sample_batch
Ejemplo n.º 8
0
    def postprocess_trajectory(self, policy, sample_batch, tf_sess=None):
        noisy_action_dist = noise_free_action_dist = None
        # Adjust the stddev depending on the action (pi)-distance.
        # Also see [1] for details.
        distribution = policy.compute_action_distribution(
            obs_batch=sample_batch[SampleBatch.CUR_OBS],
            # TODO(sven): What about state-ins and seq-lens?
            prev_action_batch=sample_batch.get(SampleBatch.PREV_ACTIONS),
            prev_reward_batch=sample_batch.get(SampleBatch.PREV_REWARDS),
            explore=self.weights_are_currently_noisy)

        # Categorical case (e.g. DQN).
        if isinstance(distribution, Categorical):
            action_dist = softmax(distribution.inputs)
        else:  # TODO(sven): Other action-dist cases.
            raise NotImplementedError

        if self.weights_are_currently_noisy:
            noisy_action_dist = action_dist
        else:
            noise_free_action_dist = action_dist

        distribution = policy.compute_action_distribution(
            obs_batch=sample_batch[SampleBatch.CUR_OBS],
            # TODO(sven): What about state-ins and seq-lens?
            prev_action_batch=sample_batch.get(SampleBatch.PREV_ACTIONS),
            prev_reward_batch=sample_batch.get(SampleBatch.PREV_REWARDS),
            explore=not self.weights_are_currently_noisy)

        # Categorical case (e.g. DQN).
        if isinstance(distribution, Categorical):
            action_dist = softmax(distribution.inputs)

        if not self.weights_are_currently_noisy:
            noisy_action_dist = action_dist
        else:
            noise_free_action_dist = action_dist

        # Categorical case (e.g. DQN).
        if isinstance(distribution, Categorical):
            # Calculate KL-divergence (DKL(clean||noisy)) according to [2].
            # TODO(sven): Allow KL-divergence to be calculated by our
            #  Distribution classes (don't support off-graph/numpy yet).
            kl_divergence = np.nanmean(
                np.sum(
                    noise_free_action_dist *
                    np.log(noise_free_action_dist /
                           (noisy_action_dist + SMALL_NUMBER)), 1))
            current_epsilon = self.sub_exploration.get_info()["cur_epsilon"]
            if tf_sess is not None:
                current_epsilon = tf_sess.run(current_epsilon)
            delta = -np.log(1 - current_epsilon +
                            current_epsilon / self.action_space.n)
            if kl_divergence <= delta:
                self.stddev_val *= 1.01
            else:
                self.stddev_val /= 1.01

        # Set self.stddev to calculated value.
        if self.framework == "tf":
            self.stddev.load(self.stddev_val, session=tf_sess)
        else:
            self.stddev = self.stddev_val

        return sample_batch
Ejemplo n.º 9
0
    def test_multi_action_distribution(self):
        """Tests the MultiActionDistribution (across all frameworks)."""
        batch_size = 1000
        input_space = Tuple([
            Box(-10.0, 10.0, shape=(batch_size, 4)),
            Box(-2.0, 2.0, shape=(
                batch_size,
                6,
            )),
            Dict({"a": Box(-1.0, 1.0, shape=(batch_size, 4))}),
        ])
        std_space = Box(-0.05, 0.05, shape=(
            batch_size,
            3,
        ))

        low, high = -1.0, 1.0
        value_space = Tuple([
            Box(0, 3, shape=(batch_size, ), dtype=np.int32),
            Box(-2.0, 2.0, shape=(batch_size, 3), dtype=np.float32),
            Dict({"a": Box(0.0, 1.0, shape=(batch_size, 2), dtype=np.float32)})
        ])

        for fw, sess in framework_iterator(session=True):
            if fw == "torch":
                cls = TorchMultiActionDistribution
                child_distr_cls = [
                    TorchCategorical, TorchDiagGaussian,
                    partial(TorchBeta, low=low, high=high)
                ]
            else:
                cls = MultiActionDistribution
                child_distr_cls = [
                    Categorical,
                    DiagGaussian,
                    partial(Beta, low=low, high=high),
                ]

            inputs = list(input_space.sample())
            distr = cls(np.concatenate([inputs[0], inputs[1], inputs[2]["a"]],
                                       axis=1),
                        model={},
                        action_space=value_space,
                        child_distributions=child_distr_cls,
                        input_lens=[4, 6, 4])

            # Adjust inputs for the Beta distr just as Beta itself does.
            inputs[2]["a"] = np.clip(inputs[2]["a"], np.log(SMALL_NUMBER),
                                     -np.log(SMALL_NUMBER))
            inputs[2]["a"] = np.log(np.exp(inputs[2]["a"]) + 1.0) + 1.0
            # Sample deterministically.
            expected_det = [
                np.argmax(inputs[0], axis=-1),
                inputs[1][:, :3],  # [:3]=Mean values.
                # Mean for a Beta distribution:
                # 1 / [1 + (beta/alpha)] * range + low
                (1.0 /
                 (1.0 + inputs[2]["a"][:, 2:] / inputs[2]["a"][:, 0:2])) *
                (high - low) + low,
            ]
            out = distr.deterministic_sample()
            if sess:
                out = sess.run(out)
            check(out[0], expected_det[0])
            check(out[1], expected_det[1])
            check(out[2]["a"], expected_det[2])

            # Stochastic sampling -> expect roughly the mean.
            inputs = list(input_space.sample())
            # Fix categorical inputs (not needed for distribution itself, but
            # for our expectation calculations).
            inputs[0] = softmax(inputs[0], -1)
            # Fix std inputs (shouldn't be too large for this test).
            inputs[1][:, 3:] = std_space.sample()
            # Adjust inputs for the Beta distr just as Beta itself does.
            inputs[2]["a"] = np.clip(inputs[2]["a"], np.log(SMALL_NUMBER),
                                     -np.log(SMALL_NUMBER))
            inputs[2]["a"] = np.log(np.exp(inputs[2]["a"]) + 1.0) + 1.0
            distr = cls(np.concatenate([inputs[0], inputs[1], inputs[2]["a"]],
                                       axis=1),
                        model={},
                        action_space=value_space,
                        child_distributions=child_distr_cls,
                        input_lens=[4, 6, 4])
            expected_mean = [
                np.mean(np.sum(inputs[0] * np.array([0, 1, 2, 3]), -1)),
                inputs[1][:, :3],  # [:3]=Mean values.
                # Mean for a Beta distribution:
                # 1 / [1 + (beta/alpha)] * range + low
                (1.0 / (1.0 + inputs[2]["a"][:, 2:] / inputs[2]["a"][:, :2])) *
                (high - low) + low,
            ]
            out = distr.sample()
            if sess:
                out = sess.run(out)
            out = list(out)
            if fw == "torch":
                out[0] = out[0].numpy()
                out[1] = out[1].numpy()
                out[2]["a"] = out[2]["a"].numpy()
            check(np.mean(out[0]), expected_mean[0], decimals=1)
            check(np.mean(out[1], 0), np.mean(expected_mean[1], 0), decimals=1)
            check(np.mean(out[2]["a"], 0),
                  np.mean(expected_mean[2], 0),
                  decimals=1)

            # Test log-likelihood outputs.
            # Make sure beta-values are within 0.0 and 1.0 for the numpy
            # calculation (which doesn't have scaling).
            inputs = list(input_space.sample())
            # Adjust inputs for the Beta distr just as Beta itself does.
            inputs[2]["a"] = np.clip(inputs[2]["a"], np.log(SMALL_NUMBER),
                                     -np.log(SMALL_NUMBER))
            inputs[2]["a"] = np.log(np.exp(inputs[2]["a"]) + 1.0) + 1.0
            distr = cls(np.concatenate([inputs[0], inputs[1], inputs[2]["a"]],
                                       axis=1),
                        model={},
                        action_space=value_space,
                        child_distributions=child_distr_cls,
                        input_lens=[4, 6, 4])
            inputs[0] = softmax(inputs[0], -1)
            values = list(value_space.sample())
            log_prob_beta = np.log(
                beta.pdf(values[2]["a"], inputs[2]["a"][:, :2],
                         inputs[2]["a"][:, 2:]))
            # Now do the up-scaling for [2] (beta values) to be between
            # low/high.
            values[2]["a"] = values[2]["a"] * (high - low) + low
            inputs[1][:, 3:] = np.exp(inputs[1][:, 3:])
            expected_log_llh = np.sum(
                np.concatenate([
                    np.expand_dims(
                        np.log(
                            [i[values[0][j]]
                             for j, i in enumerate(inputs[0])]), -1),
                    np.log(
                        norm.pdf(values[1], inputs[1][:, :3],
                                 inputs[1][:, 3:])), log_prob_beta
                ], -1), -1)

            values[0] = np.expand_dims(values[0], -1)
            if fw == "torch":
                values = tree.map_structure(lambda s: torch.Tensor(s), values)
            # Test all flattened input.
            concat = np.concatenate(tree.flatten(values),
                                    -1).astype(np.float32)
            out = distr.logp(concat)
            if sess:
                out = sess.run(out)
            check(out, expected_log_llh, atol=15)
            # Test structured input.
            out = distr.logp(values)
            if sess:
                out = sess.run(out)
            check(out, expected_log_llh, atol=15)
            # Test flattened input.
            out = distr.logp(tree.flatten(values))
            if sess:
                out = sess.run(out)
            check(out, expected_log_llh, atol=15)
Ejemplo n.º 10
0
                def policy_mapping_fn(agent_id, episode, worker, **kwargs):

                    # Pick, whether this is:
                    # LE: league-exploiter vs snapshot.
                    # ME: main-exploiter vs (any) main.
                    # M: Learning main vs itself.
                    type_ = np.random.choice(["LE", "ME", "M"],
                                             p=probs_match_types)

                    # Learning league exploiter vs a snapshot.
                    # Opponent snapshots should be selected based on a win-rate-
                    # derived probability.
                    if type_ == "LE":
                        if episode.episode_id % 2 == agent_id:
                            league_exploiter = np.random.choice([
                                p for p in trainable_policies
                                if p.startswith("league_ex")
                            ])
                            logger.debug(
                                f"Episode {episode.episode_id}: AgentID "
                                f"{agent_id} played by {league_exploiter} (training)"
                            )
                            return league_exploiter
                        # Play against any non-trainable policy (excluding itself).
                        else:
                            all_opponents = list(non_trainable_policies)
                            probs = softmax([
                                worker.global_vars["win_rates"][pid]
                                for pid in all_opponents
                            ])
                            opponent = np.random.choice(all_opponents, p=probs)
                            logger.debug(
                                f"Episode {episode.episode_id}: AgentID "
                                f"{agent_id} played by {opponent} (frozen)")
                            return opponent

                    # Learning main exploiter vs (learning main OR snapshot main).
                    elif type_ == "ME":
                        if episode.episode_id % 2 == agent_id:
                            main_exploiter = np.random.choice([
                                p for p in trainable_policies
                                if p.startswith("main_ex")
                            ])
                            logger.debug(
                                f"Episode {episode.episode_id}: AgentID "
                                f"{agent_id} played by {main_exploiter} (training)"
                            )
                            return main_exploiter
                        else:
                            # n% of the time, play against the learning main.
                            # Also always play againt learning main if no
                            # non-learning mains have been created yet.
                            if num_main_policies == 1 or (
                                    np.random.random() <
                                    prob_playing_learning_main):
                                main = "main_0"
                                training = "training"
                            # 100-n% of the time, play against a non-learning
                            # main. Opponent main snapshots should be selected
                            # based on a win-rate-derived probability.
                            else:
                                all_opponents = [
                                    f"main_{p}"
                                    for p in list(range(1, num_main_policies))
                                ]
                                probs = softmax([
                                    worker.global_vars["win_rates"][pid]
                                    for pid in all_opponents
                                ])
                                main = np.random.choice(all_opponents, p=probs)
                                training = "frozen"
                            logger.debug(
                                f"Episode {episode.episode_id}: AgentID "
                                f"{agent_id} played by {main} ({training})")
                            return main

                    # Main policy: Self-play.
                    else:
                        logger.debug(
                            f"Episode {episode.episode_id}: main_0 vs main_0")
                        return "main_0"
Ejemplo n.º 11
0
    def test_multi_action_distribution(self):
        """Tests the MultiActionDistribution (only torch so far)."""
        batch_size = 1000
        input_space = Tuple([
            Box(-10.0, 10.0, shape=(batch_size, 4)),
            Box(-2.0, 2.0, shape=(
                batch_size,
                6,
            ))
        ])
        std_space = Box(-0.05, 0.05, shape=(
            batch_size,
            3,
        ))

        value_space = Tuple([
            Box(0, 3, shape=(batch_size, ), dtype=np.int32),
            Box(-2.0, 2.0, shape=(batch_size, 3), dtype=np.float32)
        ])

        for fw, sess in framework_iterator(frameworks="torch", session=True):
            if fw == "torch":
                cls = TorchMultiActionDistribution
                child_distr_cls = [TorchCategorical, TorchDiagGaussian]
            else:
                cls = MultiActionDistribution
                child_distr_cls = [Categorical, DiagGaussian]

            inputs = list(input_space.sample())
            distr = cls(np.concatenate([inputs[0], inputs[1]], axis=1),
                        model={},
                        action_space=value_space,
                        child_distributions=child_distr_cls,
                        input_lens=[4, 6])

            # Sample deterministically.
            expected_det = [
                np.argmax(inputs[0], axis=-1),
                inputs[1][:, :3],  # [:3]=Mean values.
            ]
            out = distr.deterministic_sample()
            if sess:
                out = sess.run(out)
            check(out[0], expected_det[0])
            check(out[1], expected_det[1])

            # Stochastic sampling -> expect roughly the mean.
            inputs = list(input_space.sample())
            # Fix categorical inputs (not needed for distribution itself, but
            # for our expectation calculations).
            inputs[0] = softmax(inputs[0], -1)
            # Fix std inputs (shouldn't be too large for this test).
            inputs[1][:, 3:] = std_space.sample()
            distr = cls(np.concatenate([inputs[0], inputs[1]], axis=1),
                        model={},
                        action_space=value_space,
                        child_distributions=child_distr_cls,
                        input_lens=[4, 6])
            expected_mean = [
                np.mean(np.sum(inputs[0] * np.array([0, 1, 2, 3]), -1)),
                inputs[1][:, :3],  # [:3]=Mean values.
            ]
            out = distr.sample()
            if sess:
                out = sess.run(out)
            out = list(out)
            if fw == "torch":
                out[0] = out[0].numpy()
                out[1] = out[1].numpy()
            check(np.mean(out[0]), expected_mean[0], decimals=1)
            check(np.mean(out[1], 0), np.mean(expected_mean[1], 0), decimals=1)

            # Test log-likelihood outputs.
            # Make sure beta-values are within 0.0 and 1.0 for the numpy
            # calculation (which doesn't have scaling).
            inputs = list(input_space.sample())
            distr = cls(np.concatenate([inputs[0], inputs[1]], axis=1),
                        model={},
                        action_space=value_space,
                        child_distributions=child_distr_cls,
                        input_lens=[4, 6])
            inputs[0] = softmax(inputs[0], -1)
            values = list(value_space.sample())
            inputs[1][:, 3:] = np.exp(inputs[1][:, 3:])
            expected_log_llh = np.sum(
                np.concatenate([
                    np.expand_dims(
                        np.log(
                            [i[values[0][j]]
                             for j, i in enumerate(inputs[0])]), -1),
                    np.log(
                        norm.pdf(values[1], inputs[1][:, :3], inputs[1][:,
                                                                        3:]))
                ], -1), -1)

            values[0] = np.expand_dims(values[0], -1)
            if fw == "torch":
                values = tree.map_structure(lambda s: torch.Tensor(s), values)
            # Test all flattened input.
            concat = np.concatenate(tree.flatten(values),
                                    -1).astype(np.float32)
            out = distr.logp(concat)
            if sess:
                out = sess.run(out)
            check(out, expected_log_llh, atol=15)
            # Test structured input.
            out = distr.logp(values)
            if sess:
                out = sess.run(out)
            check(out, expected_log_llh, atol=15)
            # Test flattened input.
            out = distr.logp(tree.flatten(values))
            if sess:
                out = sess.run(out)
            check(out, expected_log_llh, atol=15)