Example #1
0
    def test_flatten(self):
        # We flatten Discrete to 1 value
        assert su.flatdim(self.space) == 25
        # gym flattens Discrete to one-hot
        assert gyms.flatdim(self.space) == 35

        asample = su.torch_point(self.space, self.space.sample())
        flattened = su.flatten(self.space, asample)
        unflattened = su.unflatten(self.space, flattened)
        assert self.same(asample, unflattened)

        # suppress `UserWarning: WARN: Box bound precision lowered by casting to float32`
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")

            flattened_space = su.flatten_space(self.space)
            assert flattened_space.shape == (25, )
            # The maximum comes from Discrete(11)
            assert flattened_space.high.max() == 11.0
            assert flattened_space.low.min() == -10.0

            gym_flattened_space = gyms.flatten_space(self.space)
            assert gym_flattened_space.shape == (35, )
            # The maximum comes from Box(-10, 10, (3, 4))
            assert gym_flattened_space.high.max() == 10.0
            assert gym_flattened_space.low.min() == -10.0
Example #2
0
    def __init__(
        self,
        distr: Distr,
        obs: Dict[str, Any],
        action_space: gym.spaces.Space,
        num_active_samplers: Optional[int],
        approx_steps: Optional[int],
        teacher_forcing: Optional[TeacherForcingAnnealingType],
        tracking_info: Optional[Dict[str, Any]],
        always_enforce: bool = False,
    ):
        self.distr = distr
        self.is_sequential = isinstance(self.distr, SequentialDistr)

        # action_space is a gym.spaces.Dict for SequentialDistr, or any gym.Space for other Distr
        self.action_space = action_space
        self.num_active_samplers = num_active_samplers
        self.approx_steps = approx_steps
        self.teacher_forcing = teacher_forcing
        self.tracking_info = tracking_info
        self.always_enforce = always_enforce

        assert (
            "expert_action" in obs
        ), "When using teacher forcing, obs must contain an `expert_action` uuid"

        obs_space = Expert.flagged_space(self.action_space,
                                         use_dict_as_groups=self.is_sequential)
        self.expert = su.unflatten(obs_space, obs["expert_action"])
Example #3
0
 def _zeroed_observation(self) -> Union[OrderedDict, Tuple]:
     # AllenAct-style flattened space (to easily generate an all-zeroes action as an array)
     flat_space = su.flatten_space(self.observation_space)
     # torch point to correctly unflatten `Discrete` for zeroed output
     flat_zeroed = su.torch_point(flat_space, np.zeros_like(flat_space.sample()))
     # unflatten zeroed output and convert to numpy
     return su.numpy_point(
         self.observation_space, su.unflatten(self.observation_space, flat_zeroed)
     )
Example #4
0
    def test_batched(self):
        samples = [self.space.sample() for _ in range(10)]
        flattened = [
            su.flatten(self.space, su.torch_point(self.space, sample))
            for sample in samples
        ]
        stacked = torch.stack(flattened, dim=0)
        unflattened = su.unflatten(self.space, stacked)
        for bidx, refsample in enumerate(samples):
            # Compare each torch-ified sample to the corresponding unflattened from the stack
            assert self.same(su.torch_point(self.space, refsample),
                             unflattened, bidx)

        assert self.same(su.flatten(self.space, unflattened), stacked)
Example #5
0
    def enforce(
        self,
        sample: Any,
        action_space: gym.spaces.Space,
        teacher: OrderedDict,
        teacher_force_info: Optional[Dict[str, Any]],
        action_name: Optional[str] = None,
    ):
        actions = su.flatten(action_space, sample)

        assert (
            len(actions.shape) == 3
        ), f"Got flattened actions with shape {actions.shape} (it should be [1 x `samplers` x `flatdims`])"

        if self.num_active_samplers is not None:
            assert actions.shape[1] == self.num_active_samplers

        expert_actions = su.flatten(action_space,
                                    teacher[Expert.ACTION_POLICY_LABEL])
        assert (
            expert_actions.shape == actions.shape
        ), f"expert actions shape {expert_actions.shape} doesn't match the model's {actions.shape}"

        # expert_success is 0 if the expert action could not be computed and otherwise equals 1.
        expert_action_exists_mask = teacher[Expert.EXPERT_SUCCESS_LABEL]

        if not self.always_enforce:
            teacher_forcing_mask = (torch.distributions.bernoulli.Bernoulli(
                torch.tensor(self.teacher_forcing(self.approx_steps))).sample(
                    expert_action_exists_mask.shape).long().to(
                        actions.device)) * expert_action_exists_mask
        else:
            teacher_forcing_mask = expert_action_exists_mask

        if teacher_force_info is not None:
            teacher_force_info["teacher_ratio/sampled{}".format(
                f"_{action_name}" if action_name is not None else "")] = (
                    teacher_forcing_mask.float().mean().item())

        extended_shape = teacher_forcing_mask.shape + (1, ) * (
            len(actions.shape) - len(teacher_forcing_mask.shape))

        actions = torch.where(teacher_forcing_mask.byte().view(extended_shape),
                              expert_actions, actions)

        return su.unflatten(action_space, actions)
