コード例 #1
0
    def _get_tgt_encoding_hand_code(self, observations, gps_pred,
                                    compass_pred):
        initial_pg = observations[self.goal_sensor_uuid]
        pred_pg = self._update_pg(initial_pg, gps_pred, compass_pred)

        if "pointgoal_with_gps" in observations:
            true_pg = observations["pointgoal_with_gps"]

            error = (pred_pg - true_pg).norm(dim=-1, keepdim=True)

            if AuxLosses.is_active():
                AuxLosses.register_loss("egomotion_error", error.mean(), 0.1)

            valid_ego_preds = error.detach() < self.ego_error_threshold

            pg = torch.where(valid_ego_preds, pred_pg.detach(), true_pg)
        else:
            pg = pred_pg

        # Back-propping from the policy into the goal seems kinda odd -- doesn't make sense
        # to move the goal to make the policy more/less likely to predict a given action
        # We also have prefect supervision on what this should be!
        goal_observations = _to_mag_and_unit_vec(pg.detach())

        return self.tgt_embeding(goal_observations)
コード例 #2
0
    def forward(self, s, g):
        if g is None:
            return prior(s).sample()

        priv_logits = self._priv(s, g)
        d_cap_logits = self._d_cap(s)

        d_cap = torch.sigmoid(d_cap_logits)
        priv = torch.sigmoid(priv_logits)
        x = d_cap * priv + (1 - d_cap) * prior(s).sample()

        if AuxLosses.is_active():
            priv_log_prob = F.logsigmoid(priv)
            log_d_cap = F.logsigmoid(d_cap_logits)
            AuxLosses.register_loss(
                "information",
                (
                    -d_cap * log_d_cap
                    + (1 - d_cap) * (1 + priv_log_prob)
                    - (log_d_cap + priv_log_prob)
                ).mean(),
                self.beta,
            )

        return x
コード例 #3
0
    def forward(self, s, obs):
        if "pointgoal_with_gps" not in obs:
            return self.prior(s).sample()
        
        if self.use_odometry:
            privileged_info = obs["pointgoal_with_gps_compass"]
        else:
            privileged_info = obs["pointgoal_with_gps"]

        mu, sigma = torch.chunk(
            self.encoder(
                torch.cat(
                    [
                        s,
                        self.priv_embed(
                            _to_mag_and_unit_vec(privileged_info)
                        ),
                    ],
                    -1,
                )
            ),
            2,
            s.dim() - 1,
        )

        if not self.use_info_bot:
            mu.fill_(0.0)
            sigma.fill_(1.0)

        sigma = F.softplus(sigma)
        dist = torch.distributions.Normal(mu, sigma)

        x = dist.rsample()

        # The following code block is for running with Selective Noise Injection

        # if self.training:
        #    x = dist.rsample()
        #else:
        #    x = dist.mean


        if AuxLosses.is_active():
            AuxLosses.register_loss(
                "information",
                _subsampled_mean(
                    torch.distributions.kl_divergence(dist, self.prior(s))
                ),
                # torch.distributions.kl_divergence(dist, self.prior(s)).mean(),
                self.beta,
            )

        return x
コード例 #4
0
    def evaluate_actions(
        self, observations, prev_observations, rnn_hidden_states, prev_actions, masks, action
    ):
        features, _ = self.net(observations, prev_observations, rnn_hidden_states, prev_actions, masks)
        value = self.critic(features)

        if self.supervise_stop:
            stop_distribution = self.stop_action_distribution(features)
            non_stop_distribution = self.non_stop_action_distribution(features)

            action_log_probs = torch.where(
                action == 0,
                stop_distribution.log_probs(torch.full_like(action, 1)),
                stop_distribution.log_probs(torch.full_like(action, 0))
                + non_stop_distribution.log_probs(
                    torch.max(action - 1, torch.zeros_like(action))
                ),
            )

            distribution_entropy = (
                -1.0
                * (
                    stop_distribution.probs[:, -1] * stop_distribution.logits[:, -1]
                    + (
                        stop_distribution.probs[:, 0:1]
                        * non_stop_distribution.probs
                        * (
                            stop_distribution.logits[:, 0:1]
                            + non_stop_distribution.logits
                        )
                    ).sum(-1)
                ).mean()
            )

            stop_loss = F.cross_entropy(
                stop_distribution.logits,
                observations["stop_oracle"].long().squeeze(-1),
                weight=torch.tensor(
                    [1.0, 1.0 / np.sqrt(100.0)], device=features.device
                ),
            )

            AuxLosses.register_loss("stop_loss", stop_loss)
        else:
            action_distribution = self.action_distribution(features)

            action_log_probs = action_distribution.log_probs(action)
            distribution_entropy = action_distribution.entropy().mean()

        return value, action_log_probs, distribution_entropy
コード例 #5
0
    def forward(self, s, obs):
        priv_emb = super().forward(s, obs)

        gps = self.gps_head(s)
        compass = self.compass_head(s)
        pg = _update_pg_gps(obs["pointgoal"], gps)

        embed_pg = self.predicted_embed(_to_mag_and_unit_vec(pg.detach()))

        if AuxLosses.is_active():
            AuxLosses.register_loss(
                "egomotion_error",
                torch.norm(pg - obs["pointgoal_with_gps"], dim=-1).mean(),
                0.0,
            )
            AuxLosses.register_loss(
                "compass_loss",
                _subsampled_mean(_angular_distance_loss(compass, obs["compass"])),
            )
            AuxLosses.register_loss(
                "gps_loss",
                _subsampled_mean(
                    F.mse_loss(gps, obs["gps"], reduction="none").mean(-1)
                ),
            )

        return self.combine_layer(torch.cat([priv_emb, embed_pg], dim=-1))
コード例 #6
0
    def update(self, rollouts):
        advantages = self.get_advantages(rollouts)

        value_loss_epoch = 0
        action_loss_epoch = 0
        aux_losses_epoch = 0
        dist_entropy_epoch = 0

        AuxLosses.activate()
        for e in range(self.ppo_epoch):
            data_generator = rollouts.recurrent_generator(
                advantages, self.num_mini_batch)

            for sample in data_generator:
                (
                    obs_batch,
                    prev_obs_batch,
                    recurrent_hidden_states_batch,
                    actions_batch,
                    prev_actions_batch,
                    value_preds_batch,
                    return_batch,
                    masks_batch,
                    old_action_log_probs_batch,
                    adv_targ,
                ) = sample

                AuxLosses.clear()
                # Reshape to do in a single forward pass for all steps
                (
                    values,
                    action_log_probs,
                    dist_entropy,
                ) = self.actor_critic.evaluate_actions(
                    obs_batch,
                    prev_obs_batch,
                    recurrent_hidden_states_batch,
                    prev_actions_batch,
                    masks_batch,
                    actions_batch,
                )

                ratio = torch.exp(action_log_probs -
                                  old_action_log_probs_batch)
                surr1 = ratio * adv_targ
                surr2 = (torch.clamp(ratio, 1.0 - self.clip_param,
                                     1.0 + self.clip_param) * adv_targ)
                action_loss = -torch.min(surr1, surr2).mean()

                if self.use_clipped_value_loss:
                    value_pred_clipped = value_preds_batch + (
                        values - value_preds_batch).clamp(
                            -self.clip_param, self.clip_param)
                    value_losses = (values - return_batch).pow(2)
                    value_losses_clipped = (value_pred_clipped -
                                            return_batch).pow(2)
                    value_loss = (
                        0.5 *
                        torch.max(value_losses, value_losses_clipped).mean())
                else:
                    value_loss = 0.5 * (return_batch - values).pow(2).mean()

                self.optimizer.zero_grad()
                total_loss = (value_loss * self.value_loss_coef + action_loss -
                              dist_entropy * self.entropy_coef)

                use_aux_loss = self.use_aux_losses

                if use_aux_loss:
                    aux_losses = AuxLosses.reduce()
                else:
                    aux_losses = AuxLosses.get_loss(
                        "information") * AuxLosses._loss_alphas["information"]

                aux_losses_epoch += aux_losses.item()

                total_loss = total_loss + aux_losses

                self.before_backward(total_loss)
                total_loss.backward()
                self.after_backward(total_loss)

                self.before_step()
                self.optimizer.step()
                self.after_step()

                value_loss_epoch += value_loss.item()
                action_loss_epoch += action_loss.item()
                dist_entropy_epoch += dist_entropy.item()

        num_updates = self.ppo_epoch * self.num_mini_batch

        value_loss_epoch /= num_updates
        action_loss_epoch /= num_updates
        dist_entropy_epoch /= num_updates
        aux_losses_epoch /= num_updates

        AuxLosses.deactivate()

        return (
            value_loss_epoch,
            action_loss_epoch,
            dist_entropy_epoch,
            aux_losses_epoch,
        )