Example #6
0
    def loss(  # type: ignore
        self,
        step_count: int,
        batch: ObservationType,
        actor_critic_output: ActorCriticOutput[Distr],
        *args,
        **kwargs,
    ):
        """Computes the imitation loss.

        # Parameters

        batch : A batch of data corresponding to the information collected when rolling out (possibly many) agents
            over a fixed number of steps. In particular this batch should have the same format as that returned by
            `RolloutStorage.recurrent_generator`.
            Here `batch["observations"]` must contain `"expert_action"` observations
            or `"expert_policy"` observations. See `ExpertActionSensor` (or `ExpertPolicySensor`) for an example of
            a sensor producing such observations.
        actor_critic_output : The output of calling an ActorCriticModel on the observations in `batch`.
        args : Extra args. Ignored.
        kwargs : Extra kwargs. Ignored.

        # Returns

        A (0-dimensional) torch.FloatTensor corresponding to the computed loss. `.backward()` will be called on this
        tensor in order to compute a gradient update to the ActorCriticModel's parameters.
        """
        observations = cast(Dict[str, torch.Tensor], batch["observations"])

        losses = OrderedDict()

        should_report_loss = False

        if "expert_action" in observations:
            if self.expert_sensor is None or not self.expert_sensor.use_groups:
                expert_actions_and_mask = observations["expert_action"]

                assert expert_actions_and_mask.shape[-1] == 2
                expert_actions_and_mask_reshaped = expert_actions_and_mask.view(-1, 2)

                expert_actions = expert_actions_and_mask_reshaped[:, 0].view(
                    *expert_actions_and_mask.shape[:-1], 1
                )
                expert_actions_masks = (
                    expert_actions_and_mask_reshaped[:, 1]
                    .float()
                    .view(*expert_actions_and_mask.shape[:-1], 1)
                )

                total_loss, expert_successes = self.group_loss(
                    cast(CategoricalDistr, actor_critic_output.distributions),
                    expert_actions,
                    expert_actions_masks,
                )

                should_report_loss = expert_successes.item() != 0
            else:
                expert_actions = su.unflatten(
                    self.expert_sensor.observation_space, observations["expert_action"]
                )

                total_loss = 0

                ready_actions = OrderedDict()

                for group_name, cd in zip(
                    self.expert_sensor.group_spaces,
                    cast(
                        SequentialDistr, actor_critic_output.distributions
                    ).conditional_distrs,
                ):
                    assert group_name == cd.action_group_name

                    cd.reset()
                    cd.condition_on_input(**ready_actions)

                    expert_action = expert_actions[group_name][
                        AbstractExpertSensor.ACTION_POLICY_LABEL
                    ]
                    expert_action_masks = expert_actions[group_name][
                        AbstractExpertSensor.EXPERT_SUCCESS_LABEL
                    ]

                    ready_actions[group_name] = expert_action

                    current_loss, expert_successes = self.group_loss(
                        cd, expert_action, expert_action_masks,
                    )

                    should_report_loss = (
                        expert_successes.item() != 0 or should_report_loss
                    )

                    cd.reset()

                    if expert_successes.item() != 0:
                        losses[group_name + "_cross_entropy"] = current_loss.item()
                        total_loss = total_loss + current_loss
        elif "expert_policy" in observations:
            if self.expert_sensor is None or not self.expert_sensor.use_groups:
                assert isinstance(
                    actor_critic_output.distributions, CategoricalDistr
                ), "This implementation currently only supports `CategoricalDistr`"

                expert_policies = cast(Dict[str, torch.Tensor], batch["observations"])[
                    "expert_policy"
                ][..., :-1]
                expert_actions_masks = cast(
                    Dict[str, torch.Tensor], batch["observations"]
                )["expert_policy"][..., -1:]

                expert_successes = expert_actions_masks.sum()
                if expert_successes.item() > 0:
                    should_report_loss = True

                log_probs = cast(
                    CategoricalDistr, actor_critic_output.distributions
                ).log_probs_tensor

                # Add dimensions to `expert_actions_masks` on the right to allow for masking
                # if necessary.
                len_diff = len(log_probs.shape) - len(expert_actions_masks.shape)
                assert len_diff >= 0
                expert_actions_masks = expert_actions_masks.view(
                    *expert_actions_masks.shape, *((1,) * len_diff)
                )

                total_loss = (
                    -(log_probs * expert_policies) * expert_actions_masks
                ).sum() / torch.clamp(expert_successes, min=1)
            else:
                raise NotImplementedError(
                    "This implementation currently only supports `CategoricalDistr`"
                )
        else:
            raise NotImplementedError(
                "Imitation loss requires either `expert_action` or `expert_policy`"
                " sensor to be active."
            )
        return (
            total_loss,
            {"expert_cross_entropy": total_loss.item(), **losses}
            if should_report_loss
            else {},
        )