コード例 #7
0
    def forward(self, observations, prev_observations, rnn_hidden_states, prev_actions, masks):
        if AuxLosses.is_active():
            AuxLosses.obs = observations

        depth_flag = False
        rgb_flag = False

        if "depth" in observations:
            depth_flag = True
        if "rgb" in observations:
            rgb_flag = True

        if "visual_features" in observations:
            visual_features = observations["visual_features"]
            prev_visual_features = observations["prev_visual_features"]
        
        elif masks.size(0) != rnn_hidden_states.size(1):
            obs_input = {}
            N = rnn_hidden_states.size(1)
            T = masks.size(0) // N

            if depth_flag:
                prev_obs = prev_observations["depth"].view(T, N, *prev_observations["depth"].size()[1:])
                obs = observations["depth"].view(T, N, *observations["depth"].size()[1:])
                obs_input["depth"] = torch.cat((prev_obs[0:1], obs), dim=0)
                obs_input["depth"] = obs_input["depth"].view((T + 1) * N, *obs_input["depth"].size()[2:])

            if rgb_flag:
                prev_obs = prev_observations["rgb"].view(T, N, *prev_observations["rgb"].size()[1:])
                obs = observations["rgb"].view(T, N, *observations["rgb"].size()[1:])
                obs_input["rgb"] = torch.cat((prev_obs[0:1], obs), dim=0)
                obs_input["rgb"] = obs_input["rgb"].view((T + 1) * N, *obs_input["rgb"].size()[2:])

            obs_features = self.visual_encoder(obs_input)
            prev_visual_features = obs_features[:T*N, :, :, :]
            visual_features = obs_features[-T*N:, :, :, :]
                
        else:
            obs_input = {}

            if depth_flag:
                obs_input["depth"] = torch.cat((prev_observations["depth"], observations["depth"]), dim=0)
            if rgb_flag:
                obs_input["rgb"] = torch.cat((prev_observations["rgb"], observations["rgb"]), dim=0)

            obs_features = self.visual_encoder(obs_input)
            prev_visual_features, visual_features = obs_features.split(obs_features.size()[0] // 2, dim=0)

        visual_features = self.compression(visual_features)

        visual_emb = self.visual_fc(visual_features)
	
	    # difference of frames (unit 1)
        flow_emb = self.visual_flow_encoder(
            (visual_features - self.compression(prev_visual_features))
            * masks.view(-1, 1, 1, 1)
        )

        prev_actions = self.prev_action_embedding(
            ((prev_actions.float() + 1) * masks).long().squeeze(-1)
        )

        context_emb = prev_actions + self._tgt_proj(observations["pointgoal"])
        x, rnn_hidden_states = self.state_encoder(
            torch.cat([visual_emb, flow_emb], dim=-1) + context_emb,
            rnn_hidden_states,
            masks,
        )

        tgt_encoding = self.get_tgt_encoding(observations, x)

        x = torch.cat([x, tgt_encoding], dim=-1)
        x = self.goal_mem_layer(x)

        if AuxLosses.is_active():
            n = rnn_hidden_states.size(1)
            t = int(x.size(0) / n)

            delta_ego = self.delta_egomotion_predictor(flow_emb).view(t, n, 3)
            gps_gt = observations["gps"].view(t, n, 2)
            compass_gt = observations["compass"].view(t, n, 1)
            masks = masks.view(t, n, 1)

            gt_delta = gps_gt[1:] - gps_gt[:-1]
            gt_delta = _update_pg_gps_compass(
                gt_delta.view((t - 1) * n, 2),
                torch.zeros_like(gt_delta).view((t - 1) * n, 2),
                compass_gt[:-1].view((t - 1) * n, 1),
            ).view(t - 1, n, 2)
            AuxLosses.register_loss(
                "delta_gps",
                _subsampled_mean(
                    torch.masked_select(
                        F.mse_loss(
                            delta_ego[1:, :, 0:2], gt_delta, reduction="none"
                        ).mean(dim=-1),
                        masks[1:, :, 0].bool(),
                    )
                ),
            )

            AuxLosses.register_loss(
                "delta_compass",
                _subsampled_mean(
                    torch.masked_select(
                        _angular_distance_loss(
                            delta_ego[1:, :, 2:], compass_gt[1:] - compass_gt[:-1]
                        ),
                        masks[1:, :, 0].bool(),
                    )
                ),
            )

        return x, rnn_hidden_states
コード例 #8
0
    def train(self) -> None:
        r"""Main method for DD-PPO.

        Returns:
            None
        """
        self.local_rank, tcp_store = init_distrib_slurm(
            self.config.RL.DDPPO.distrib_backend)
        add_signal_handlers()

        # Stores the number of workers that have finished their rollout
        num_rollouts_done_store = distrib.PrefixStore("rollout_tracker",
                                                      tcp_store)
        num_rollouts_done_store.set("num_done", "0")

        self.world_rank = distrib.get_rank()
        self.world_size = distrib.get_world_size()

        self.config.defrost()
        self.config.TORCH_GPU_ID = self.local_rank
        self.config.SIMULATOR_GPU_ID = self.local_rank
        self.config.TASK_CONFIG.TASK.POINTGOAL_WITH_EGO_PREDICTION_SENSOR.MODEL.GPU_ID = self.local_rank
        # Multiply by the number of simulators to make sure they also get unique seeds
        self.config.TASK_CONFIG.SEED += (self.world_rank *
                                         self.config.NUM_PROCESSES)
        self.config.freeze()

        random.seed(self.config.TASK_CONFIG.SEED)
        np.random.seed(self.config.TASK_CONFIG.SEED)
        torch.manual_seed(self.config.TASK_CONFIG.SEED)

        if torch.cuda.is_available():
            self.device = torch.device("cuda", self.local_rank)
            torch.cuda.set_device(self.device)
        else:
            self.device = torch.device("cpu")

        self.envs = construct_envs(
            self.config,
            get_env_class(self.config.ENV_NAME),
            workers_ignore_signals=True,
        )

        ppo_cfg = self.config.RL.PPO
        if (not os.path.isdir(self.config.CHECKPOINT_FOLDER)
                and self.world_rank == 0):
            os.makedirs(self.config.CHECKPOINT_FOLDER)

        self._setup_actor_critic_agent(ppo_cfg)
        self.agent.init_distributed(find_unused_params=True)

        if self.world_rank == 0:
            logger.info("agent number of trainable parameters: {}".format(
                sum(param.numel() for param in self.agent.parameters()
                    if param.requires_grad)))

        observations = self.envs.reset()
        batch = batch_obs(observations, device=self.device)

        obs_space = self.envs.observation_spaces[0]
        if self._static_encoder:
            self._encoder = self.actor_critic.net.visual_encoder
            obs_space = SpaceDict({
                "visual_features":
                spaces.Box(
                    low=np.finfo(np.float32).min,
                    high=np.finfo(np.float32).max,
                    shape=self._encoder.output_shape,
                    dtype=np.float32,
                ),
                "prev_visual_features":
                spaces.Box(
                    low=np.iinfo(np.uint32).min,
                    high=np.iinfo(np.uint32).max,
                    shape=self._encoder.output_shape,
                    dtype=np.float32,
                ),
                **obs_space.spaces,
            })
            with torch.no_grad():
                batch["visual_features"] = self._encoder(batch)
                batch["prev_visual_features"] = torch.zeros_like(
                    batch["visual_features"])

        rollouts = RolloutStorage(
            ppo_cfg.num_steps,
            self.envs.num_envs,
            obs_space,
            self.envs.action_spaces[0],
            ppo_cfg.hidden_size,
            num_recurrent_layers=self.actor_critic.net.num_recurrent_layers,
        )
        rollouts.to(self.device)

        for sensor in rollouts.observations:
            rollouts.observations[sensor][0].copy_(batch[sensor])
            rollouts.previous_observations[sensor][0].copy_(
                torch.zeros_like(batch[sensor]))

        # batch and observations may contain shared PyTorch CUDA
        # tensors.  We must explicitly clear them here otherwise
        # they will be kept in memory for the entire duration of training!
        batch = None
        observations = None

        current_episode_reward = torch.zeros(self.envs.num_envs,
                                             1,
                                             device=self.device)
        running_episode_stats = dict(
            count=torch.zeros(self.envs.num_envs, 1, device=self.device),
            reward=torch.zeros(self.envs.num_envs, 1, device=self.device),
        )
        window_episode_stats = defaultdict(
            lambda: deque(maxlen=ppo_cfg.reward_window_size))

        t_start = time.time()
        env_time = 0
        pth_time = 0
        count_steps = 0
        count_checkpoints = 0
        start_update = 0
        prev_time = 0

        lr_scheduler = LambdaLR(
            optimizer=self.agent.optimizer,
            lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES),
        )

        interrupted_state = load_interrupted_state()
        if interrupted_state is not None:
            self.agent.load_state_dict(interrupted_state["state_dict"])
            self.agent.optimizer.load_state_dict(
                interrupted_state["optim_state"])
            lr_scheduler.load_state_dict(interrupted_state["lr_sched_state"])

            requeue_stats = interrupted_state["requeue_stats"]
            env_time = requeue_stats["env_time"]
            pth_time = requeue_stats["pth_time"]
            count_steps = requeue_stats["count_steps"]
            count_checkpoints = requeue_stats["count_checkpoints"]
            start_update = requeue_stats["start_update"]
            prev_time = requeue_stats["prev_time"]

        with (TensorboardWriter(self.config.TENSORBOARD_DIR,
                                flush_secs=self.flush_secs)
              if self.world_rank == 0 else contextlib.suppress()) as writer:
            for update in range(start_update, self.config.NUM_UPDATES):
                if ppo_cfg.use_linear_lr_decay:
                    lr_scheduler.step()

                self.actor_critic.update_ib_beta(count_steps)

                if ppo_cfg.use_linear_clip_decay:
                    self.agent.clip_param = ppo_cfg.clip_param * linear_decay(
                        update, self.config.NUM_UPDATES)

                if EXIT.is_set():
                    self.envs.close()

                    if REQUEUE.is_set() and self.world_rank == 0:
                        requeue_stats = dict(
                            env_time=env_time,
                            pth_time=pth_time,
                            count_steps=count_steps,
                            count_checkpoints=count_checkpoints,
                            start_update=update,
                            prev_time=(time.time() - t_start) + prev_time,
                        )
                        save_interrupted_state(
                            dict(
                                state_dict=self.agent.state_dict(),
                                optim_state=self.agent.optimizer.state_dict(),
                                lr_sched_state=lr_scheduler.state_dict(),
                                config=self.config,
                                requeue_stats=requeue_stats,
                            ))

                    requeue_job()
                    return

                count_steps_delta = 0
                self.agent.eval()

                for step in range(ppo_cfg.num_steps):

                    (
                        delta_pth_time,
                        delta_env_time,
                        delta_steps,
                    ) = self._collect_rollout_step(rollouts,
                                                   current_episode_reward,
                                                   running_episode_stats)
                    pth_time += delta_pth_time
                    env_time += delta_env_time
                    count_steps_delta += delta_steps

                    # This is where the preemption of workers happens.  If a
                    # worker detects it will be a straggler, it preempts itself!
                    if (step >=
                            ppo_cfg.num_steps * self.SHORT_ROLLOUT_THRESHOLD
                        ) and int(num_rollouts_done_store.get("num_done")) > (
                            self.config.RL.DDPPO.sync_frac * self.world_size):
                        break

                num_rollouts_done_store.add("num_done", 1)

                self.agent.train()
                if self._static_encoder:
                    self._encoder.eval()

                (
                    delta_pth_time,
                    value_loss,
                    action_loss,
                    dist_entropy,
                    aux_loss,
                ) = self._update_agent(ppo_cfg, rollouts)
                pth_time += delta_pth_time

                stats_ordering = list(sorted(running_episode_stats.keys()))
                stats = torch.stack(
                    [running_episode_stats[k] for k in stats_ordering], 0)
                distrib.all_reduce(stats)

                for i, k in enumerate(stats_ordering):
                    window_episode_stats[k].append(stats[i].clone())

                stats = torch.tensor(
                    [
                        value_loss,
                        action_loss,
                        aux_loss,
                        AuxLosses.get_loss("egomotion_error"),
                        AuxLosses.get_loss("information"),
                        0.1 * dist_entropy,
                        count_steps_delta,
                    ],
                    device=self.device,
                )
                distrib.all_reduce(stats)
                count_steps += stats[-1].item()

                if self.world_rank == 0:
                    num_rollouts_done_store.set("num_done", "0")

                    losses = [
                        stats[i].item() / self.world_size
                        for i in range(stats.size(0) - 1)
                    ]
                    deltas = {
                        k: ((v[-1] - v[0]).sum().item()
                            if len(v) > 1 else v[0].sum().item())
                        for k, v in window_episode_stats.items()
                    }
                    deltas["count"] = max(deltas["count"], 1.0)

                    print(deltas["reward"])

                    writer.add_scalar(
                        "reward",
                        deltas["reward"] / deltas["count"],
                        count_steps,
                    )

                    # Check to see if there are any metrics
                    # that haven't been logged yet
                    metrics = {
                        k: v / deltas["count"]
                        for k, v in deltas.items()
                        if k not in {"reward", "count"}
                    }
                    if len(metrics) > 0:
                        writer.add_scalars("metrics", metrics, count_steps)

                    writer.add_scalars(
                        "losses",
                        {
                            k: l
                            for l, k in zip(losses, [
                                "value", "policy", "aux", "egomotion_error",
                                "information", "entropy"
                            ])
                        },
                        count_steps,
                    )

                    # log stats
                    if update > 0 and update % self.config.LOG_INTERVAL == 0:
                        logger.info("update: {}\tfps: {:.3f}\t".format(
                            update,
                            count_steps /
                            ((time.time() - t_start) + prev_time),
                        ))

                        logger.info(
                            "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t"
                            "frames: {}".format(update, env_time, pth_time,
                                                count_steps))
                        logger.info("Average window size: {}  {}".format(
                            len(window_episode_stats["count"]),
                            "  ".join(
                                "{}: {:.3f}".format(k, v / deltas["count"])
                                for k, v in deltas.items() if k != "count"),
                        ))

                    # checkpoint model
                    if update % self.config.CHECKPOINT_INTERVAL == 0:
                        self.save_checkpoint(
                            f"ckpt.{count_checkpoints}.pth",
                            dict(step=count_steps),
                        )

                        requeue_stats = dict(
                            env_time=env_time,
                            pth_time=pth_time,
                            count_steps=count_steps,
                            count_checkpoints=count_checkpoints,
                            start_update=update,
                            prev_time=(time.time() - t_start) + prev_time,
                        )

                        save_interrupted_state(
                            dict(
                                state_dict=self.agent.state_dict(),
                                optim_state=self.agent.optimizer.state_dict(),
                                lr_sched_state=lr_scheduler.state_dict(),
                                config=self.config,
                                requeue_stats=requeue_stats,
                            ))

                        count_checkpoints += 1

            self.envs.close()
コード例 #9
0
    def forward(self, observations, rnn_hidden_states, prev_actions, masks):
        if AuxLosses.is_active():
            AuxLosses.obs = observations

        if "visual_features" in observations:
            visual_features = observations["visual_features"]
        else:
            visual_features = self.visual_encoder(observations)

        visual_features = self.compression(visual_features)

        visual_emb = self.visual_fc(visual_features)
        flow_emb = self.visual_flow_encoder(
            (visual_features - self.compression(observations["prev_visual_features"]))
            * masks.view(-1, 1, 1, 1)
        )

        prev_actions = self.prev_action_embedding(
            ((prev_actions.float() + 1) * masks).long().squeeze(-1)
        )

        context_emb = prev_actions + self._tgt_proj(observations["pointgoal"])
        x, rnn_hidden_states = self.state_encoder(
            torch.cat([visual_emb, flow_emb], dim=-1) + context_emb,
            rnn_hidden_states,
            masks,
        )

        tgt_encoding = self.get_tgt_encoding(observations, x)

        x = torch.cat([x, tgt_encoding], dim=-1)
        x = self.goal_mem_layer(x)

        if AuxLosses.is_active():
            n = rnn_hidden_states.size(1)
            t = int(x.size(0) / n)

            delta_ego = self.delta_egomotion_predictor(flow_emb).view(t, n, 3)
            gps_gt = observations["gps"].view(t, n, 2)
            compass_gt = observations["compass"].view(t, n, 1)
            masks = masks.view(t, n, 1)

            gt_delta = gps_gt[1:] - gps_gt[:-1]
            gt_delta = _update_pg_gps_compass(
                gt_delta.view((t - 1) * n, 2),
                torch.zeros_like(gt_delta).view((t - 1) * n, 2),
                compass_gt[:-1].view((t - 1) * n, 1),
            ).view(t - 1, n, 2)
            AuxLosses.register_loss(
                "delta_gps",
                _subsampled_mean(
                    torch.masked_select(
                        F.mse_loss(
                            delta_ego[1:, :, 0:2], gt_delta, reduction="none"
                        ).mean(dim=-1),
                        masks[1:, :, 0].bool(),
                    )
                ),
            )

            AuxLosses.register_loss(
                "delta_compass",
                _subsampled_mean(
                    torch.masked_select(
                        _angular_distance_loss(
                            delta_ego[1:, :, 2:], compass_gt[1:] - compass_gt[:-1]
                        ),
                        masks[1:, :, 0].bool(),
                    )
                ),
            )

        return x, rnn_hidden_states