Example #7
0
 def pick_prev_actions_step(self, step: int) -> ActionType:
     return su.unflatten(self.action_space,
                         self.prev_actions[step:step + 1])
Example #8
0
    def recurrent_generator(
        self,
        advantages: torch.Tensor,
        adv_mean: torch.Tensor,
        adv_std: torch.Tensor,
        num_mini_batch: int,
    ):
        normalized_advantages = (advantages - adv_mean) / (adv_std + 1e-5)

        num_samplers = self.rewards.shape[1]
        assert num_samplers >= num_mini_batch, (
            "The number of task samplers ({}) "
            "must be greater than or equal to the number of "
            "mini batches ({}).".format(num_samplers, num_mini_batch))

        inds = np.round(
            np.linspace(0, num_samplers, num_mini_batch + 1,
                        endpoint=True)).astype(np.int32)
        pairs = list(zip(inds[:-1], inds[1:]))
        random.shuffle(pairs)

        for start_ind, end_ind in pairs:
            cur_samplers = list(range(start_ind, end_ind))

            memory_batch = self.memory.step_squeeze(0).sampler_select(
                cur_samplers)
            observations_batch = self.unflatten_observations(
                self.observations.slice(dim=0,
                                        stop=-1).sampler_select(cur_samplers))

            actions_batch = []
            prev_actions_batch = []
            value_preds_batch = []
            return_batch = []
            masks_batch = []
            old_action_log_probs_batch = []
            adv_targ = []
            norm_adv_targ = []

            for ind in cur_samplers:
                actions_batch.append(self.actions[:, ind])
                prev_actions_batch.append(self.prev_actions[:-1, ind])
                value_preds_batch.append(self.value_preds[:-1, ind])
                return_batch.append(self.returns[:-1, ind])
                masks_batch.append(self.masks[:-1, ind])
                old_action_log_probs_batch.append(self.action_log_probs[:,
                                                                        ind])

                adv_targ.append(advantages[:, ind])
                norm_adv_targ.append(normalized_advantages[:, ind])

            actions_batch = torch.stack(actions_batch, 1)  # type:ignore
            prev_actions_batch = torch.stack(prev_actions_batch,
                                             1)  # type:ignore
            value_preds_batch = torch.stack(value_preds_batch,
                                            1)  # type:ignore
            return_batch = torch.stack(return_batch, 1)  # type:ignore
            masks_batch = torch.stack(masks_batch, 1)  # type:ignore
            old_action_log_probs_batch = torch.stack(  # type:ignore
                old_action_log_probs_batch, 1)
            adv_targ = torch.stack(adv_targ, 1)  # type:ignore
            norm_adv_targ = torch.stack(norm_adv_targ, 1)  # type:ignore

            yield {
                "observations": observations_batch,
                "memory": memory_batch,
                "actions": su.unflatten(self.action_space, actions_batch),
                "prev_actions": su.unflatten(self.action_space,
                                             prev_actions_batch),
                "values": value_preds_batch,
                "returns": return_batch,
                "masks": masks_batch,
                "old_action_log_probs": old_action_log_probs_batch,
                "adv_targ": adv_targ,
                "norm_adv_targ": norm_adv_targ,
            }