コード例 #10
0
    def seq_forward(self, x, context, mems, masks):
        r"""Forward for a sequence of length T
        Args:
            x: (T, N, -1) Tensor that has been flattened to (T * N, -1)
            hidden_states: The starting hidden state.
            masks: The masks to be applied to hidden state at every timestep.
                A (T, N) tensor flatten to (T * N)
        """
        # x is a (T, N, -1) tensor flattened to (T * N, -1)
        n = mems.size(1)
        t = int(x.size(0) / n)

        # unflatten
        x = x.view(t, n, x.size(1))
        context = context.view(t, n, context.size(1))
        masks = masks.view(t, n)

        if self.self_sup:
            ep_lens = []
            for i in range(n):
                last_zero = 0
                has_zeros = ((
                    masks[1:-1,
                          i] == 0.0).nonzero().squeeze(-1).cpu().unbind(0))
                for z in has_zeros:
                    z = z.item() + 1
                    ep_lens.append(z - last_zero)
                    last_zero = z

                ep_lens.append(t - last_zero)

            k = random.randint(
                1,
                max(
                    min(self.max_self_sup_K,
                        int(0.8 * np.mean(np.array(ep_lens)))), 2),
            )
        else:
            k = None

        content, mems, query = self.transformer.transformer_seq_forward(
            x, context, mems, masks, two_stream_k=k)

        if self.self_sup:
            positives = x
            negatives = []
            for _ in range(3):
                negative_inds = torch.randperm(t * n, device=x.device)
                negatives.append(
                    torch.gather(
                        x.view(t * n, -1),
                        dim=0,
                        index=negative_inds.view(t * n,
                                                 1).expand(t * n, x.size(-1)),
                    ).view(t, n, -1))

            negatives = torch.stack(negatives, dim=-1)

            positives = torch.einsum("...i, ...i -> ...", positives, query)
            negatives = torch.einsum("...ik, ...i -> ...k", negatives, query)
            cpc_logits = torch.stack([positives.unsqueeze(-1), negatives],
                                     dim=-1)

            valid_modeling_queries = torch.ones(t,
                                                n,
                                                device=query.device,
                                                dtype=torch.bool)
            valid_modeling_queries[0:k] = 0
            for i in range(n):
                has_zeros_batch = ((
                    masks[:, i] == 0.0).nonzero().squeeze(-1).cpu().unbind(0))
                for z in has_zeros_batch:
                    valid_modeling_queries[n:n + k, i] = 0

            cpc_loss = torch.masked_select(
                F.cross_entropy(
                    cpc_logits,
                    torch.zeros(t,
                                n,
                                dtype=torch.long,
                                device=cpc_logits.device),
                    reduction="none",
                ),
                valid_modeling_queries,
            ).mean()

            AuxLosses.register_loss("CPC|A", cpc_loss, 0.2)

        content = content.view(t * n, -1)

        return content, mems