Exemplo n.º 1
0
    def forward(self, v_dist_r, map_structure_r, eval=False, summarize=False):

        ACPROF = False
        prof = SimpleProfiler(print=ACPROF, torch_sync=ACPROF)

        if isinstance(v_dist_r, list):
            inners = torch.cat([m.inner_distribution for m in v_dist_r], dim=0)
            outers = torch.cat([m.outer_prob_mass for m in v_dist_r], dim=0)
            v_dist_r = Partial2DDistribution(inners, outers)
            map_structure_r = torch.stack(map_structure_r, dim=0)

        if self.ignore_struct:
            map_structure_r = torch.zeros_like(map_structure_r)

        prof.tick("ac:inputs")
        avec = self.action_base(v_dist_r, map_structure_r)
        prof.tick("ac:networks - action_base")
        vvec = self.value_base(v_dist_r, map_structure_r)
        prof.tick("ac:networks - value_base")
        action_scores = self.action_head(avec)
        prof.tick("ac:networks - action_head")
        value_pred = self.value_head(vvec)
        prof.tick("ac:networks - value head")

        xvel_mean = action_scores[:, 0]
        xvel_std = F.softplus(action_scores[:, 2])
        xvel_dist = FixedNormal(xvel_mean, xvel_std)

        yawrate_mean = action_scores[:, 3]
        yawrate_std = F.softplus(action_scores[:, 5])
        yawrate_dist = FixedNormal(yawrate_mean, yawrate_std)

        # Skew it towards not stopping in the beginning
        stop_logits = action_scores[:, 6]
        stop_dist = FixedBernoulli(logits=stop_logits)

        prof.tick("ac:distributions")
        prof.loop()
        prof.print_stats(1)

        return xvel_dist, yawrate_dist, stop_dist, value_pred
Exemplo n.º 2
0
class LeakyIntegratorMap(MapTransformerBase):
    def __init__(self,
                 source_map_size,
                 world_size_px,
                 world_size_m,
                 lamda=0.2):
        super(LeakyIntegratorMap, self).__init__(source_map_size,
                                                 world_size_px, world_size_m)
        self.map_size = source_map_size
        self.world_size_px = world_size_px
        self.world_size_m = world_size_m
        self.child_transformer = MapTransformerBase(source_map_size,
                                                    world_size_px,
                                                    world_size_m)
        self.lamda = lamda

        self.prof = SimpleProfiler(torch_sync=PROFILE, print=PROFILE)
        self.map_memory = MapTransformerBase(source_map_size, world_size_px,
                                             world_size_m)

        self.dbg_t = None
        self.seq = 0

    def init_weights(self):
        pass

    def reset(self):
        super(LeakyIntegratorMap, self).reset()
        self.map_memory.reset()
        self.child_transformer.reset()
        self.seq = 0

    def cuda(self, device=None):
        MapTransformerBase.cuda(self, device)
        self.child_transformer.cuda(device)
        self.map_memory.cuda(device)
        return self

    def dbg_write_extra(self, map, pose):
        if DebugWriter().should_write():
            map = map[0:1, 0:3]
            self.seq += 1
            # Initialize a transformer module
            if pose is not None:
                if self.dbg_t is None:
                    self.dbg_t = MapTransformerBase(
                        self.map_size, self.world_size_px,
                        self.world_size_m).to(map.device)

                # Transform the prediction to the global frame and write out to disk.
                self.dbg_t.set_map(map, pose)
                map_global, _ = self.dbg_t.get_map(None)
            else:
                map_global = map
            DebugWriter().write_img(map_global[0],
                                    "gif_overlaid",
                                    args={
                                        "world_size": self.world_size_px,
                                        "name": "sm"
                                    })

    def forward(self, images, coverages, cam_poses, add_mask=None, show=False):
        #show="li"
        self.prof.tick(".")
        batch_size = len(images)

        assert add_mask is None or add_mask[
            0] is not None, "The first observation in a sequence needs to be used!"

        # Step 1: All local maps to global: # TODO: Allow inputing global maps when new projector is ready
        self.child_transformer.set_maps(images, cam_poses)
        observations_g, _ = self.child_transformer.get_maps(None)

        self.child_transformer.set_maps(coverages, cam_poses)
        coverages_g, _ = self.child_transformer.get_maps(None)

        masked_observations_g_add = self.lamda * observations_g * coverages_g

        all_maps_out_g = []

        self.prof.tick("maps_to_global")

        # TODO: Draw past trajectory on an extra channel of the semantic map

        # Step 2: Integrate serially in the global frame
        for i in range(batch_size):

            # If we don't have a map yet, initialize the map to this observation
            if self.map_memory.latest_maps is None:
                self.map_memory.set_map(observations_g[i:i + 1], None)
                #self.set_map(observations_g[i:i+1], None)

            # Allow masking of observations
            if add_mask is None or add_mask[i]:
                # Get the current global-frame map
                map_g, _ = self.map_memory.get_map(None)

                #obs_g = observations_g[i:i+1]
                cov_g = coverages_g[i:i + 1]
                obs_cov_g = masked_observations_g_add[i:i + 1]

                # Add the observation into the map using a leaky integrator rule (TODO: Output lamda from model)
                new_map_g = (1 - self.lamda
                             ) * map_g + obs_cov_g + self.lamda * map_g * (
                                 1 - cov_g)

                # Remember this new map
                self.map_memory.set_map(new_map_g, None)
                #self.set_map(new_map_g, None)

            map_g, _ = self.map_memory.get_map(None)

            # Return this map in the camera frame of reference
            #map_r, _ = self.get_map(cam_poses[i:i+1])

            if show != "":
                Presenter().show_image(map_g.data[0, 0:3],
                                       show,
                                       torch=True,
                                       scale=8,
                                       waitkey=50)

            all_maps_out_g.append(map_g)

        self.prof.tick("integrate")

        # Step 3: Convert all maps to local frame
        all_maps_g = torch.cat(all_maps_out_g, dim=0)

        # Write gifs for debugging
        self.dbg_write_extra(all_maps_g, None)

        self.child_transformer.set_maps(all_maps_g, None)
        maps_r, _ = self.child_transformer.get_maps(cam_poses)
        self.set_maps(maps_r, cam_poses)

        self.prof.tick("maps_to_local")
        self.prof.loop()
        self.prof.print_stats(10)

        return maps_r, cam_poses
Exemplo n.º 3
0
class FPVToGlobalMap(MapTransformerBase):
    def __init__(self,
                 source_map_size,
                 world_size_px,
                 world_size_m,
                 img_w,
                 img_h,
                 res_channels,
                 map_channels,
                 cam_h_fov,
                 domain,
                 img_dbg=False):

        super(FPVToGlobalMap, self).__init__(source_map_size, world_size_px,
                                             world_size_m)

        self.image_debug = img_dbg

        self.use_lang_filter = False

        # Process images using a resnet to get a feature map
        if self.image_debug:
            self.img_to_features = nn.MaxPool2d(8)
        else:
            # Provide enough padding so that the map is scaled down by powers of 2.
            self.img_to_features = ImgToFeatures(res_channels, map_channels,
                                                 img_w, img_h)

        # Project feature maps to the global frame
        self.map_projection = PinholeCameraProjectionModuleGlobal(
            source_map_size, world_size_px, world_size_m, img_w, img_h,
            cam_h_fov, domain)

        self.grid_sampler = GridSampler()

        self.prof = SimpleProfiler(torch_sync=PROFILE, print=PROFILE)

        self.actual_images = None

    def cuda(self, device=None):
        MapTransformerBase.cuda(self, device)
        self.map_projection.cuda(device)
        self.grid_sampler.cuda(device)
        self.img_to_features.cuda(device)
        if self.use_lang_filter:
            self.lang_filter.cuda(device)

    def init_weights(self):
        if not self.image_debug:
            self.img_to_features.init_weights()

    def reset(self):
        self.actual_images = None
        super(FPVToGlobalMap, self).reset()

    def forward_fpv_features(self, images, sentence_embeds, tensor_store=None):
        """
        Compute the first-person image features given the first-person images
        If grounding loss is enabled, will also return sentence_embedding conditioned image features
        :param images: images to compute features on
        :param sentence_embeds: sentence embeddings for each image
        :param parent:
        :return: features_fpv_vis - the visual features extracted using the ResNet
                 features_fpv_gnd - the grounded visual features obtained after applying a 1x1 language-conditioned conv
        """
        # Extract image features. If they've been precomputed ahead of time, just grab it by the provided index
        features_fpv_vis = self.img_to_features(images)

        if tensor_store is not None:
            tensor_store.keep_inputs("fpv_features", features_fpv_vis)
        #self.prof.tick("feat")

        # If required, pre-process image features by grounding them in language
        if self.use_lang_filter:
            self.lang_filter.precompute_conv_weights(sentence_embeds)
            features_gnd = self.lang_filter(features_fpv_vis)
            if tensor_store is not None:
                tensor_store.keep_inputs("fpv_features_g", features_gnd)
            #self.prof.tick("gnd")
            return features_fpv_vis, features_gnd

        return features_fpv_vis, None

    def forward(self,
                images,
                poses,
                sentence_embeds,
                tensor_store=None,
                show="",
                halfway=False):

        self.prof.tick("out")

        # self.map_projection is implemented in numpy on CPU.
        # If we give it poses on the GPU, it will transfer them to the CPU, which causes a CUDA SYNC and waits for the
        # ResNet forward pass to complete. To make use of full GPU/CPU concurrency, we move the poses to the cpu first
        poses_cpu = poses.cpu()

        features_fpv_vis_only, features_fpv_gnd_only = self.forward_fpv_features(
            images, sentence_embeds, tensor_store)

        # Halfway HAS to be True and not only truthy
        if halfway == True:
            return None, None

        # If we have grounding features, the overall features are a concatenation of grounded and non-grounded features
        if features_fpv_gnd_only is not None:
            features_fpv_all = torch.cat(
                [features_fpv_gnd_only, features_fpv_vis_only], dim=1)
        else:
            features_fpv_all = features_fpv_vis_only

        # Project first-person view features on to the map in egocentric frame
        grid_maps_cpu = self.map_projection(poses_cpu)
        grid_maps = grid_maps_cpu.to(features_fpv_all.device)

        self.prof.tick("proj_map_and_features")
        features_r = self.grid_sampler(features_fpv_all, grid_maps)

        if DEBUG_WITH_IMG:
            img_w = self.grid_sampler(images, grid_maps)
            if tensor_store is not None:
                tensor_store.keep_inputs("images_w", img_w)
            #Presenter().show_image(images.data[0], "fpv_raw", torch=True, scale=2, waitkey=1)
            #Presenter().show_image(img_w.data[0], "fpv_projected", torch=True, scale=2, waitkey=1)

        # Obtain an ego-centric map mask of where we have new information
        ones_size = list(features_fpv_all.size())
        ones_size[1] = 1
        tmp_ones = torch.ones(ones_size).to(features_r.device)
        new_coverages = self.grid_sampler(tmp_ones, grid_maps)

        # Make sure that new_coverage is a 0/1 mask (grid_sampler applies bilinear interpolation)
        new_coverages = new_coverages - torch.min(new_coverages)
        new_coverages = new_coverages / (torch.max(new_coverages) + 1e-18)

        self.prof.tick("gsample")

        if show != "":
            Presenter().show_image(images.data[0, 0:3],
                                   show + "fpv_img",
                                   torch=True,
                                   scale=2,
                                   waitkey=1)

            grid_maps_np = grid_maps.data[0].numpy()

            Presenter().show_image(grid_maps_np,
                                   show + "_grid",
                                   torch=False,
                                   scale=4,
                                   waitkey=1)
            Presenter().show_image(features_fpv_all.data[0, 0:3],
                                   show + "_preproj",
                                   torch=True,
                                   scale=8,
                                   waitkey=1)
            Presenter().show_image(images.data[0, 0:3],
                                   show + "_img",
                                   torch=True,
                                   scale=1,
                                   waitkey=1)
            Presenter().show_image(features_r.data[0, 0:3],
                                   show + "_projected",
                                   torch=True,
                                   scale=6,
                                   waitkey=1)
            Presenter().show_image(new_coverages.data[0],
                                   show + "_covg",
                                   torch=True,
                                   scale=6,
                                   waitkey=1)

        self.prof.loop()
        self.prof.print_stats(10)

        return features_r, new_coverages
Exemplo n.º 4
0
class MapAffine(nn.Module):
    # TODO: Cleanup unused run_params
    def __init__(self, source_map_size, dest_map_size, world_size_px,
                 world_size_m):
        super(MapAffine, self).__init__()
        self.source_map_size_px = source_map_size
        self.dest_map_size_px = dest_map_size
        self.world_in_map_size_px = world_size_px
        self.world_size_m = world_size_m

        self.affine_2d = Affine2D()

        self.prof = SimpleProfiler(torch_sync=PROFILE, print=PROFILE)

        pos = np.asarray(
            [self.source_map_size_px / 2, self.source_map_size_px / 2])
        rot = np.asarray([0])
        self.canonical_pose_src = Pose(pos, rot)

        pos = np.asarray(
            [self.dest_map_size_px / 2, self.dest_map_size_px / 2])
        rot = np.asarray([0])
        self.canonical_pose_dst = Pose(pos, rot)

    def pose_2d_to_mat_np(self, pose_2d, map_size, inv=False):
        pos = pose_2d.position
        yaw = pose_2d.orientation

        # Transform the img so that the drone's position ends up at the origin
        # TODO: Add batch support
        t1 = get_affine_trans_2d(-pos)

        # Rotate the img so that it's aligned with the drone's orientation
        yaw = -yaw
        t2 = get_affine_rot_2d(-yaw)

        # Translate the img so that it's centered around the drone
        t3 = get_affine_trans_2d([map_size / 2, map_size / 2])

        mat = np.dot(t3, np.dot(t2, t1))

        # Swap x and y axes (because of the BxCxHxW a.k.a BxCxYxX convention)
        swapmat = mat[[1, 0, 2], :]
        mat = swapmat[:, [1, 0, 2]]

        if inv:
            mat = np.linalg.inv(mat)

        return mat

    def poses_2d_to_mat_np(self, pose_2d, map_size, inv=False):
        pos = np.asarray(pose_2d.position)
        yaw = np.asarray(pose_2d.orientation)

        # Transform the img so that the drone's position ends up at the origin
        # TODO: Add batch support
        t1 = get_affine_trans_2d(-pos, batch=True)

        # Rotate the img so that it's aligned with the drone's orientation
        yaw = -yaw
        t2 = get_affine_rot_2d(-yaw, batch=True)

        # Translate the img so that it's centered around the drone
        t3 = get_affine_trans_2d(np.asarray([map_size / 2, map_size / 2]),
                                 batch=False)

        t21 = np.matmul(t2, t1)
        mat = np.matmul(t3, t21)

        # Swap x and y axes (because of the BxCxHxW a.k.a BxCxYxX convention)
        swapmat = mat[:, [1, 0, 2], :]
        mat = swapmat[:, :, [1, 0, 2]]

        if inv:
            mat = np.linalg.inv(mat)

        return mat

    def get_old_to_new_pose_mat(self, old_pose, new_pose):
        old_T_inv = self.pose_2d_to_mat_np(old_pose,
                                           self.source_map_size_px,
                                           inv=True)
        new_T = self.pose_2d_to_mat_np(new_pose,
                                       self.dest_map_size_px,
                                       inv=False)
        mat = np.dot(new_T, old_T_inv)
        #mat = new_T
        mat_t = np_to_tensor(mat, cuda=False)
        return mat_t

    def get_old_to_new_pose_matrices(self, old_pose, new_pose):
        old_T_inv = self.poses_2d_to_mat_np(old_pose,
                                            self.source_map_size_px,
                                            inv=True)
        new_T = self.poses_2d_to_mat_np(new_pose,
                                        self.dest_map_size_px,
                                        inv=False)
        mat = np.matmul(new_T, old_T_inv)
        #mat = new_T
        mat_t = np_to_tensor(mat, insert_batch_dim=False, cuda=False)
        return mat_t

    def get_affine_matrices(self, map_poses, cam_poses, batch_size):
        # Convert the pose from airsim coordinates to the image pixel coordinages
        # If the pose is None, use the canonical pose (global frame)
        if map_poses is not None:
            map_poses = map_poses.numpy(
            )  # TODO: Check if we're gonna have a list here or something
            # TODO: This is the big bottleneck. Could we precompute it in the dataloader?
            map_poses_img = poses_m_to_px(
                map_poses,
                self.source_map_size_px,
                [self.world_in_map_size_px, self.world_in_map_size_px],
                self.world_size_m,
                batch_dim=True)
        else:
            map_poses_img = self.canonical_pose_src.repeat_np(batch_size)

        if cam_poses is not None:
            cam_poses = cam_poses.numpy()
            cam_poses_img = poses_m_to_px(
                cam_poses,
                self.dest_map_size_px,
                [self.world_in_map_size_px, self.world_in_map_size_px],
                self.world_size_m,
                batch_dim=True)
        else:
            cam_poses_img = self.canonical_pose_dst.repeat_np(batch_size)

        # Get the affine transformation matrix to transform the map to the new camera pose
        affines = self.get_old_to_new_pose_matrices(map_poses_img,
                                                    cam_poses_img)

        return affines

    def get_affine_i(self, map_poses, cam_poses, i):
        # Convert the pose from airsim coordinates to the image pixel coordinages
        # If the pose is None, use the canonical pose (global frame)
        self.prof.tick("call")
        if map_poses is not None and map_poses[i] is not None:
            map_pose_i = map_poses[i].numpy()
            map_pose_img = poses_m_to_px(
                map_pose_i, self.source_map_size_px,
                [self.world_in_map_size_px, self.world_in_map_size_px],
                self.world_size_m)
        else:
            map_pose_img = self.canonical_pose_src

        if cam_poses is not None and cam_poses[i] is not None:
            cam_pose_i = cam_poses[i].numpy()
            cam_pose_img = poses_m_to_px(
                cam_pose_i, self.dest_map_size_px,
                [self.world_in_map_size_px, self.world_in_map_size_px],
                self.world_size_m)
        else:
            cam_pose_img = self.canonical_pose_dst

        self.prof.tick("convert_pose")

        # Get the affine transformation matrix to transform the map to the new camera pose
        affine_i = self.get_old_to_new_pose_mat(map_pose_img, cam_pose_img)

        self.prof.tick("calc_affine")
        return affine_i

    def forward(self, maps, current_poses, new_poses):
        """
        Affine transform the map from being centered around map_pose in the canonocial map frame to
        being centered around cam_pose in the canonical map frame.
        Canonical map frame is the one where the map origin aligns with the environment origin, but the env may
        or may not take up the entire map.
        :param map: map centered around the drone in map_pose
        :param current_poses: the previous drone pose in canonical map frame
        :param new_poses: the new drone pose in canonical map frame
        :return:
        """

        # TODO: Handle the case where cam_pose is None and return a map in the canonical frame
        self.prof.tick("out")
        batch_size = maps.size(0)

        self.prof.tick("init")

        affine_matrices_cpu = self.get_affine_matrices(current_poses,
                                                       new_poses, batch_size)

        self.prof.tick("affine_mat_and_pose")

        # Apply the affine transformation on the map
        # The affine matrices should be on CPU (if not, they'll be copied to CPU anyway!)
        # CUDA SYNC point:
        maps_out = self.affine_2d(
            maps,
            affine_matrices_cpu,
            out_size=[self.dest_map_size_px, self.dest_map_size_px])

        self.prof.tick("affine_2d")
        self.prof.loop()
        if batch_size > 1:
            self.prof.print_stats(20)

        return maps_out
Exemplo n.º 5
0
class DrawStartPosOnGlobalMap(MapTransformerBase):

    def __init__(self, source_map_size, world_size_px, world_size_m, lamda=0.2):
        super(DrawStartPosOnGlobalMap, self).__init__(source_map_size, world_size_px, world_size_m)
        self.map_size = source_map_size
        self.world_size_px = world_size_px
        self.world_size_m = world_size_m
        self.child_transformer = MapTransformerBase(source_map_size, world_size_px, world_size_m)

        self.prof = SimpleProfiler(torch_sync=PROFILE, print=PROFILE)
        self.start_pose = None
        self.last_emb = None

        self.dbg_t = None
        self.seq = 0

    def init_weights(self):
        pass

    def reset(self):
        super(DrawStartPosOnGlobalMap, self).reset()
        self.start_pose = None
        self.last_emb = None
        self.child_transformer.reset()
        self.seq = 0

    def cuda(self, device=None):
        MapTransformerBase.cuda(self, device)
        self.child_transformer.cuda(device)
        return self

    def get_start_poses(self, cam_poses_w, sentence_embeddings):
        # For each timestep, get the pose corresponding to the start of the instruction segment
        seq_len = len(sentence_embeddings)
        start_poses = []
        for i in range(seq_len):
            if self.last_emb is not None and (sentence_embeddings[i].data == self.last_emb).all():
                pass # Keep the same start pose since we're on the same segment
            else:
                self.last_emb = sentence_embeddings[i].data
                self.start_pose = cam_poses_w[i]
            start_poses.append(self.start_pose)
        return start_poses

    def forward(self, maps_w, sentence_embeddings, map_poses_w, cam_poses_w, show=False):
        #show="li
        self.prof.tick(".")
        batch_size = len(maps_w)

        # Initialize the layers of the same size as the maps, but with only one channel
        new_layer_size = list(maps_w.size())
        new_layer_size[1] = 1
        all_maps_out_w = empty_float_tensor(new_layer_size, self.is_cuda, self.cuda_device)

        start_poses = self.get_start_poses(cam_poses_w, sentence_embeddings)

        poses_img = [poses_m_to_px(as_pose, self.map_size, self.world_size_px, self.world_size_m) for as_pose in start_poses]
        #poses_img = poses_as_to_img(start_poses, self.world_size, batch_dim=True)

        for i in range(batch_size):
            x = min(max(int(poses_img[i].position.data[0]), 0), new_layer_size[2] - 1)
            y = min(max(int(poses_img[i].position.data[1]), 0), new_layer_size[2] - 1)
            all_maps_out_w[i, 0, x, y] = 10.0

        if show != "":
            Presenter().show_image(all_maps_out_w[0], show, torch=True, waitkey=1)

        self.prof.tick("draw")

        # Step 3: Convert all maps to local frame
        maps_out = torch.cat([Variable(all_maps_out_w), maps_w], dim=1)
        #all_maps_w = torch.cat(all_maps_out_w, dim=0)

        self.prof.loop()
        self.prof.print_stats(10)

        return maps_out, map_poses_w
Exemplo n.º 6
0
class MapAffine(nn.Module):

    # TODO: Cleanup unused run_params
    def __init__(self, map_size, world_size_px, world_size_m):
        super(MapAffine, self).__init__()
        self.map_size = map_size
        self.world_size_px = world_size_px
        self.world_size_m = world_size_m

        self.affine_2d = Affine2D()

        self.prof = SimpleProfiler(torch_sync=PROFILE, print=PROFILE)

    def pose_2d_to_mat_np(self, pose_2d, inv=False):
        pos = pose_2d.position
        yaw = pose_2d.orientation

        # Transform the img so that the drone's position ends up at the origin
        # TODO: Add batch support
        t1 = get_affine_trans_2d(-pos)

        # Rotate the img so that it's aligned with the drone's orientation
        yaw = -yaw
        t2 = get_affine_rot_2d(-yaw)

        # Translate the img so that it's centered around the drone
        t3 = get_affine_trans_2d([self.map_size / 2, self.map_size / 2])

        mat = np.dot(t3, np.dot(t2, t1))

        # Swap x and y axes (because of the BxCxHxW a.k.a BxCxYxX convention)
        swapmat = mat[[1, 0, 2], :]
        mat = swapmat[:, [1, 0, 2]]

        if inv:
            mat = np.linalg.inv(mat)

        return mat

    def get_old_to_new_pose_mat(self, old_pose, new_pose):
        old_T_inv = self.pose_2d_to_mat_np(old_pose, inv=True)
        new_T = self.pose_2d_to_mat_np(new_pose, inv=False)
        mat = np.dot(new_T, old_T_inv)
        #mat = new_T
        mat_t = np_to_tensor(mat)
        return mat_t

    def get_canonical_frame_pose(self):
        pos = np.asarray([self.map_size / 2, self.map_size / 2])
        rot = np.asarray([0])

        return Pose(pos, rot)

    def forward(self, maps, map_pose, cam_pose):
        """
        Affine transform the map from being centered around map_pose in the canonocial map frame to
        being centered around cam_pose in the canonical map frame.
        Canonical map frame is the one where the map origin aligns with the environment origin, but the env may
        or may not take up the entire map.
        :param map: map centered around the drone in map_pose
        :param map_pose: the previous drone pose in canonical map frame
        :param cam_pose: the new drone pose in canonical map frame
        :return:
        """

        # TODO: Handle the case where cam_pose is None and return a map in the canonical frame
        self.prof.tick("out")
        batch_size = maps.size(0)
        affine_matrices = torch.zeros([batch_size, 3, 3]).to(maps.device)

        self.prof.tick("init")
        for i in range(batch_size):

            # Convert the pose from airsim coordinates to the image pixel coordinages
            # If the pose is None, use the canonical pose (global frame)
            if map_pose is not None and map_pose[i] is not None:
                map_pose_i = map_pose[i].numpy()
                map_pose_img = poses_m_to_px(
                    map_pose_i, self.map_size,
                    [self.world_size_px, self.world_size_px],
                    self.world_size_m)
            else:
                map_pose_img = self.get_canonical_frame_pose()

            if cam_pose is not None and cam_pose[i] is not None:
                cam_pose_i = cam_pose[i].numpy()
                cam_pose_img = poses_m_to_px(
                    cam_pose_i, self.map_size,
                    [self.world_size_px, self.world_size_px],
                    self.world_size_m)
            else:
                cam_pose_img = self.get_canonical_frame_pose()

            self.prof.tick("pose")

            # Get the affine transformation matrix to transform the map to the new camera pose
            affine_i = self.get_old_to_new_pose_mat(map_pose_img, cam_pose_img)
            affine_matrices[i] = affine_i
            self.prof.tick("affine")

        # TODO: Do the same with OpenCV and compare results for testing

        # Apply the affine transformation on the map
        maps_out = self.affine_2d(maps, affine_matrices)

        self.prof.tick("affine_sample")
        self.prof.loop()
        self.prof.print_stats(20)

        return maps_out
Exemplo n.º 7
0
class ModelTrajectoryToAction(ModuleWithAuxiliaries):
    def __init__(self, run_name=""):

        super(ModelTrajectoryToAction, self).__init__()
        self.model_name = "lsvd_action"
        self.run_name = run_name
        self.writer = LoggingSummaryWriter(log_dir="runs/" + run_name)

        self.params = get_current_parameters()["ModelPVN"]
        self.aux_weights = get_current_parameters()["AuxWeights"]

        self.prof = SimpleProfiler(torch_sync=PROFILE, print=PROFILE)
        self.iter = nn.Parameter(torch.zeros(1), requires_grad=False)

        # Common
        # --------------------------------------------------------------------------------------------------------------
        self.map_transform_w_to_s = MapTransformerBase(
            source_map_size=self.params["global_map_size"],
            dest_map_size=self.params["local_map_size"],
            world_size=self.params["world_size_px"])

        self.map_transform_r_to_w = MapTransformerBase(
            source_map_size=self.params["local_map_size"],
            dest_map_size=self.params["global_map_size"],
            world_size=self.params["world_size_px"])

        # Output an action given the global semantic map
        if self.params["map_to_action"] == "downsample2":
            self.map_to_action = EgoMapToActionTriplet(
                map_channels=self.params["map_to_act_channels"],
                map_size=self.params["local_map_size"],
                other_features_size=self.params["emb_size"])

        elif self.params["map_to_action"] == "cropped":
            self.map_to_action = CroppedMapToActionTriplet(
                map_channels=self.params["map_to_act_channels"],
                map_size=self.params["local_map_size"],
                manual=self.params["manual_rule"],
                path_only=self.params["action_in_path_only"],
                recurrence=self.params["action_recurrence"])

        self.spatialsoftmax = SpatialSoftmax2d()
        self.gt_fill_missing = MapBatchFillMissing(
            self.params["local_map_size"], self.params["world_size_px"])

        # Don't freeze the trajectory to action weights, because it will be pre-trained during path-prediction training
        # and finetuned on all timesteps end-to-end
        enable_weight_saving(self.map_to_action,
                             "map_to_action",
                             alwaysfreeze=False,
                             neverfreeze=True)

        self.action_loss = ActionLoss()

        self.env_id = None
        self.seg_idx = None
        self.prev_instruction = None
        self.seq_step = 0
        self.get_act_start_pose = None
        self.gt_labels = None

    # TODO: Try to hide these in a superclass or something. They take up a lot of space:
    def cuda(self, device=None):
        ModuleWithAuxiliaries.cuda(self, device)
        self.map_to_action.cuda(device)
        self.action_loss.cuda(device)
        self.map_transform_w_to_s.cuda(device)
        self.map_transform_r_to_w.cuda(device)
        self.gt_fill_missing.cuda(device)
        return self

    def get_iter(self):
        return int(self.iter.data[0])

    def inc_iter(self):
        self.iter += 1

    def init_weights(self):
        self.map_to_action.init_weights()

    def reset(self):
        # TODO: This is error prone. Create a class StatefulModule, iterate submodules and reset all stateful modules
        super(ModelTrajectoryToAction, self).reset()
        self.map_transform_w_to_s.reset()
        self.map_transform_r_to_w.reset()
        self.gt_fill_missing.reset()

    def setEnvContext(self, context):
        print("Set env context to: " + str(context))
        self.env_id = context["env_id"]

    def start_segment_rollout(self):
        import rollout.run_metadata as md
        m_size = self.params["local_map_size"]
        w_size = self.params["world_size_px"]
        self.gt_labels = get_top_down_ground_truth_static_global(
            md.ENV_ID, md.START_IDX, md.END_IDX, m_size, m_size, w_size,
            w_size)
        self.seg_idx = md.SEG_IDX
        self.gt_labels = self.maybe_cuda(self.gt_labels)
        if self.params["clear_history"]:
            self.start_sequence()

    def get_action(self, state, instruction):
        """
        Given a DroneState (from PomdpInterface) and instruction, produce a numpy 4D action (x, y, theta, pstop)
        :param state: DroneState object with the raw image from the simulator
        :param instruction: Tokenized instruction given the corpus
        #TODO: Absorb corpus within model
        :return:
        """
        prof = SimpleProfiler(print=True)
        prof.tick(".")
        # TODO: Simplify this
        self.eval()
        images_np_pure = state.image
        state_np = state.state
        state = Variable(none_padded_seq_to_tensor([state_np]))

        #print("Act: " + debug_untokenize_instruction(instruction))

        # Add the batch dimension

        first_step = True
        if instruction == self.prev_instruction:
            first_step = False
        self.prev_instruction = instruction
        if first_step:
            self.get_act_start_pose = self.cam_poses_from_states(state[0:1])

        self.seq_step += 1

        # This is for training the policy to mimic the ground-truth state distribution with oracle actions
        # b_traj_gt_w_select = b_traj_ground_truth[b_plan_mask_t[:, np.newaxis, np.newaxis, np.newaxis].expand_as(b_traj_ground_truth)].view([-1] + gtsize)
        traj_gt_w = Variable(self.gt_labels)
        b_poses = self.cam_poses_from_states(state)
        # TODO: These source and dest should go as arguments to get_maps (in forward pass not params)
        transformer = MapTransformerBase(
            source_map_size=self.params["global_map_size"],
            world_size=self.params["world_size_px"],
            dest_map_size=self.params["local_map_size"])
        self.maybe_cuda(transformer)
        transformer.set_maps(traj_gt_w, None)
        traj_gt_r, _ = transformer.get_maps(b_poses)
        self.clear_inputs("traj_gt_r_select")
        self.clear_inputs("traj_gt_w_select")
        self.keep_inputs("traj_gt_r_select", traj_gt_r)
        self.keep_inputs("traj_gt_w_select", traj_gt_w)

        action = self(traj_gt_r, firstseg=[self.seq_step == 1])

        output_action = action.squeeze().data.cpu().numpy()

        stop_prob = output_action[3]
        output_stop = 1 if stop_prob > self.params["stop_threshold"] else 0
        output_action[3] = output_stop

        return output_action

    def deterministic_action(self, action_mean, action_std, stop_prob):
        batch_size = action_mean.size(0)
        action = Variable(
            empty_float_tensor((batch_size, 4), self.is_cuda,
                               self.cuda_device))
        action[:, 0:3] = action_mean[:, 0:3]
        action[:, 3] = stop_prob
        return action

    def sample_action(self, action_mean, action_std, stop_prob):
        action = torch.normal(action_mean, action_std)
        stop = torch.bernoulli(stop_prob)
        return action, stop

    # This is called before beginning an execution sequence
    def start_sequence(self):
        self.seq_step = 0
        self.reset()
        print("RESETTED!")
        return

    def cam_poses_from_states(self, states):
        cam_pos = states[:, 9:12]
        cam_rot = states[:, 12:16]
        pose = Pose(cam_pos, cam_rot)
        return pose

    def save(self, epoch):
        filename = self.params[
            "map_to_action_file"] + "_" + self.run_name + "_" + str(epoch)
        save_pytorch_model(self.map_to_action, filename)
        print("Saved action model to " + filename)

    def forward(self, traj_gt_r, firstseg=None):
        """
        :param images: BxCxHxW batch of images (observations)
        :param states: BxK batch of drone states
        :param instructions: BxM LongTensor where M is the maximum length of any instruction
        :param instr_lengths: list of len B of integers, indicating length of each instruction
        :param has_obs: list of booleans of length B indicating whether the given element in the sequence has an observation
        :param yield_semantic_maps: If true, will not compute actions (full model), but return the semantic maps that
            were built along the way in response to the images. This is ugly, but allows code reuse
        :return:
        """
        action_pred = self.map_to_action(traj_gt_r,
                                         None,
                                         fistseg_mask=firstseg)
        out_action = self.deterministic_action(action_pred[:, 0:3], None,
                                               action_pred[:, 3])
        self.keep_inputs("action", out_action)
        self.prof.tick("map_to_action")

        return out_action

    def maybe_cuda(self, tensor):
        if self.is_cuda:
            if False:
                if type(tensor) is Variable:
                    tensor.data.pin_memory()
                elif type(tensor) is Pose:
                    pass
                elif type(tensor) is torch.FloatTensor:
                    tensor.pin_memory()
            return tensor.cuda()
        else:
            return tensor

    def cuda_var(self, tensor):
        return cuda_var(tensor, self.is_cuda, self.cuda_device)

    # Forward pass for training (with batch optimizations
    def sup_loss_on_batch(self, batch, eval):
        self.prof.tick("out")

        action_loss_total = Variable(
            empty_float_tensor([1], self.is_cuda, self.cuda_device))

        if batch is None:
            print("Skipping None Batch")
            return action_loss_total

        actions = self.maybe_cuda(batch["actions"])
        states = self.maybe_cuda(batch["states"])

        firstseg_mask = batch["firstseg_mask"]

        # Auxiliary labels
        traj_ground_truth_select = self.maybe_cuda(batch["traj_ground_truth"])
        # stops = self.maybe_cuda(batch["stops"])
        metadata = batch["md"]
        batch_size = actions.size(0)
        count = 0

        # Loop thru batch
        for b in range(batch_size):
            seg_idx = -1

            self.reset()

            self.prof.tick("out")
            b_seq_len = len_until_nones(metadata[b])

            # TODO: Generalize this
            # Slice the data according to the sequence length
            b_metadata = metadata[b][:b_seq_len]
            b_actions = actions[b][:b_seq_len]
            b_traj_ground_truth_select = traj_ground_truth_select[b]
            b_states = states[b][:b_seq_len]

            self.keep_inputs("traj_gt_global_select",
                             b_traj_ground_truth_select)

            #b_firstseg = get_obs_mask_segstart(b_metadata)
            b_firstseg = firstseg_mask[b][:b_seq_len]

            # ----------------------------------------------------------------------------
            # Optional Auxiliary Inputs
            # ----------------------------------------------------------------------------
            gtsize = list(b_traj_ground_truth_select.size())[1:]
            b_poses = self.cam_poses_from_states(b_states)
            # TODO: These source and dest should go as arguments to get_maps (in forward pass not params)
            transformer = MapTransformerBase(
                source_map_size=self.params["global_map_size"],
                world_size=self.params["world_size_px"],
                dest_map_size=self.params["local_map_size"])
            self.maybe_cuda(transformer)
            transformer.set_maps(b_traj_ground_truth_select, None)
            traj_gt_local_select, _ = transformer.get_maps(b_poses)
            self.keep_inputs("traj_gt_r_select", traj_gt_local_select)
            self.keep_inputs("traj_gt_w_select", b_traj_ground_truth_select)

            # ----------------------------------------------------------------------------

            self.prof.tick("inputs")

            actions = self(traj_gt_local_select, firstseg=b_firstseg)
            action_losses, _ = self.action_loss(b_actions,
                                                actions,
                                                batchreduce=False)
            action_losses = self.action_loss.batch_reduce_loss(action_losses)
            action_loss = self.action_loss.reduce_loss(action_losses)
            action_loss_total = action_loss
            count += b_seq_len

            self.prof.tick("loss")

        action_loss_avg = action_loss_total / (count + 1e-9)
        prefix = self.model_name + ("/eval" if eval else "/train")
        self.writer.add_scalar(prefix + "/action_loss",
                               action_loss_avg.data.cpu()[0], self.get_iter())

        self.prof.tick("out")

        prefix = self.model_name + ("/eval" if eval else "/train")
        self.writer.add_dict(prefix, get_current_meters(), self.get_iter())

        self.inc_iter()

        self.prof.tick("summaries")
        self.prof.loop()
        self.prof.print_stats(1)

        return action_loss_avg

    def get_dataset(self, data=None, envs=None, dataset_name=None, eval=False):
        # TODO: Maybe use eval here
        data_sources = []
        data_sources.append(aup.PROVIDER_TRAJECTORY_GROUND_TRUTH_STATIC)
        return SegmentDataset(data=data,
                              env_list=envs,
                              dataset_name=dataset_name,
                              aux_provider_names=data_sources,
                              segment_level=True)
Exemplo n.º 8
0
class MapBatchFillMissing(MapTransformerBase):

    def __init__(self, source_map_size, world_in_map_size):
        super(MapBatchFillMissing, self).__init__(source_map_size, world_in_map_size)
        self.map_size = source_map_size
        self.world_size = world_in_map_size
        self.child_transformer = MapTransformerBase(source_map_size, world_in_map_size)

        self.prof = SimpleProfiler(torch_sync=PROFILE, print=PROFILE)
        self.map_memory = MapTransformerBase(source_map_size, world_in_map_size)

        self.last_observation = None

        self.dbg_t = None
        self.seq = 0

    def init_weights(self):
        pass

    def reset(self):
        super(MapBatchFillMissing, self).reset()
        self.map_memory.reset()
        self.child_transformer.reset()
        self.seq = 0
        self.last_observation = None

    def cuda(self, device=None):
        MapTransformerBase.cuda(self, device)
        self.child_transformer.cuda(device)
        self.map_memory.cuda(device)
        return self

    def dbg_write_extra(self, map, pose):
        if DebugWriter().should_write():
            map = map[0:1, 0:3]
            self.seq += 1
            # Initialize a transformer module
            if pose is not None:
                if self.dbg_t is None:
                    self.dbg_t = MapTransformerBase(self.map_size, self.world_size)
                    if self.is_cuda:
                        self.dbg_t.cuda(self.cuda_device)

                # Transform the prediction to the global frame and write out to disk.
                self.dbg_t.set_map(map, pose)
                map_global, _ = self.dbg_t.get_map(None)
            else:
                map_global = map
            DebugWriter().write_img(map_global[0], "gif_overlaid", args={"world_size": self.world_size, "name": "identity_integrator"})

    def forward(self, select_images, all_cam_poses, plan_mask=None, show=False):
        #show="li"
        self.prof.tick(".")

        # During rollout, plan_mask will alternate between [True] and [False]
        if plan_mask is None:
            all_images = select_images
            return all_images, all_cam_poses

        full_batch_size = len(all_cam_poses)

        all_maps_out_r = []

        self.prof.tick("maps_to_global")

        # For each timestep, take the latest map that was available, transformed into this timestep
        # Do only a maximum of one transformation for any map to avoid cascading of errors!
        ptr = 0
        for i in range(full_batch_size):
            this_pose = all_cam_poses[i:i+1]
            if plan_mask[i]:
                this_obs = (select_images[ptr:ptr+1], this_pose)
                ptr += 1
                self.last_observation = this_obs
            else:
                assert self.last_observation is not None, "The first observation in a sequence needs to be used!"
                last_map, last_pose = self.last_observation

                # TODO: See if we can speed this up. Perhaps batch for all timesteps inbetween observations
                self.child_transformer.set_map(last_map, last_pose)
                this_obs = self.child_transformer.get_map(this_pose)

            all_maps_out_r.append(this_obs[0])

            if show != "":
                Presenter().show_image(this_obs.data[0, 0:3], show, torch=True, scale=8, waitkey=50)

        self.prof.tick("integrate")

        # Step 3: Convert all maps to local frame
        all_maps_r = torch.cat(all_maps_out_r, dim=0)

        # Write gifs for debugging
        #self.dbg_write_extra(all_maps_r, None)

        self.set_maps(all_maps_r, all_cam_poses)

        self.prof.tick("maps_to_local")
        self.prof.loop()
        self.prof.print_stats(10)

        return all_maps_r, all_cam_poses
Exemplo n.º 9
0
    def train_epoch(self,
                    env_list=None,
                    data_list_real=None,
                    data_list_sim=None,
                    eval=False,
                    restricted_domain=False):

        if eval:
            self.model_real.eval()
            self.model_sim.eval()
            self.model_critic.eval()
            inference_type = "eval"
            epoch_num = self.train_epoch_num
            self.test_epoch_num += 1
        else:
            self.model_real.train()
            self.model_sim.train()
            self.model_critic.train()
            inference_type = "train"
            epoch_num = self.train_epoch_num
            self.train_epoch_num += 1

        # Allow testing with both domains being simulation domain
        if self.params["sim_domain_only"]:
            dataset_real = self.model_sim.get_dataset(
                data=data_list_sim,
                envs=env_list,
                dataset_names=self.sim_datasets,
                dataset_prefix="supervised",
                eval=eval)
            self.model_real = self.model_sim
        else:
            dataset_real = self.model_real.get_dataset(
                data=data_list_real,
                envs=env_list,
                dataset_names=self.real_datasets,
                dataset_prefix="supervised",
                eval=eval)

        dataset_sim = self.model_sim.get_dataset(
            data=data_list_sim,
            envs=env_list,
            dataset_names=self.sim_datasets,
            dataset_prefix="supervised",
            eval=eval)

        dual_model_loader = DualDataloader(dataset_a=dataset_real,
                                           dataset_b=dataset_sim,
                                           batch_size=self.batch_size,
                                           shuffle=True,
                                           num_workers=self.num_loaders,
                                           pin_memory=False,
                                           timeout=0,
                                           drop_last=False,
                                           joint_length="max")

        dual_critic_loader = DualDataloader(dataset_a=dataset_real,
                                            dataset_b=dataset_sim,
                                            batch_size=self.batch_size,
                                            shuffle=True,
                                            num_workers=self.num_loaders,
                                            pin_memory=False,
                                            timeout=0,
                                            drop_last=False,
                                            joint_length="infinite")
        dual_critic_iterator = iter(dual_critic_loader)

        #wloss_before_updates_writer = LoggingSummaryWriter(log_dir=f"runs/{self.run_name}/discriminator_before_updates")
        #wloss_after_updates_writer = LoggingSummaryWriter(log_dir=f"runs/{self.run_name}/discriminator_after_updates")

        samples_real = len(dataset_real)
        samples_sim = len(dataset_sim)
        if samples_real == 0 or samples_sim == 0:
            print(
                f"DATASET HAS NO DATA: REAL: {samples_real > 0}, SIM: {samples_sim > 0}"
            )
            return -1.0

        num_batches = len(dual_model_loader)

        epoch_loss = 0
        count = 0
        critic_elapsed_iterations = 0

        prof = SimpleProfiler(torch_sync=PROFILE, print=PROFILE)

        prof.tick("out")

        # Alternate training critic and model
        for real_batch, sim_batch in dual_model_loader:
            if restricted_domain == "real":
                sim_batch = real_batch
            elif restricted_domain == "simulator":
                real_batch = sim_batch
            if real_batch is None or sim_batch is None:
                continue

            # We run more updates on the sim data than on the real data to speed up training and
            # avoid overfitting on the scarce real data
            if self.sim_steps_per_real_step == 1 or self.sim_steps_per_real_step == 0 or count % self.sim_steps_per_real_step == 0:
                train_sim_only = False
            else:
                train_sim_only = True

            if sim_batch is None or (not train_sim_only
                                     and real_batch is None):
                continue

            prof.tick("load_model_data")
            # Train the model for model_steps in a row, then train the critic, and repeat
            critic_batch_num = 0

            if count % self.model_steps == 0 and not eval and not self.disable_wloss:
                #print("\nTraining critic\n")
                # Train the critic for self.critic_steps steps
                if critic_elapsed_iterations > self.critic_warmup_iterations:
                    critic_steps = self.critic_steps
                    if self.critic_steps_cycle:
                        critic_steps_delta = int(
                            self.critic_steps_amplitude *
                            math.sin(count * 3.14159 /
                                     self.critic_steps_period) + 0.5)
                        critic_steps += critic_steps_delta
                else:
                    critic_steps = self.critic_warmup_steps

                assert (critic_steps >
                        0), "Need more than one iteration for critic!"
                for cstep in range(critic_steps):

                    # Each batch is actually a single rollout (we batch the rollout data across the sequence)
                    # To collect a batch of rollouts, we need to keep iterating
                    real_store = KeyTensorStore()
                    sim_store = KeyTensorStore()
                    for b in range(self.critic_batch_size):
                        # Get the next non-None batch
                        real_c_batch, sim_c_batch = None, None
                        while real_c_batch is None or sim_c_batch is None:
                            real_c_batch, sim_c_batch = next(
                                dual_critic_iterator)
                        prof.tick("critic_load_data")
                        # When training the critic, we don't backprop into the model, so we don't need gradients here
                        with torch.no_grad():
                            real_loss, real_store_b = self.model_real.sup_loss_on_batch(
                                real_c_batch, eval=eval, halfway=True)
                            sim_loss, sim_store_b = self.model_sim.sup_loss_on_batch(
                                sim_c_batch, eval=eval, halfway=True)
                        prof.tick("critic_features")
                        real_store.append(real_store_b)
                        sim_store.append(sim_store_b)
                        prof.tick("critic_store_append")

                    # Forward the critic
                    # The real_store and sim_store should really be a batch of multiple rollouts
                    wdist_loss_a, critic_store = self.model_critic.calc_domain_loss(
                        real_store, sim_store)

                    prof.tick("critic_domain_loss")

                    # Store the first and last critic loss
                    #if cstep == 0:
                    #    wdist_loss_before_updates = wdist_loss_a.detach().cpu()
                    #if cstep == critic_steps - 1:
                    #    wdist_loss_after_updates = wdist_loss_a.detach().cpu()

                    # Update the critic
                    critic_batch_num += 1
                    self.optim_critic.zero_grad()
                    # Wasserstein distance is maximum distance transport cost under Lipschitz constraint, so we maximize it
                    (-wdist_loss_a).backward()
                    self.optim_critic.step()
                    sys.stdout.write(
                        f"\r    Critic batch: {critic_batch_num}/{critic_steps} d_loss: {wdist_loss_a.data.item()}"
                    )
                    sys.stdout.flush()
                    prof.tick("critic_backward")

                # Write wasserstein loss before and after wasertein loss updates
                #prefix = "pvn_critic" + ("/eval" if eval else "/train")
                #wloss_before_updates_writer.add_scalar(f"{prefix}/w_score_before_updates", wdist_loss_before_updates.item(), self.model_critic.get_iter())
                #wloss_after_updates_writer.add_scalar(f"{prefix}/w_score_before_updates", wdist_loss_after_updates.item(), self.model_critic.get_iter())

                critic_elapsed_iterations += 1
                print("Continuing model training\n")

                # Clean up GPU memory
                del wdist_loss_a
                del critic_store
                del real_store
                del sim_store
                del real_store_b
                del sim_store_b
                prof.tick("del")

            # Forward the model
            real_store = KeyTensorStore()
            sim_store = KeyTensorStore()
            real_loss = None
            sim_loss = None
            # TODO: Get rid of this loop!. It doesn't even loop over and sample new batches
            for b in range(self.model_batch_size):
                real_loss_b, real_store_b = self.model_real.sup_loss_on_batch(
                    real_batch,
                    eval,
                    halfway=train_sim_only,
                    grad_noise=self.real_grad_noise)
                real_loss = (real_loss +
                             real_loss_b) if real_loss else real_loss_b
                real_store.append(real_store_b)

                sim_loss_b, sim_store_b = self.model_sim.sup_loss_on_batch(
                    sim_batch, eval, halfway=False)
                sim_loss = (sim_loss + sim_loss_b) if sim_loss else sim_loss_b
                sim_store.append(sim_store_b)
                prof.tick("model_forward")

            sim_loss = sim_loss / self.model_batch_size
            if train_sim_only:
                total_loss = sim_loss
            else:
                real_loss = real_loss / self.model_batch_size
                total_loss = real_loss + sim_loss

            if not self.disable_wloss:
                # Forward the critic
                wdist_loss_b, critic_store = self.model_critic.calc_domain_loss(
                    real_store, sim_store)

                prof.tick("model_domain_loss")
                # Minimize average real/sim losses, maximize domain loss
                total_loss = total_loss + wdist_loss_b

            # Grad step
            if not eval and total_loss.requires_grad:
                self.optim_models.zero_grad()
                try:
                    total_loss.backward()
                except Exception as e:
                    print("Error backpropping: ")
                    print(e)
                self.optim_models.step()
                prof.tick("model_backward")

            sys.stdout.write(
                f"\r Batch: {count}/{num_batches} r_loss: {real_loss.data.item() if real_loss else None} s_loss: {sim_loss.data.item()}"
            )
            sys.stdout.flush()

            # Get losses as floats
            epoch_loss += total_loss.data.item()
            count += 1

            self.train_segment += 0 if eval else 1
            self.test_segment += 1 if eval else 0

            prof.loop()
            prof.print_stats(self.model_steps)

        print("")
        epoch_loss /= (count + 1e-15)

        return epoch_loss
Exemplo n.º 10
0
class SegmentDataset(Dataset):
    def __init__(self,
                 data=None,
                 env_list=None,
                 dataset_names=["simulator"],
                 dataset_prefix="supervised",
                 domain="sim",
                 max_traj_length=None,
                 aux_provider_names=[],
                 segment_level=False,
                 cache=False):
        """
        Dataset for the replay memory
        :param data: if data is pre-loaded in memory, this is the training data
        :param env_list: if data is to be loaded by the dataset, this is the list of environments for which to include data
        :param dataset_names: list of datasets from which to load data
        :param dataset_prefix: name of the dataset. Default: supervised will use data collected with collect_supervised_data
        :param max_traj_length: truncate trajectories to this long
        :param cuda:
        :param aux_provider_names:
        """

        # If data is already loaded in memory, use it
        self.data = data
        self.prof = SimpleProfiler(torch_sync=False, print=PROFILE)
        self.min_seg_len = P.get_current_parameters()["Data"].get("min_seg_len", 3)
        self.do_cache = P.get_current_parameters()["Data"].get("cache", False)
        self.dataset_prefix = dataset_prefix
        self.dataset_names = dataset_names
        self.domain = domain

        self.env_restrictions = P.get_current_parameters()["Data"].get("dataset_env_restrictions")
        if self.env_restrictions:
            self.dataset_restricted_envs = {dname:P.get_current_parameters()["Data"]["EnvRestrictionGroups"][self.env_restrictions[dname]] for dname in dataset_names if dname in self.env_restrictions}
            print(f"Using restricted envs: {list(self.dataset_restricted_envs.keys())}")
        else:
            self.dataset_restricted_envs = {}

        self.max_traj_length = max_traj_length
        train_instr, dev_instr, test_instr, corpus = get_all_instructions()
        # TODO: This shouldn't have access to all instructions. We should really make distinct train, dev, test modes
        self.all_instr = {**train_instr, **dev_instr, **test_instr}

        train_instr_full, dev_instr_full, test_instr_full, corpus = get_all_instructions(full=True)
        self.all_instr_full = {**train_instr_full, **dev_instr_full, **test_instr_full}

        self.segment_level = segment_level
        self.sample_ids = []

        if self.data is None:
            assert env_list is not None
            for i, dataset_name in enumerate(self.dataset_names):
                dataset_env_list = filter_env_list_has_data(dataset_name, env_list, dataset_prefix)
                if self.segment_level:
                    dataset_env_list, dataset_seg_list = self.split_into_segments(dataset_env_list, dataset_name)
                else:
                    dataset_seg_list = [0 for _ in dataset_env_list]
                for env, seg in zip(dataset_env_list, dataset_seg_list):
                    self.sample_ids.append((dataset_name, env, seg))

        self.token2word, self.word2token = get_word_to_token_map(corpus)
        self.aux_provider_names = aux_provider_names
        self.aux_label_names = get_aux_label_names(aux_provider_names)
        self.stackable_names = get_stackable_label_names(aux_provider_names)
        self.data_cache = {dataset_name:{} for dataset_name in dataset_names}

        self.traj_len = P.get_current_parameters()["Setup"]["trajectory_length"]

    def load_env_data(self, dataset_name, env_id):
        if self.do_cache:
            if env_id not in self.data_cache[dataset_name]:
                self.data_cache[dataset_name][env_id] = load_single_env_from_dataset(dataset_name, env_id, self.dataset_prefix)
            return self.data_cache[dataset_name][env_id]
        else:
            return load_single_env_from_dataset(dataset_name, env_id, self.dataset_prefix)

    def __len__(self):
        if self.data is not None:
            return len(self.data)
        else:
            return len(self.sample_ids)

    def __getitem__(self, idx):
        self.prof.tick("out")
        # If data is already loaded, use it
        if self.data is not None:
            seg_data = self.data[idx]
            raise NotImplementedError("Not implemented and tested")
            if type(seg_data) is int:
                raise NotImplementedError("Mixing dynamically loaded envs with training data is no longer supported.")
        else:
            dataset_name, env_id, seg_idx = self.sample_ids[idx]
            env_data = self.load_env_data(dataset_name, env_id)

            if self.segment_level:
                seg_data = []
                segs_in_data = set()
                for sample in env_data:
                    # This is a hack around the dataset format change - some stuff used to be inside the metadata dict,
                    # but is now moved into the root level
                    if "metadata" not in sample:
                        sample["metadata"] = sample
                    # TODO: Set this at rollout time - we know which domain we're rolling out, but this can potentially be mixed up
                    sample["metadata"]["domain"] = self.domain
                    segs_in_data.add(sample["metadata"]["seg_idx"])

                # Keep the segments for which we have instructions
                segs_in_data_and_instructions = set()
                for _seg_idx in segs_in_data:
                    if get_instruction_segment(env_id, 0, _seg_idx, all_instr=self.all_instr_full) is not None:
                        segs_in_data_and_instructions.add(_seg_idx)

                if seg_idx not in segs_in_data_and_instructions:
                    if DEBUG: print(f"Segment {env_id}::{seg_idx} not in (data)and(instructions)")
                    # If there's a single segment in this entire dataset, just return that segment even if it's not a match.
                    if len(segs_in_data) == 1:
                        seg_data = env_data
                        if DEBUG: print(f"  Only one seg in data ({segs_in_data}): returning that")
                    # Otherwise return a random segment instead
                    elif len(segs_in_data_and_instructions) > 0:
                        seg_idx = random.choice(list(segs_in_data_and_instructions))
                        if DEBUG: print(f"  Returning a random segment from (data)and(instructions): {seg_idx}")
                    elif dataset_name == "real" and len(segs_in_data) > 0:
                        seg_idx = random.choice(list(segs_in_data))
                        if DEBUG: print(f"  REAL dataset. Returning a random seg from data: {seg_idx}")
                    else:
                        seg_idx = -1
                        if DEBUG: print(f"  No segment found. Skipping example")

                if len(seg_data) == 0:
                    if DEBUG: print(f"   Grabing segment: {seg_idx}")
                    for sample in env_data:
                        if sample["metadata"]["seg_idx"] == seg_idx:
                            seg_data.append(sample)
                if DEBUG: print(f"   Returning segment data of length: {len(seg_data)}")
            else:
                seg_data = env_data
        # I get a lot of Nones here in RL training because the dataset index is created based on different data than available!
        # TODO: in RL training, treat entire environment as a single segment and don't distinguish.
        # How? Check above
        if len(seg_data) < self.min_seg_len:
            print(f"   None reason: len:{len(seg_data)} in {dataset_name}, env:{env_id}, seg:{seg_idx}")
            return None

        if len(seg_data) > self.traj_len:
            seg_data = seg_data[:self.traj_len]

        seg_idx = seg_data[0]["metadata"]["seg_idx"]
        set_idx = seg_data[0]["metadata"]["set_idx"]
        env_id = seg_data[0]["metadata"]["env_id"]
        instr = get_instruction_segment(env_id, set_idx, seg_idx, all_instr=self.all_instr)
        if instr is None and dataset_name != "real":
            #print(f"{dataset_name} Seg {env_id}:{set_idx}:{seg_idx} not present in instruction data")
            return None

        instr = get_instruction_segment(env_id, set_idx, seg_idx, all_instr=self.all_instr_full)
        if instr is None:
            print(f"{dataset_name} Seg {env_id}:{set_idx}:{seg_idx} not present in FULL instruction data. WTF?")
            return None

        # Convert to tensors, replacing Nones with zero's
        images_in = [seg_data[i]["state"].image if i < len(seg_data) else None for i in range(len(seg_data))]
        states = [seg_data[i]["state"].state if i < len(seg_data) else None for i in range(len(seg_data))]

        images_np = standardize_images(images_in)
        images = none_padded_seq_to_tensor(images_np)

        #depth_images_np = standardize_depth_images(images_in)
        #depth_images = none_padded_seq_to_tensor(depth_images_np)

        states = none_padded_seq_to_tensor(states)

        actions = [s["ref_action"] for s in seg_data]
        actions = none_padded_seq_to_tensor(actions)
        stops = [1.0 if s["done"] else 0.0 for s in seg_data]

        # e.g. [1 1 1 1 1 1 0 0 0 0 .. 0] for segment with 6 samples
        mask = [1.0 if s["ref_action"] is not None else 0.0 for s in seg_data]

        stops = torch.FloatTensor(stops)
        mask = torch.FloatTensor(mask)

        # This is a list, converted to tensor in collate_fn
        #if INSTRUCTIONS_FROM_FILE:
        #    tok_instructions = [tokenize_instruction(load_instruction(md["env_id"], md["set_idx"], md["seg_idx"]), self.word2token) if s["md"] is not None else None for s in seg_data]
        #else:
        tok_instructions = [tokenize_instruction(s["instruction"], self.word2token) if s["instruction"] is not None else None for s in seg_data]

        md = [seg_data[i]["metadata"] for i in range(len(seg_data))]
        flag = md[0]["flag"] if "flag" in md[0] else None

        data = {
            "instr": tok_instructions,
            "images": images,
            #"depth_images": depth_images,
            "states": states,
            "actions": actions,
            "stops": stops,
            "masks": mask,
            "flags": flag,
            "md": md
        }

        self.prof.tick("getitem_core")
        for aux_provider_name in self.aux_provider_names:
            aux_datas = resolve_data_provider(aux_provider_name)(seg_data, data)
            for d in aux_datas:
                data[d[0]] = d[1]
            self.prof.tick("getitem_" + aux_provider_name)

        return data

    def split_into_segments(self, env_list, dname):
        envs = []
        segs = []
        sk = 0
        skipenv = 0
        for env_id in env_list:
            # If we only allow certain envs from a dataset, and this env is not allowed, skip it
            # (Intended use is to train Stage1 with limited real-world data and compare)
            if dname in self.dataset_restricted_envs:
                if env_id not in self.dataset_restricted_envs[dname]:
                    skipenv += 1
                    continue
            # 0th instr set
            instruction_set = self.all_instr[env_id][0]["instructions"]
            for seg in instruction_set:
                seg_idx = seg["seg_idx"]
                if DEBUG: print(f"For env {env_id} including segment: {seg_idx}")
                envs.append(env_id)
                segs.append(seg_idx)
        print(f"Skipped {sk} segments due to merge_len constraints from dataset: {dname}")
        print(f"Skipped {skipenv} environments due to restriction on dataset: {dname}")
        print(f"  kept {len(segs)} segments")
        #envs, segs = self.filter_segment_availability(dname, envs, segs)
        return envs, segs

    def filter_segment_availability(self, dname, envs, segs):
        data_env = None
        envs_out, segs_out = [], []
        # TODO: When saving envs, also save metadata for which segments are present
        for env_id, seg_id in zip(envs, segs):
                md = load_single_env_metadata_from_dataset(dname, env_id, self.dataset_prefix)
                if md is None or seg_id in md["seg_ids"]:
                    envs_out.append(env_id)
                    segs_out.append(seg_id)
                else:
                    print(f"Env {env_id} doesn't have seg {seg_id}")
        return envs_out, segs_out

    def set_word2token(self, token2term, word2token):
        self.token2term = token2term
        self.word2token = word2token

    def stack_tensors(self, one):
        if one is None:
            return None
        one = torch.stack(one, dim=0)
        one = Variable(one)
        return one

    def collate_fn(self, list_of_samples):
        self.prof.tick("out")
        if None in list_of_samples:
            return None

        data_batch = dict_zip(list_of_samples)

        data_t = dict_map(data_batch, self.stack_tensors,
                          ["images", "states", "actions", "stops", "masks"] + self.stackable_names)

        instructions_t, instruction_lengths = instruction_sequence_batch_to_tensor(data_batch["instr"])

        data_t["instr"] = instructions_t
        data_t["instr_len"] = instruction_lengths
        data_t["cam_pos"] = data_t["states"][:, 9:12].clone()
        data_t["cam_rot"] = data_t["states"][:, 12:16].clone()

        self.prof.tick("collate")
        self.prof.loop()
        self.prof.print_stats(5)
        return data_t
Exemplo n.º 11
0
class PVN_Wrapper_Bidomain(nn.Module):
    def __init__(self,
                 run_name="",
                 model_instance_name="only",
                 oracle_stage1=False):
        super(PVN_Wrapper_Bidomain, self).__init__()
        self.instance_name = model_instance_name
        self.s1_params = P.get_current_parameters()["ModelPVN"]["Stage1"]
        self.s2_params = P.get_current_parameters()["ModelPVN"]["Stage2"]
        self.wrapper_params = P.get_current_parameters()["PVNWrapper"]
        self.oracle_stage1 = oracle_stage1

        self.real_drone = P.get_current_parameters()["Setup"]["real_drone"]
        self.rviz = None
        if self.real_drone and P.get_current_parameters()["Setup"].get(
                "use_rviz", False):
            self.rviz = RvizInterface(
                base_name="/pvn/",
                map_topics=["semantic_map", "visitation_dist"],
                markerarray_topics=["instruction"])

        self.stage1_visitation_prediction = PVN_Stage1_Bidomain(
            run_name, model_instance_name)

        self.keyboard = self.wrapper_params.get("keyboard")
        self.rl = self.wrapper_params[
            "learning_mode"] == "reinforcement_learning"

        if self.keyboard:
            self.stage2_action_generation = PVN_Stage2_Keyboard(
                run_name, model_instance_name)
        else:  # self.wrapper_params["pvn_version"] == "v2":
            # Wrap Stage 1 to hide it from PyTorch
            self.stage2_action_generation = PVN_Stage2_ActorCritic(
                run_name, model_instance_name)
        #else:
        #    self.stage2_action_generation = PVN_Stage2_Bidomain(run_name, model_instance_name)

        self.load_models_from_file()

        #self.spatialsoftmax = SpatialSoftmax2d()
        self.visitation_softmax = VisitationSoftmax()

        self.map_transformer_w_to_r = MapTransformer(
            source_map_size=self.s1_params["global_map_size"],
            dest_map_size=self.s1_params["local_map_size"],
            world_size_m=self.s1_params["world_size_m"],
            world_size_px=self.s1_params["world_size_px"])

        self.visitation_reward = VisitationReward(
            world_size_px=self.s1_params["world_size_px"],
            world_size_m=self.s1_params["world_size_m"])
        self.visitation_and_exploration_reward = VisitationAndExplorationReward(
            world_size_px=self.s1_params["world_size_px"],
            world_size_m=self.s1_params["world_size_m"])
        self.wd_visitation_and_exploration_reward = WDVisitationAndExplorationReward(
            world_size_px=self.s1_params["world_size_px"],
            world_size_m=self.s1_params["world_size_m"],
            params=self.wrapper_params["wd_reward"])

        self.map_coverage_reward = MapCoverageReward()
        self.action_oob_reward = ActionOutOfBoundsReward()

        self.map_boundary = self.make_map_boundary()

        self.prev_instruction = None
        self.start_cam_poses = None
        self.seq_step = 0
        self.log_v_dist_w = None
        self.v_dist_w = None
        self.map_uncoverage_w = None
        self.current_segment = None
        self.presenter = Presenter()

        self.actprof = SimpleProfiler(print=ACTPROF, torch_sync=ACTPROF)

    def make_picklable(self):
        self.stage1_visitation_prediction.make_picklable()
        self.stage2_action_generation.make_picklable()

    def load_models_from_file(self, stage1_file=None, stage2_file=None):
        if self.real_drone:
            stage1_file = stage1_file or self.wrapper_params.get(
                "stage1_file_real")
        else:
            stage1_file = stage1_file or self.wrapper_params.get(
                "stage1_file_sim")
        stage2_file = stage2_file or self.wrapper_params.get("stage2_file")
        if stage1_file:
            print("PVNWrapper: Loading Stage 1")
            try:
                load_pytorch_model(self.stage1_visitation_prediction,
                                   stage1_file)
            except RuntimeError as e:
                print(f"Couldn't load Stage1 without namespace: {e}")
                load_pytorch_model(self.stage1_visitation_prediction,
                                   stage1_file,
                                   namespace="stage1_visitation_prediction")
        if stage2_file:
            print("PVNWrapper: Loading Stage 2")
            try:
                load_pytorch_model(self.stage2_action_generation, stage2_file)
            except RuntimeError as e:
                print(f"Couldn't load Stage2 without namespace: {e}")
                load_pytorch_model(self.stage2_action_generation,
                                   stage2_file,
                                   namespace="stage2_action_generation")

    def parameters(self, recurse=True):
        #print("WARNING: RETURNING STAGE 2 PARAMS AS WRAPPER PARAMS")
        return self.stage2_action_generation.parameters(recurse)

    # Policy state is whatever needs to be updated during RL training.
    # Right now we only update the stage 2 weights.
    def get_policy_state(self):
        return self.stage2_action_generation.state_dict()

    def set_policy_state(self, state):
        self.stage2_action_generation.load_state_dict(state)

    def set_static_state(self, state):
        self.stage1_visitation_prediction.load_state_dict(state)

    def get_static_state(self):
        return self.stage1_visitation_prediction.state_dict()

    def init_weights(self):
        self.stage1_visitation_prediction.init_weights()
        self.stage2_action_generation.init_weights()
        self.load_models_from_file()

    def reset(self):
        self.stage1_visitation_prediction.reset()
        self.stage2_action_generation.reset()
        self.prev_instruction = None
        self.start_cam_poses = None
        self.log_v_dist_w = None
        self.v_dist_w = None
        self.map_uncoverage_w = None
        self.map_coverage_reward.reset()
        self.visitation_reward.reset()
        self.wd_visitation_and_exploration_reward.reset()

    def start_sequence(self):
        self.seq_step = 0
        self.reset()

    def start_segment_rollout(self, env_id, set_idx, seg_idx):
        self.current_segment = [env_id, set_idx, seg_idx]
        self.start_sequence()

    def cam_poses_from_states(self, states):
        cam_pos = states[:, 9:12]
        cam_rot = states[:, 12:16]
        pose = Pose(cam_pos, cam_rot)
        return pose

    def drone_poses_from_states(self, states):
        drn_pos = states[:, 0:3]
        drn_rot_euler = states[:, 3:6].detach().cpu().numpy()
        quats = [euler2quat(a[0], a[1], a[2]) for a in drn_rot_euler]
        quats = torch.from_numpy(np.asarray(quats)).to(drn_pos.device)
        pose = Pose(drn_pos, quats)
        return pose

    def poses_from_states(self, states):
        USE_DRONE_POSES = True
        if USE_DRONE_POSES:
            return self.drone_poses_from_states(states)
        else:
            return self.cam_poses_from_states(states)

    def make_map_boundary(self):
        mapsize = self.s1_params["global_map_size"]
        boundary = torch.zeros([1, 1, mapsize, mapsize])
        boundary[:, :, 0, :] = 1.0
        boundary[:, :, mapsize - 1, :] = 1.0
        boundary[:, :, :, 0] = 1.0
        boundary[:, :, :, mapsize - 1] = 1.0
        return boundary

    def calc_intrinsic_rewards(self, next_state, action, done, first):
        if self.v_dist_w is None or self.map_uncoverage_w is None:
            raise ValueError(
                "Computing intrinsic reward prior to any rollouts!")
        else:
            states_np = next_state.state[np.newaxis, :]
            states = torch.from_numpy(states_np)
            cam_pos = states[:, 0:12]

            if self.s1_params.get(
                    "clip_observability") and self.wrapper_params.get(
                        "wasserstein_reward"):
                visitation_reward, stop_reward, exploration_reward, stop_oob_reward, stop_p_reward = self.wd_visitation_and_exploration_reward(
                    self.v_dist_w, cam_pos, action, done, first)

            elif self.s1_params.get("clip_observability"):
                visitation_reward, stop_reward, exploration_reward = self.visitation_and_exploration_reward(
                    self.v_dist_w, self.goal_oob_prob_w, cam_pos, action, done)
            else:
                visitation_reward, stop_reward = self.visitation_reward(
                    self.v_dist_w, cam_pos, action, done)
                exploration_reward = 0.0

            if self.wrapper_params.get("explore_reward_only"):
                visitation_reward = 0.0
                stop_reward = 0.0

            negative_per_step_reward = -self.wrapper_params["wd_reward"][
                "step_alpha"]
            action_oob_reward = self.action_oob_reward.get_reward(action)

            return {
                "visitation_reward": visitation_reward,
                "stop_reward": stop_reward,
                "exploration_reward": exploration_reward,
                "negative_per_step_reward": negative_per_step_reward,
                "action_oob_reward": action_oob_reward,
                "stop_oob_reward": stop_oob_reward,
                "stop_p_reward": stop_p_reward
            }

    def states_to_torch(self, state):
        states_np = state.state[np.newaxis, :]
        images_np = state.image[np.newaxis, :]
        images_np = standardize_images(images_np, out_np=True)
        images_fpv = torch.from_numpy(images_np).float()
        states = torch.from_numpy(states_np)
        return states, images_fpv

    def build_map_structured_input(self, map_uncoverage_w, cam_poses):
        map_uncoverage_r, _ = self.map_transformer_w_to_r(
            map_uncoverage_w, None, cam_poses)
        if self.s2_params["use_map_boundary"]:
            # Change device if necessary
            self.map_boundary = self.map_boundary.to(map_uncoverage_r.device)
            batch_size = map_uncoverage_w.shape[0]
            map_boundary_r, _ = self.map_transformer_w_to_r(
                self.map_boundary.repeat([batch_size, 1, 1, 1]), None,
                cam_poses)
            structured_map_info_r = torch.cat(
                [map_uncoverage_r, map_boundary_r], dim=1)
        else:
            structured_map_info_r = map_uncoverage_r

        batch_size = map_uncoverage_w.shape[0]
        struct_info_w = torch.cat([
            map_uncoverage_w,
            self.map_boundary.repeat([batch_size, 1, 1, 1])
        ])
        return structured_map_info_r, struct_info_w

    def make_viz_dict_from_stage1_internals(self):
        F_C = self.stage1_visitation_prediction.tensor_store.get_latest_input(
            "fpv_features")
        F_W = self.stage1_visitation_prediction.tensor_store.get_latest_input(
            "F_w")
        M_W = self.stage1_visitation_prediction.tensor_store.get_latest_input(
            "M_w")
        SM_W = self.stage1_visitation_prediction.tensor_store.get_latest_input(
            "SM_w")
        S_W = self.stage1_visitation_prediction.tensor_store.get_latest_input(
            "S_W_select")
        R_W = self.stage1_visitation_prediction.tensor_store.get_latest_input(
            "R_W_select")
        return {
            "F_C": F_C.detach().cpu().numpy(),
            "F_W": F_W.detach().cpu().numpy(),
            "M_W": M_W.detach().cpu().numpy(),
            "S_W": S_W.detach().cpu().numpy(),
            "R_W": R_W.detach().cpu().numpy(),
            "SM_W": SM_W.detach().cpu().numpy()
        }

    def get_action(self, state, instruction, sample=False, rl_rollout=False):
        """
        Given a DroneState (from PomdpInterface) and instruction, produce a numpy 4D action (x, y, theta, pstop)
        :param state: DroneState object with the raw image from the simulator
        :param instruction: Tokenized instruction given the corpus
        :param sample: (Only applies if self.rl): If true, sample action from action distribution. If False, take most likely action.
        :return:
        """
        self.eval()

        states, images_fpv = self.states_to_torch(state)

        first_step = True
        if instruction == self.prev_instruction:
            first_step = False
        if first_step:
            self.reset()
            self.start_cam_poses = self.cam_poses_from_states(states)
            if self.rviz is not None:
                dbg_instr = "\n".join(Presenter().split_lines(
                    debug_untokenize_instruction(instruction), maxchars=45))
                self.rviz.publish_instruction_text("instruction", dbg_instr)

        self.prev_instruction = instruction
        self.seq_step += 1

        instr_len = [len(instruction)] if instruction is not None else None
        instructions = torch.LongTensor(instruction).unsqueeze(0)
        plan_now = self.seq_step % self.s1_params[
            "plan_every_n_steps"] == 0 or first_step

        # Run stage1 visitation prediction
        # TODO: There's a bug here where we ignore images between planning timesteps. That's why must plan every timestep
        if plan_now or True:
            device = next(self.parameters()).device
            images_fpv = images_fpv.to(device)
            states = states.to(device)
            instructions = instructions.to(device)
            self.start_cam_poses = self.start_cam_poses.cuda(device)

            self.actprof.tick("start")
            #print("Planning for: " + debug_untokenize_instruction(list(instructions[0].detach().cpu().numpy())))
            self.log_v_dist_w, v_dist_w_poses, rl_outputs = self.stage1_visitation_prediction(
                images_fpv,
                states,
                instructions,
                instr_len,
                plan=[True],
                firstseg=[first_step],
                noisy_start_poses=self.start_cam_poses,
                start_poses=self.start_cam_poses,
                select_only=True,
                rl=True,
                noshow=True)
            self.actprof.tick("stage1")
            self.map_uncoverage_w = rl_outputs["map_uncoverage_w"]
            self.v_dist_w = self.log_v_dist_w.softmax()
            # TODO: Fix
            if self.rviz:  #a.k.a False
                v_dist_w_np = self.v_dist_w.inner_distribution[0].data.cpu(
                ).numpy().transpose(1, 2, 0)
                # expand to 0-1 range
                v_dist_w_np[:, :, 0] /= (np.max(v_dist_w_np[:, :, 0]) + 1e-10)
                v_dist_w_np[:, :, 1] /= (np.max(v_dist_w_np[:, :, 1]) + 1e-10)
                self.rviz.publish_map("visitation_dist", v_dist_w_np,
                                      self.s1_params["world_size_m"])

        # Transform to robot reference frame
        drn_poses = self.poses_from_states(states)
        # Log-distributions CANNOT be transformed - the transformer fills empty space with zeroes, which makes sense for
        # probability distributions, but makes no sense for likelihood scores
        x = self.v_dist_w.inner_distribution
        xr, r_poses = self.map_transformer_w_to_r(x, None, drn_poses)
        v_dist_r = Partial2DDistribution(xr, self.v_dist_w.outer_prob_mass)

        structured_map_info_r, map_info_w = self.build_map_structured_input(
            self.map_uncoverage_w, drn_poses)

        if self.oracle_stage1:
            # Ground truth visitation distributions (in start and global frames)
            assert self.current_segment is not None, "start_segment_rollout must be called before rolling out model that uses ground truth"
            env_id, set_idx, seg_idx = self.current_segment
            v_dist_w_gt = aup.resolve_and_get_ground_truth_static_global(
                env_id, set_idx, seg_idx, self.s1_params["global_map_size"],
                self.s1_params["world_size_px"]).to(images_fpv.device)
            v_dist_r_ground_truth_select, poses_r = self.map_transformer_w_to_r(
                v_dist_w_gt, None, drn_poses)
            map_uncoverage_r = structured_map_info_r[:, 0, :, :]
            # PVNv2: Mask the visitation distributions according to observability thus far:
            if self.s1_params["clip_observability"]:
                v_dist_r_gt_masked = Partial2DDistribution.from_distribution_and_mask(
                    v_dist_r_ground_truth_select, 1 - map_uncoverage_r)
            # PVNv1: Have P(oob)=0, and use unmasked ground-truth visitation distributions
            else:
                v_dist_r_gt_masked = Partial2DDistribution(
                    v_dist_r_ground_truth_select,
                    torch.zeros_like(v_dist_r_ground_truth_select[:, :, 0, 0]))
            v_dist_r = v_dist_r_gt_masked

        if False:
            Presenter().show_image(structured_map_info_r,
                                   "map_struct",
                                   scale=4,
                                   waitkey=1)
            v_dist_r.show("v_dist_r", scale=4, waitkey=1)

        # Run stage2 action generation
        if self.rl:
            self.actprof.tick("pipes")
            # If RL, stage 2 outputs distributions over actions (following torch.distributions API)
            xvel_dist, yawrate_dist, stop_dist, value = self.stage2_action_generation(
                v_dist_r, structured_map_info_r, eval=True)

            self.actprof.tick("stage2")
            if sample:
                xvel, yawrate, stop = self.stage2_action_generation.sample_action(
                    xvel_dist, yawrate_dist, stop_dist)
            else:
                xvel, yawrate, stop = self.stage2_action_generation.mode_action(
                    xvel_dist, yawrate_dist, stop_dist)

            self.actprof.tick("sample")
            xvel_logprob, yawrate_logprob, stop_logprob = self.stage2_action_generation.action_logprob(
                xvel_dist, yawrate_dist, stop_dist, xvel, yawrate, stop)

            xvel = xvel.detach().cpu().numpy()
            yawrate = yawrate.detach().cpu().numpy()
            stop = stop.detach().cpu().numpy()
            xvel_logprob = xvel_logprob.detach()
            yawrate_logprob = yawrate_logprob.detach()
            stop_logprob = stop_logprob.detach()
            if value is not None:
                value = value.detach()  #.cpu().numpy()

            if self.s2_params.get("use_stop_threshold"):
                stop_prob = stop_dist.probs.detach().item()
                stop = 1 if stop_prob > self.s2_params.get(
                    "stop_threshold") else 0
                stop = np.ones_like(xvel) * stop
                print("stop prob: ", stop_prob, " -> ", stop)

            # Add an empty column for sideways velocity
            act = np.concatenate([xvel, np.zeros(xvel.shape), yawrate, stop])
            # This will be needed to compute rollout statistics later on
            #v_dist_w = self.visitation_softmax(self.log_v_dist_w, self.log_goal_oob_score)

            # Keep all the info we will need later for A2C / PPO training
            # TODO: We assume independence between velocity and stop distributions. Not true, but what ya gonna do?
            rl_data = {
                "policy_input": v_dist_r.detach(),
                "policy_input_b": structured_map_info_r[0].detach(),
                "v_dist_w": self.v_dist_w.inner_distribution[0].detach(),
                "value_pred": value[0] if value else None,
                "xvel": xvel,
                "yawrate": yawrate,
                "stop": stop,
                "xvel_logprob": xvel_logprob,
                "yawrate_logprob": yawrate_logprob,
                "stop_logprob": stop_logprob,
                "action_logprob": xvel_logprob + stop_logprob + yawrate_logprob
            }
            self.actprof.tick("end")
            self.actprof.loop()
            self.actprof.print_stats(1)
            if rl_rollout:
                return act, rl_data
            else:
                viz_data = self.make_viz_dict_from_stage1_internals()
                viz_data["v_dist_r_inner"] = v_dist_r.inner_distribution[
                    0].detach().cpu().numpy()
                viz_data["v_dist_r_outer"] = v_dist_r.outer_prob_mass[
                    0].detach().cpu().numpy()
                viz_data["v_dist_w_inner"] = self.v_dist_w.inner_distribution[
                    0].detach().cpu().numpy()
                viz_data["v_dist_w_outer"] = self.v_dist_w.outer_prob_mass[
                    0].detach().cpu().numpy()
                viz_data["map_struct"] = structured_map_info_r[0].detach().cpu(
                ).numpy()
                viz_data["BM_W"] = map_info_w[0].detach().cpu().numpy()
                return act, viz_data

        else:
            raise NotImplemented(
                "Non-RL learning mode no longer support it - just use RL learning mode as it is more general!"
            )

    def unbatch(self, batch):
        # Inputs
        states = self.stage2_action_generation.cuda_var(batch["states"][0])
        seq_len = len(states)
        firstseg_mask = batch["firstseg_mask"][
            0]  # True for every timestep that is a new instruction segment
        plan_mask = batch["plan_mask"][
            0]  # True for every timestep that we do visitation prediction
        actions = self.stage2_action_generation.cuda_var(batch["actions"][0])

        actions_select = self.stage2_action_generation.batch_select.one(
            actions, plan_mask, actions.device)

        # Ground truth visitation distributions (in start and global frames)
        v_dist_w_ground_truth_select = self.stage2_action_generation.cuda_var(
            batch["traj_ground_truth"][0])
        poses = self.poses_from_states(states)
        poses_select = self.stage2_action_generation.batch_select.one(
            poses, plan_mask, actions.device)
        v_dist_r_ground_truth_select, poses_r = self.map_transformer_w_to_r(
            v_dist_w_ground_truth_select, None, poses_select)

        #Presenter().show_image(v_dist_w_ground_truth_select.detach().cpu()[0,0], "v_dist_w_ground_truth_select", waitkey=1, scale=4)
        #Presenter().show_image(v_dist_r_ground_truth_select.detach().cpu()[0,0], "v_dist_r_ground_truth_select", waitkey=1, scale=4)

        return states, actions_select, v_dist_r_ground_truth_select, poses_select, plan_mask, firstseg_mask

    # Forward pass for training (with batch optimizations
    def sup_loss_on_batch(self, batch, eval, halfway=False):
        self.reset()
        states, actions_gt_select, v_dist_r_ground_truth_select, poses_select, plan_mask, firstseg_mask = self.unbatch(
            batch)
        images, states, instructions, instr_len, plan_mask, firstseg_mask, \
         start_poses, noisy_start_poses, metadata = self.stage1_visitation_prediction.unbatch(batch, halfway=halfway)
        batch_size = images.shape[0]

        # ----------------------------------------------------------------------------
        with torch.no_grad():
            map_uncoverage_w = self.stage1_visitation_prediction(
                images,
                states,
                instructions,
                instr_len,
                plan=plan_mask,
                firstseg=firstseg_mask,
                noisy_start_poses=start_poses,
                start_poses=start_poses,
                select_only=True,
                halfway="observability")

        # ----------------------------------------------------------------------------
        poses = self.poses_from_states(states)
        structured_map_info_r = self.build_map_structured_input(
            map_uncoverage_w, poses)
        map_uncoverage_r = structured_map_info_r[:, 0, :, :]

        v_dist_r_gt_masked = Partial2DDistribution.from_distribution_and_mask(
            v_dist_r_ground_truth_select, 1 - map_uncoverage_r)

        if False:
            for i in range(batch_size):
                v_dist_r_gt_masked.show("v_dist_r_masked",
                                        scale=4,
                                        waitkey=True,
                                        idx=i)

        xvel_dist, yawrate_dist, stop_dist, value_pred = self.stage2_action_generation(
            v_dist_r_gt_masked, structured_map_info_r, eval=False)
        xvel_logprob = xvel_dist.log_probs(actions_gt_select[:, 0])
        yawrate_logprob = yawrate_dist.log_probs(actions_gt_select[:, 2])
        # TODO: Figure out why this doesn't already sum
        stop_logprob = stop_dist.log_probs(actions_gt_select[:, 3]).sum()
        total_logprob = xvel_logprob + yawrate_logprob + stop_logprob

        avg_logprob = total_logprob / batch_size
        avg_xvel_logprob = xvel_logprob / batch_size
        avg_yawrate_logprob = yawrate_logprob / batch_size
        avg_stop_logprob = stop_logprob / batch_size

        squared_xvel_dst = ((xvel_dist.mean -
                             actions_gt_select[:, 0])**2).mean()
        squared_yawrate_dst = ((yawrate_dist.mean -
                                actions_gt_select[:, 2])**2).mean()

        #action_loss = -avg_stop_logprob + squared_xvel_dst + squared_yawrate_dst
        action_loss = -avg_stop_logprob + squared_xvel_dst + squared_yawrate_dst

        prefix = self.stage2_action_generation.model_name + ("/eval" if eval
                                                             else "/train")
        self.stage2_action_generation.writer.add_scalar(
            prefix + "/action_loss",
            action_loss.data.cpu().item(),
            self.stage2_action_generation.get_iter())
        self.stage2_action_generation.writer.add_scalar(
            prefix + "/avg_logprob",
            avg_logprob.data.cpu().item(),
            self.stage2_action_generation.get_iter())
        self.stage2_action_generation.writer.add_scalar(
            prefix + "/avg_xvel_logprob",
            avg_xvel_logprob.data.cpu().item(),
            self.stage2_action_generation.get_iter())
        self.stage2_action_generation.writer.add_scalar(
            prefix + "/avg_yawrate_logprob",
            avg_yawrate_logprob.data.cpu().item(),
            self.stage2_action_generation.get_iter())
        self.stage2_action_generation.writer.add_scalar(
            prefix + "/avg_stop_logprob",
            avg_stop_logprob.data.cpu().item(),
            self.stage2_action_generation.get_iter())
        self.stage2_action_generation.writer.add_scalar(
            prefix + "/squared_xvel_dst",
            squared_xvel_dst.data.cpu().item(),
            self.stage2_action_generation.get_iter())
        self.stage2_action_generation.writer.add_scalar(
            prefix + "/squared_yawrate_dst",
            squared_yawrate_dst.data.cpu().item(),
            self.stage2_action_generation.get_iter())

        #self.writer.add_dict(prefix, get_current_meters(), self.get_iter())
        self.stage2_action_generation.inc_iter()
        return action_loss, None

    def get_dataset(self,
                    data=None,
                    envs=None,
                    domain=None,
                    dataset_names=None,
                    dataset_prefix=None,
                    eval=False):
        return self.stage1_visitation_prediction.get_dataset(
            data=data,
            envs=envs,
            domain=domain,
            dataset_names=dataset_names,
            dataset_prefix=dataset_prefix,
            eval=eval)
Exemplo n.º 12
0
class PVN_Stage2_Bidomain(nn.Module):

    def __init__(self, run_name="", model_instance_name="only"):
        super(PVN_Stage2_Bidomain, self).__init__()
        self.model_name = "pvn_stage2"
        self.run_name = run_name
        self.instance_name = model_instance_name
        self.writer = LoggingSummaryWriter(log_dir=f"{get_logging_dir()}/runs/{run_name}/{self.instance_name}")

        self.params_s1 = get_current_parameters()["ModelPVN"]["Stage1"]
        self.params = get_current_parameters()["ModelPVN"]["Stage2"]

        self.prof = SimpleProfiler(torch_sync=PROFILE, print=PROFILE)
        self.iter = nn.Parameter(torch.zeros(1), requires_grad=False)

        self.tensor_store = KeyTensorStore()

        # Common
        # --------------------------------------------------------------------------------------------------------------
        self.map_transform_w_to_r = MapTransformer(source_map_size=self.params_s1["global_map_size"],
                                                       dest_map_size=self.params_s1["local_map_size"],
                                                       world_size_px=self.params_s1["world_size_px"],
                                                       world_size_m=self.params_s1["world_size_m"])

        self.map_to_action = CroppedMapToActionTriplet(
            map_channels=self.params["map_to_act_channels"],
            map_size=self.params_s1["local_map_size"],
            manual=False,
            path_only=self.params["action_in_path_only"],
            recurrence=self.params["action_recurrence"])

        self.batch_select = MapBatchSelect()

        self.action_loss = ActionLoss()

        self.env_id = None
        self.seg_idx = None
        self.prev_instruction = None
        self.seq_step = 0
        self.get_act_start_pose = None
        self.gt_labels = None

    def steal_cross_domain_modules(self, other_self):
        self.map_to_action = other_self.map_to_action

    def both_domain_parameters(self, other_self):
        # This function iterates and yields parameters from this module and the other module, but does not yield
        # shared parameters twice.
        # Since all the parameters are shared, it's fine to just iterate over this module's parameters
        for p in self.parameters():
            yield p
        return

    def make_picklable(self):
        self.writer = DummySummaryWriter()

    def get_iter(self):
        return int(self.iter.data[0])

    def inc_iter(self):
        self.iter += 1

    def init_weights(self):
        self.map_to_action.init_weights()

    def reset(self):
        self.tensor_store.reset()

    def setEnvContext(self, context):
        print("Set env context to: " + str(context))
        self.env_id = context["env_id"]

    def start_segment_rollout(self):
        import rollout.run_metadata as md
        m_size = self.params["local_map_size"]
        w_size = self.params["world_size_px"]
        self.gt_labels = get_top_down_ground_truth_static_global(
            md.ENV_ID, md.START_IDX, md.END_IDX, m_size, m_size, w_size, w_size)
        self.seg_idx = md.SEG_IDX
        self.gt_labels = self.maybe_cuda(self.gt_labels)
        if self.params["clear_history"]:
            self.start_sequence()

    def cam_poses_from_states(self, states):
        cam_pos = states[:, 9:12]
        cam_rot = states[:, 12:16]
        pose = Pose(cam_pos, cam_rot)
        return pose

    def forward(self, visit_dist_r, structured_map_info_r, firstseg=None, eval=False):
        structured_map_info_r = None # not used in CoRL 2018 model

        action_scores = self.map_to_action(visit_dist_r, None, fistseg_mask=firstseg)

        self.prof.tick("map_to_action")
        xvel_mean = action_scores[:, 0]
        xvel_std = F.softplus(action_scores[:, 2])
        xvel_dist = FixedNormal(xvel_mean, xvel_std)

        yawrate_mean = action_scores[:, 3]
        if eval and self.params.get("test_time_amplifier"):
            yawrate_mean = yawrate_mean * self.params["test_time_amplifier"]

        yawrate_std = F.softplus(action_scores[:, 5])
        yawrate_dist = FixedNormal(yawrate_mean, yawrate_std)

        # Skew it towards not stopping in the beginning
        stop_logits = action_scores[:, 6]
        stop_dist = FixedBernoulli(logits=stop_logits)

        # TODO: This PVNv1 CoRL 2018 head is incompatible with Actor-critic  training for now
        value = None
        return xvel_dist, yawrate_dist, stop_dist, value

    def sample_action(self, xvel_dist, yawrate_dist, stop_dist):
        # Sample action from the predicted distributions
        xvel_sample = xvel_dist.sample()
        yawrate_sample = yawrate_dist.sample()
        stop = stop_dist.sample()
        return xvel_sample, yawrate_sample, stop

    def mode_action(self, xvel_dist, yawrate_dist, stop_dist):
        xvel_sample = xvel_dist.mode()
        yawrate_sample = yawrate_dist.mode()
        stop = stop_dist.mean
        return xvel_sample, yawrate_sample, stop

    def action_logprob(self, xvel_dist, yawrate_dist, stop_dist, xvel, yawrate, stop):
        xvel_logprob = xvel_dist.log_prob(xvel)
        yawrate_logprob = yawrate_dist.log_prob(yawrate)
        stop_logprob = stop_dist.log_prob(stop)
        return xvel_logprob, yawrate_logprob, stop_logprob

    def cuda_var(self, tensor):
        return tensor.to(next(self.parameters()).device)

    def unbatch(self, batch):
        # Inputs
        states = self.cuda_var(batch["states"][0])
        seq_len = len(states)
        firstseg_mask = batch["firstseg_mask"][0]          # True for every timestep that is a new instruction segment
        plan_mask = batch["plan_mask"][0]                  # True for every timestep that we do visitation prediction
        actions = self.cuda_var(batch["actions"][0])

        actions_select = self.batch_select.one(actions, plan_mask, actions.device)

        # Ground truth visitation distributions (in start and global frames)
        v_dist_w_ground_truth_select = self.cuda_var(batch["traj_ground_truth"][0])
        cam_poses = self.cam_poses_from_states(states)
        cam_poses_select = self.batch_select.one(cam_poses, plan_mask, actions.device)
        v_dist_r_ground_truth_select, poses_r = self.map_transform_w_to_r(v_dist_w_ground_truth_select, None, cam_poses_select)
        self.tensor_store.keep_inputs("v_dist_w_ground_truth_select", v_dist_w_ground_truth_select)
        self.tensor_store.keep_inputs("v_dist_r_ground_truth_select", v_dist_r_ground_truth_select)

        Presenter().show_image(v_dist_w_ground_truth_select.detach().cpu()[0,0], "v_dist_w_ground_truth_select", waitkey=1, scale=4)
        Presenter().show_image(v_dist_r_ground_truth_select.detach().cpu()[0,0], "v_dist_r_ground_truth_select", waitkey=1, scale=4)

        return states, actions_select, v_dist_r_ground_truth_select, cam_poses_select, plan_mask, firstseg_mask

    # TODO: This code heavily overlaps with that in ModelPvnWrapperBidomain.
    # PVN Wrapper should be used for training Stage 2 and not this
    # Forward pass for training Stage 2 only (with batch optimizations)
    def sup_loss_on_batch(self, batch, eval, halfway=False):
        raise ValueError("This code still works but is deprecated. Train Stage2 using PVNWrapper instead (it can compute observability of maps etc)")
        self.prof.tick("out")
        action_loss_total = self.cuda_var(torch.zeros([1]))
        if batch is None:
            print("Skipping None Batch")
            return action_loss_total

        self.reset()
        states, actions_gt_select, v_dist_r_ground_truth_select, cam_poses_select, plan_mask, firstseg_mask = self.unbatch(batch)
        count = 0
        self.prof.tick("inputs")

        batch_size = actions_gt_select.shape[0]

        # ----------------------------------------------------------------------------
        xvel_dist, yawrate_dist, stop_dist, _ = self(v_dist_r_ground_truth_select, firstseg_mask)

        stop_logprob = stop_dist.log_probs(actions_gt_select[:,3]).sum()
        avg_stop_logprob = stop_logprob / batch_size
        squared_xvel_dst = ((xvel_dist.mean - actions_gt_select[:,0]) ** 2).sum()
        squared_yawrate_dst = ((yawrate_dist.mean - actions_gt_select[:,2]) ** 2).sum()
        action_loss = -stop_logprob + squared_xvel_dst + squared_yawrate_dst

        self.prof.tick("loss")

        prefix = self.model_name + ("/eval" if eval else "/train")
        self.writer.add_scalar(prefix + "/action_loss", action_loss.data.cpu().item(), self.get_iter())
        self.writer.add_scalar(prefix + "/x_sqrdst", squared_xvel_dst.data.cpu().item(), self.get_iter())
        self.writer.add_scalar(prefix + "/yaw_sqrdst", (squared_yawrate_dst / batch_size).data.cpu().item(), self.get_iter())
        self.writer.add_scalar(prefix + "/stop_logprob", (avg_stop_logprob / batch_size).data.cpu().item(), self.get_iter())
        self.writer.add_dict(prefix, get_current_meters(), self.get_iter())

        self.inc_iter()

        self.prof.tick("summaries")
        self.prof.loop()
        self.prof.print_stats(1)

        return action_loss, self.tensor_store

    def get_dataset(self, data=None, envs=None, domain=None, dataset_names=None, dataset_prefix=None, eval=False):
        # TODO: Maybe use eval here
        data_sources = []
        data_sources.append(aup.PROVIDER_TRAJECTORY_GROUND_TRUTH_STATIC)
        return SegmentDataset(data=data, env_list=envs, domain=domain, dataset_names=dataset_names, dataset_prefix=dataset_prefix, aux_provider_names=data_sources, segment_level=True)
Exemplo n.º 13
0
class TrainerRL:
    def __init__(self, params, save_rollouts_to_dataset="", device=None):
        self.iterations_per_epoch = params.get("iterations_per_epoch", 1)
        self.test_iterations_per_epoch = params.get(
            "test_iterations_per_epoch", 1)
        self.num_workers = params.get("num_workers")
        self.num_rollouts_per_iter = params.get("num_rollouts_per_iter")
        self.model_name = params.get("model") or params.get("rl_model")
        self.init_model_file = params.get("model_file")
        self.num_steps = params.get("trajectory_len")
        self.device = device

        self.summary_every_n = params.get("plot_every_n")

        self.roller = SimpleParallelPolicyRoller(
            num_workers=self.num_workers,
            device=self.device,
            policy_name=self.model_name,
            policy_file=self.init_model_file,
            dataset_save_name=save_rollouts_to_dataset)

        self.rollout_sampler = RolloutSampler(self.roller)

        # This should load it's own weights from file based on
        self.full_model, _ = load_model(self.model_name)
        self.full_model = self.full_model.to(self.device)
        self.actor_critic = self.full_model.stage2_action_generation
        # Train in eval mode to disable dropout
        #self.actor_critic.eval()
        self.full_model.stage1_visitation_prediction.eval()
        self.writer = LoggingSummaryWriter(
            log_dir=f"{get_logging_dir()}/runs/{params['run_name']}/ppo")

        self.global_step = 0
        self.stage1_updates = 0

        clip_param = params.get("clip")
        num_mini_batch = params.get("num_mini_batch")
        value_loss_coef = params.get("value_loss_coef")
        lr = params.get("lr")
        eps = params.get("eps")
        max_grad_norm = params.get("max_grad_norm")
        use_clipped_value_loss = params.get("use_clipped_value_loss")

        self.entropy_coef = params.get("entropy_coef")
        self.entropy_schedule_epochs = params.get("entropy_schedule_epochs",
                                                  [])
        self.entropy_schedule_multipliers = params.get(
            "entropy_schedule_multipliers", [])

        self.minibatch_size = params.get("minibatch_size")

        self.use_gae = params.get("use_gae")
        self.gamma = params.get("gamma")
        self.gae_lambda = params.get("gae_lambda")
        self.intrinsic_reward_only = params.get("intrinsic_reward_only")

        self.prof = SimpleProfiler(torch_sync=PROFILE, print=PROFILE)

        print(
            f"PPO trainable parameters: {get_n_trainable_params(self.actor_critic)}"
        )
        print(
            f"PPO actor-critic all parameters: {get_n_params(self.actor_critic)}"
        )

        self.ppo = PPO(self.actor_critic,
                       clip_param=clip_param,
                       ppo_epoch=1,
                       num_mini_batch=num_mini_batch,
                       value_loss_coef=value_loss_coef,
                       entropy_coef=self.entropy_coef,
                       lr=lr,
                       eps=eps,
                       max_grad_norm=max_grad_norm,
                       use_clipped_value_loss=use_clipped_value_loss)

    def set_start_epoch(self, epoch):
        prints_per_epoch = int(self.iterations_per_epoch /
                               self.summary_every_n)
        self.global_step = epoch * prints_per_epoch

    def save_rollouts(self, rollouts, dataset_name):
        for rollout in rollouts:
            # This saves just a single segment per environment, as opposed to all segments that the oracle saves. Problem?
            if len(rollout) > 0:
                env_id = rollout[0]["env_id"]
                save_dataset(dataset_name, rollout, env_id=env_id, lock=True)

    def reload_stage1(self, module_state_dict):
        print("Reloading stage 1 model in RL trainer")
        self.full_model.stage1_visitation_prediction.load_state_dict(
            module_state_dict)
        print("Reloading stage 1 model in rollout sampler")
        self.rollout_sampler.update_stage1_on_workers(
            self.full_model.stage1_visitation_prediction)
        print("Done reloading stage1")
        self.stage1_updates += 1

    def train_epoch(self, epoch_num, eval=False, envs="train"):

        rewards = []
        returns = []
        value_losses = []
        action_losses = []
        dist_entropies = []
        value_preds = []
        vels = []
        stopprobs = []

        step_rollout_metrics = {}

        # Update entropy coefficient by applying scaling
        if len(self.entropy_schedule_epochs) > 0:
            scaled_entropy_coeff = self.entropy_coef
            for e_multiplier, e_epoch in zip(self.entropy_schedule_multipliers,
                                             self.entropy_schedule_epochs):
                if epoch_num > e_epoch:
                    scaled_entropy_coeff = e_multiplier * self.entropy_coef
                else:
                    break
            self.ppo.set_entropy_coef(scaled_entropy_coeff)
        else:
            scaled_entropy_coeff = self.entropy_coef

        self.prof.tick("out")

        # TODO: Make the 100 a parameter
        iterations = self.test_iterations_per_epoch if eval else self.iterations_per_epoch

        for i in range(iterations):
            policy_state = self.full_model.get_policy_state()
            device = policy_state[next((iter(policy_state)))].device
            print("TrainerRL: Sampling N Rollouts")
            rollouts = self.rollout_sampler.sample_n_rollouts(
                self.num_rollouts_per_iter,
                policy_state,
                sample=not eval,
                envs=envs)
            #if save_rollouts_to_dataset is not None:
            #    self.save_rollouts(rollouts, save_rollouts_to_dataset)

            self.prof.tick("sample_rollouts")
            print("TrainerRL: Calculating Rollout Metrics")
            i_rollout_metrics = calc_rollout_metrics(rollouts)
            step_rollout_metrics = dictlist_append(step_rollout_metrics,
                                                   i_rollout_metrics)

            assert len(rollouts) > 0

            # Convert our rollouts to the format used by Ilya Kostrikov
            device = next(self.full_model.parameters()).device
            rollout_storage = RolloutStorage.from_rollouts(
                rollouts,
                device=device,
                intrinsic_reward_only=self.intrinsic_reward_only)
            next_value = None

            rollout_storage.compute_returns(next_value, self.use_gae,
                                            self.gamma, self.gae_lambda, False)

            self.prof.tick("compute_storage")

            reward = rollout_storage.rewards.mean().detach().cpu().item()
            avg_return = (((rollout_storage.returns[1:] *
                            rollout_storage.masks[:-1]).sum() +
                           rollout_storage.returns[0]) /
                          (rollout_storage.masks[:-1].sum() + 1)).cpu().item()
            avg_value = rollout_storage.value_preds.mean().detach().cpu().item(
            )
            avg_vel = rollout_storage.actions[:, 0,
                                              0:3].detach().cpu().numpy().mean(
                                                  axis=0, keepdims=False)
            avg_stopprob = rollout_storage.actions[:, 0, 3].mean().detach(
            ).cpu().item()

            print("TrainerRL: PPO Update")
            if not eval:
                value_loss, action_loss, dist_entropy, avg_ratio = self.ppo.update(
                    rollout_storage, self.global_step, self.minibatch_size)
                print(
                    f"Iter: {i}/{iterations}, Value loss: {value_loss}, Action loss: {action_loss}, Entropy: {dist_entropy}, Reward: {reward}"
                )
            else:
                value_loss = 0
                action_loss = 0
                dist_entropy = 0
                avg_ratio = 0

            self.prof.tick("ppo_update")
            print("TrainerRL: PPO Update Done")

            returns.append(avg_return)
            rewards.append(reward)
            value_losses.append(value_loss)
            action_losses.append(action_loss)
            dist_entropies.append(dist_entropy)
            value_preds.append(avg_value)
            vels.append(avg_vel)
            stopprobs.append(avg_stopprob)

            if i % self.summary_every_n == self.summary_every_n - 1:
                avg_reward = np.mean(
                    np.asarray(rewards[-self.summary_every_n:]))
                avg_return = np.mean(
                    np.asarray(returns[-self.summary_every_n:]))
                avg_vel = np.mean(np.asarray(vels[-self.summary_every_n:]),
                                  axis=0,
                                  keepdims=False)

                metrics = {
                    "value_loss":
                    np.mean(np.asarray(value_losses[-self.summary_every_n:])),
                    "action_loss":
                    np.mean(np.asarray(action_losses[-self.summary_every_n:])),
                    "dist_entropy":
                    np.mean(np.asarray(
                        dist_entropies[-self.summary_every_n:])),
                    "avg_value":
                    np.mean(np.asarray(value_preds[-self.summary_every_n:])),
                    "avg_vel_x":
                    avg_vel[0],
                    "avg_yaw_rate":
                    avg_vel[2],
                    "avg_stopprob":
                    np.mean(np.asarray(stopprobs[-self.summary_every_n:])),
                    "ratio":
                    avg_ratio
                }

                # Reduce average
                step_rollout_metrics = dict_map(step_rollout_metrics,
                                                lambda m: np.asarray(m).mean())

                mode = "eval" if eval else "train"

                self.writer.add_scalar(f"ppo_{mode}/reward", avg_reward,
                                       self.global_step)
                self.writer.add_scalar(f"ppo_{mode}/return", avg_return,
                                       self.global_step)
                self.writer.add_scalar(f"ppo_{mode}/stage1_updates",
                                       self.stage1_updates, self.global_step)
                self.writer.add_dict(f"ppo_{mode}/", metrics, self.global_step)
                self.writer.add_dict(f"ppo_{mode}/", step_rollout_metrics,
                                     self.global_step)
                self.writer.add_scalar(f"ppo_{mode}/scaled_entropy_coeff",
                                       scaled_entropy_coeff, self.global_step)
                step_rollout_metrics = {}

                self.global_step += 1

            self.prof.tick("logging")
            print("TrainerRL: Finished Step")

        # TODO: Remove code duplication (this was easier for now)
        avg_reward = np.mean(np.asarray(rewards))
        avg_vel = np.mean(np.asarray(vels), axis=0, keepdims=False)
        metrics = {
            "value_loss": np.mean(np.asarray(value_losses)),
            "action_loss": np.mean(np.asarray(action_losses)),
            "dist_entropy": np.mean(np.asarray(dist_entropies)),
            "avg_value": np.mean(np.asarray(value_preds)),
            "avg_vel_x": avg_vel[0],
            "avg_yaw_rate": avg_vel[2],
            "avg_stopprob": np.mean(np.asarray(stopprobs))
        }
        #pprint(metrics)

        self.prof.tick("logging")
        self.prof.loop()
        self.prof.print_stats(1)

        return avg_reward, metrics
Exemplo n.º 14
0
    def get_action(self, state, instruction, sample=False, rl_rollout=False):
        """
        Given a DroneState (from PomdpInterface) and instruction, produce a numpy 4D action (x, y, theta, pstop)
        :param state: DroneState object with the raw image from the simulator
        :param instruction: Tokenized instruction given the corpus
        :param sample: (Only applies if self.rl): If true, sample action from action distribution. If False, take most likely action.
        #TODO: Absorb corpus within model
        :return:
        """
        self.eval()

        ACTPROF = False
        actprof = SimpleProfiler(print=ACTPROF, torch_sync=ACTPROF)

        states, images_fpv = self.states_to_torch(state)

        first_step = True
        if instruction == self.prev_instruction:
            first_step = False
        if first_step:
            self.reset()
            self.start_poses = self.cam_poses_from_states(states)
            if self.rviz is not None:
                dbg_instr = "\n".join(Presenter().split_lines(
                    debug_untokenize_instruction(instruction), maxchars=45))
                self.rviz.publish_instruction_text("instruction", dbg_instr)

        self.prev_instruction = instruction
        self.seq_step += 1

        instr_len = [len(instruction)] if instruction is not None else None
        instructions = torch.LongTensor(instruction).unsqueeze(0)
        plan_now = self.seq_step % self.s1_params[
            "plan_every_n_steps"] == 0 or first_step

        # Run stage1 visitation prediction
        # TODO: There's a bug here where we ignore images between planning timesteps. That's why must plan every timestep
        if plan_now or True:
            device = next(self.parameters()).device
            images_fpv = images_fpv.to(device)
            states = states.to(device)
            instructions = instructions.to(device)
            self.start_poses = self.start_poses.to(device)

            actprof.tick("start")
            #print("Planning for: " + debug_untokenize_instruction(list(instructions[0].detach().cpu().numpy())))
            self.log_v_dist_w, v_dist_w_poses, self.log_goal_oob_score, rl_outputs = self.stage1_visitation_prediction(
                images_fpv,
                states,
                instructions,
                instr_len,
                plan=[True],
                firstseg=[first_step],
                noisy_start_poses=self.start_poses,
                start_poses=self.start_poses,
                select_only=True,
                rl=True)
            actprof.tick("stage1")

            self.map_coverage_w = rl_outputs["map_coverage_w"]
            self.map_uncoverage_w = rl_outputs["map_uncoverage_w"]
            self.v_dist_w, self.goal_oob_prob_w = self.visitation_softmax(
                self.log_v_dist_w, self.log_goal_oob_score)
            if self.rviz:
                v_dist_w_np = self.v_dist_w[0].data.cpu().numpy().transpose(
                    1, 2, 0)
                # expand to 0-1 range
                v_dist_w_np[:, :, 0] /= (np.max(v_dist_w_np[:, :, 0]) + 1e-10)
                v_dist_w_np[:, :, 1] /= (np.max(v_dist_w_np[:, :, 1]) + 1e-10)
                self.rviz.publish_map("visitation_dist", v_dist_w_np,
                                      self.s1_params["world_size_m"])

        # Transform to robot reference frame
        cam_poses = self.cam_poses_from_states(states)
        # Log-distributions CANNOT be transformed - the transformer fills empty space with zeroes, which makes sense for
        # probability distributions, but makes no sense for likelihood scores
        map_coverage_r, _ = self.map_transformer_w_to_r(
            self.map_coverage_w, None, cam_poses)
        map_uncoverage_r, _ = self.map_transformer_w_to_r(
            self.map_uncoverage_w, None, cam_poses)
        v_dist_r, r_poses = self.map_transformer_w_to_r(
            self.v_dist_w, None, cam_poses)

        # Run stage2 action generation
        if self.rl:
            actprof.tick("pipes")
            # If RL, stage 2 outputs distributions over actions (following torch.distributions API)
            xvel_dist, yawrate_dist, stop_dist, value = self.stage2_action_generation(
                v_dist_r, map_uncoverage_r, eval=True)

            actprof.tick("stage2")
            if sample:
                xvel, yawrate, stop = self.stage2_action_generation.sample_action(
                    xvel_dist, yawrate_dist, stop_dist)
            else:
                xvel, yawrate, stop = self.stage2_action_generation.mode_action(
                    xvel_dist, yawrate_dist, stop_dist)

            actprof.tick("sample")
            xvel_logprob, yawrate_logprob, stop_logprob = self.stage2_action_generation.action_logprob(
                xvel_dist, yawrate_dist, stop_dist, xvel, yawrate, stop)

            xvel = xvel.detach().cpu().numpy()
            yawrate = yawrate.detach().cpu().numpy()
            stop = stop.detach().cpu().numpy()
            xvel_logprob = xvel_logprob.detach()
            yawrate_logprob = yawrate_logprob.detach()
            stop_logprob = stop_logprob.detach()
            value = value.detach()  #.cpu().numpy()

            # Add an empty column for sideways velocity
            act = np.concatenate([xvel, np.zeros(xvel.shape), yawrate, stop])

            # This will be needed to compute rollout statistics later on
            #v_dist_w = self.visitation_softmax(self.log_v_dist_w, self.log_goal_oob_score)

            # Keep all the info we will need later for A2C / PPO training
            # TODO: We assume independence between velocity and stop distributions. Not true, but what ya gonna do?
            rl_data = {
                "policy_input": v_dist_r[0].detach(),
                "v_dist_w": self.v_dist_w[0].detach(),
                "policy_input_b": map_uncoverage_r[0].detach(),
                "value_pred": value[0],
                "xvel": xvel,
                "yawrate": yawrate,
                "stop": stop,
                "xvel_logprob": xvel_logprob,
                "yawrate_logprob": yawrate_logprob,
                "stop_logprob": stop_logprob,
                "action_logprob": xvel_logprob + stop_logprob + yawrate_logprob
            }
            actprof.tick("end")
            actprof.loop()
            actprof.print_stats(1)
            if rl_rollout:
                return act, rl_data
            else:
                return act

        else:
            action = self.stage2_action_generation(v_dist_r,
                                                   firstseg=[first_step],
                                                   eval=True)
            output_action = action.squeeze().data.cpu().numpy()
            stop_prob = output_action[3]
            output_stop = 1 if stop_prob > self.s2_params[
                "stop_threshold"] else 0
            output_action[3] = output_stop

            return output_action
Exemplo n.º 15
0
class LeakyIntegratorGlobalMap(MapTransformerBase):
    def __init__(self, source_map_size, world_in_map_size, lamda=0.2):
        super(LeakyIntegratorGlobalMap, self).__init__(source_map_size,
                                                       world_in_map_size)
        self.map_size = source_map_size
        self.world_size = world_in_map_size
        self.child_transformer = MapTransformerBase(source_map_size,
                                                    world_in_map_size)
        self.lamda = lamda

        self.prof = SimpleProfiler(torch_sync=PROFILE, print=PROFILE)
        self.map_memory = []

        self.dbg_t = None
        self.seq = 0

    def init_weights(self):
        pass

    def reset(self):
        super(LeakyIntegratorGlobalMap, self).reset()
        self.map_memory = []
        self.child_transformer.reset()
        self.seq = 0

    def cuda(self, device=None):
        MapTransformerBase.cuda(self, device)
        self.child_transformer.cuda(device)
        return self

    def dbg_write_extra(self, map, pose):
        if DebugWriter().should_write():
            map = map[0:1, 0:3]
            self.seq += 1
            # Initialize a transformer module
            if pose is not None:
                if self.dbg_t is None:
                    self.dbg_t = MapTransformerBase(self.map_size,
                                                    self.world_size)
                    if self.is_cuda:
                        self.dbg_t.cuda(self.cuda_device)

                # Transform the prediction to the global frame and write out to disk.
                self.dbg_t.set_map(map, pose)
                map_global, _ = self.dbg_t.get_map(None)
            else:
                map_global = map
            DebugWriter().write_img(map_global[0],
                                    "gif_overlaid",
                                    args={
                                        "world_size": self.world_size,
                                        "name": "sm"
                                    })

    def forward(self,
                images_w,
                coverages_w,
                add_mask=None,
                reset_mask=None,
                show=False):
        #show="li"
        self.prof.tick(".")
        batch_size = len(images_w)

        assert add_mask is None or add_mask[
            0] is not None, "The first observation in a sequence needs to be used!"

        masked_observations_w_add = self.lamda * images_w * coverages_w

        all_maps_out_w = []

        self.prof.tick("maps_to_global")

        # TODO: Draw past trajectory on an extra channel of the semantic map
        # Step 2: Integrate serially in the global frame
        for i in range(batch_size):
            if len(self.map_memory) == 0 or (reset_mask is not None
                                             and reset_mask[i]):
                new_map_w = images_w[i:i + 1]

            # Allow masking of observations
            elif add_mask is None or add_mask[i]:
                # Get the current global-frame map
                map_g = self.map_memory[-1]
                cov_w = coverages_w[i:i + 1]
                obs_cov_g = masked_observations_w_add[i:i + 1]

                # Add the observation into the map using a leaky integrator rule (TODO: Output lamda from model)
                new_map_w = (1 - self.lamda
                             ) * map_g + obs_cov_g + self.lamda * map_g * (
                                 1 - cov_w)
            else:
                new_map_w = self.map_memory[-1]

            self.map_memory.append(new_map_w)
            all_maps_out_w.append(new_map_w)

            if show != "":
                Presenter().show_image(new_map_w.data[0, 0:3],
                                       show,
                                       torch=True,
                                       scale=8,
                                       waitkey=50)

        self.prof.tick("integrate")

        # Step 3: Convert all maps to local frame
        all_maps_w = torch.cat(all_maps_out_w, dim=0)

        # Write gifs for debugging
        #self.dbg_write_extra(all_maps_w, None)

        self.prof.tick("maps_to_local")
        self.prof.loop()
        self.prof.print_stats(10)

        return all_maps_w
Exemplo n.º 16
0
class Affine2D(nn.Module):
    def __init__(self):
        super(Affine2D, self).__init__()
        self.is_cuda = False
        self.cuda_device = None

        self.prof = SimpleProfiler(torch_sync=PROFILE, print=PROFILE)

    def cuda(self, device=None):
        nn.Module.cuda(self, device)
        self.is_cuda = True
        self.cuda_device = device
        return self

    def get_pytorch_to_img_mat(self, img_size, inv=False):
        """
        Returns an affine transformation matrix that takes an image in coordinate range [-1,1] and turns it
        into an image of coordinate range [W,H]
        :param img_size: (W,H)
        :return:
        """
        # First move the image so that the origin is in the top-left corner
        # (in pytorch, the origin is in the center of the image)
        """
        t1 = np.asarray([
            [1.0, 0, 1.0],
            [0, 1.0, 1.0],
            [0, 0, 1.0]
        ])

        # Then scale the image up to the required size
        scale_w = img_size[0] / 2
        scale_h = img_size[1] / 2
        t2 = np.asarray([
            [scale_h, 0, 0],
            [0, scale_w, 0],
            [0, 0, 1]
        ])
        """

        # First scale the image to pixel coordinates
        scale_w = img_size[0] / 2
        scale_h = img_size[1] / 2

        t1 = np.asarray([[scale_h, 0, 0], [0, scale_w, 0], [0, 0, 1]])

        # Then move it such that the corner is at the origin
        t2 = np.asarray([[1.0, 0, scale_h], [0, 1.0, scale_w], [0, 0, 1.0]])

        T = np.dot(t2, t1)

        if inv:
            T = np.linalg.inv(T)

        T_t = np_to_tensor(T, cuda=False)

        return T_t

    def img_affines_to_pytorch_cpu(self, img_affines, img_in_size, out_size):
        T_src = self.get_pytorch_to_img_mat(img_in_size, inv=False)
        Tinv_dst = self.get_pytorch_to_img_mat(out_size, inv=True)

        self.prof.tick("getmat")

        # Convert pytorch-coord image to imgage pixel coords, apply the transformation, then convert the result back.
        batch_size = img_affines.size(0)
        T_src = T_src.repeat(batch_size, 1, 1)
        Tinv_dst = Tinv_dst.repeat(batch_size, 1, 1)

        x = torch.bmm(
            img_affines, T_src
        )  # Convert pytorch coords to pixel coords and apply the transformation
        pyt_affines = torch.bmm(
            Tinv_dst, x)  # Convert the transformation back to pytorch coords

        self.prof.tick("convert")

        inverses = [torch.inverse(affine) for affine in pyt_affines]

        self.prof.tick("inverse")

        pyt_affines_inv = torch.stack(inverses, dim=0)

        self.prof.tick("stack")

        return pyt_affines_inv

    def forward(self, image, affine_mat, out_size=None):
        """
        Applies the given batch of affine transformation matrices to the batch of images
        :param image:   batch of images to transform
        :param affine:  batch of affine matrices to apply. Specified in image coordinates (internally converted to pytorch coords)
        :return:        batch of images of same size as the input batch with the affine matrix having been applied
        """

        batch_size = image.size(0)

        self.prof.tick(".")

        # Cut off the batch and channel to get the image size as the source size
        img_size = list(image.size())[2:4]
        if out_size is None:
            out_size = img_size

        affines_pytorch = self.img_affines_to_pytorch_cpu(
            affine_mat, img_size, out_size)

        if self.is_cuda:
            affines_pytorch = affines_pytorch.cuda(self.cuda_device)

        # Build the affine grid
        grid = F.affine_grid(
            affines_pytorch[:, [0, 1], :],
            torch.Size((batch_size, 1, out_size[0], out_size[1]))).float()

        self.prof.tick("affine_grid")

        # Rotate the input image
        rot_img = F.grid_sample(image, grid, padding_mode="zeros")

        self.prof.tick("grid_sample")
        self.prof.loop()
        self.prof.print_stats(10)

        return rot_img
Exemplo n.º 17
0
class ModelGSMNBiDomain(nn.Module):
    def __init__(self, run_name="", model_instance_name=""):

        super(ModelGSMNBiDomain, self).__init__()
        self.model_name = "gsmn_bidomain"
        self.run_name = run_name
        self.name = model_instance_name
        if not self.name:
            self.name = ""
        self.writer = LoggingSummaryWriter(
            log_dir=f"runs/{run_name}/{self.name}")

        self.params = get_current_parameters()["Model"]
        self.aux_weights = get_current_parameters()["AuxWeights"]
        self.use_aux = self.params["UseAuxiliaries"]

        self.prof = SimpleProfiler(torch_sync=PROFILE, print=PROFILE)
        self.iter = nn.Parameter(torch.zeros(1), requires_grad=False)

        self.tensor_store = KeyTensorStore()
        self.aux_losses = AuxiliaryLosses()

        self.rviz = None
        if self.params.get("rviz"):
            self.rviz = RvizInterface(
                base_name="/gsmn/",
                map_topics=["semantic_map", "grounding_map", "goal_map"],
                markerarray_topics=["instruction"])

        # Path-pred FPV model definition
        # --------------------------------------------------------------------------------------------------------------

        self.img_to_features_w = FPVToGlobalMap(
            source_map_size=self.params["global_map_size"],
            world_size_px=self.params["world_size_px"],
            world_size_m=self.params["world_size_m"],
            res_channels=self.params["resnet_channels"],
            map_channels=self.params["feature_channels"],
            img_w=self.params["img_w"],
            img_h=self.params["img_h"],
            cam_h_fov=self.params["cam_h_fov"],
            img_dbg=IMG_DBG)

        self.map_accumulator_w = LeakyIntegratorGlobalMap(
            source_map_size=self.params["global_map_size"],
            world_size_px=self.params["world_size_px"],
            world_size_m=self.params["world_size_m"])

        # Pre-process the accumulated map to do language grounding if necessary - in the world reference frame
        if self.use_aux[
                "grounding_map"] and not self.use_aux["grounding_features"]:
            self.map_processor_a_w = LangFilterMapProcessor(
                embed_size=self.params["emb_size"],
                in_channels=self.params["feature_channels"],
                out_channels=self.params["relevance_channels"],
                spatial=False,
                cat_out=True)
        else:
            self.map_processor_a_w = IdentityMapProcessor(
                source_map_size=self.params["global_map_size"],
                world_size_px=self.params["world_size_px"],
                world_size_m=self.params["world_size_m"])

        if self.use_aux["goal_map"]:
            self.map_processor_b_r = LangFilterMapProcessor(
                embed_size=self.params["emb_size"],
                in_channels=self.params["relevance_channels"],
                out_channels=self.params["goal_channels"],
                spatial=self.params["spatial_goal_filter"],
                cat_out=self.params["cat_rel_and_goal"])
        else:
            self.map_processor_b_r = IdentityMapProcessor(
                source_map_size=self.params["local_map_size"],
                world_size_px=self.params["world_size_px"],
                world_size_m=self.params["world_size_m"])

        # Common
        # --------------------------------------------------------------------------------------------------------------

        # Sentence Embedding
        self.sentence_embedding = SentenceEmbeddingSimple(
            self.params["word_emb_size"],
            self.params["emb_size"],
            self.params["emb_layers"],
            dropout=0.0)

        self.map_transform_w_to_r = MapTransformerBase(
            source_map_size=self.params["global_map_size"],
            dest_map_size=self.params["local_map_size"],
            world_size_px=self.params["world_size_px"],
            world_size_m=self.params["world_size_m"])
        self.map_transform_r_to_w = MapTransformerBase(
            source_map_size=self.params["local_map_size"],
            dest_map_size=self.params["global_map_size"],
            world_size_px=self.params["world_size_px"],
            world_size_m=self.params["world_size_m"])

        # Output an action given the global semantic map
        if self.params["map_to_action"] == "downsample2":
            self.map_to_action = EgoMapToActionTriplet(
                map_channels=self.params["map_to_act_channels"],
                map_size=self.params["local_map_size"],
                other_features_size=self.params["emb_size"])

        elif self.params["map_to_action"] == "cropped":
            self.map_to_action = CroppedMapToActionTriplet(
                map_channels=self.params["map_to_act_channels"],
                map_size=self.params["local_map_size"])

        # Auxiliary Objectives
        # --------------------------------------------------------------------------------------------------------------

        # We add all auxiliaries that are necessary. The first argument is the auxiliary name, followed by parameters,
        # followed by variable number of names of inputs. ModuleWithAuxiliaries will automatically collect these inputs
        # that have been saved with keep_auxiliary_input() during execution
        if self.use_aux["class_features"]:
            self.aux_losses.add_auxiliary(
                ClassAuxiliary2D("aux_class", self.params["feature_channels"],
                                 self.params["num_landmarks"],
                                 self.params["dropout"], "fpv_features",
                                 "lm_pos_fpv_features", "lm_indices",
                                 "tensor_store"))
        if self.use_aux["grounding_features"]:
            self.aux_losses.add_auxiliary(
                ClassAuxiliary2D("aux_ground",
                                 self.params["relevance_channels"], 2,
                                 self.params["dropout"], "fpv_features_g",
                                 "lm_pos_fpv_features", "lm_mentioned",
                                 "tensor_store"))
        if self.use_aux["class_map"]:
            self.aux_losses.add_auxiliary(
                ClassAuxiliary2D("aux_class_map",
                                 self.params["feature_channels"],
                                 self.params["num_landmarks"],
                                 self.params["dropout"], "map_S_W",
                                 "lm_pos_map", "lm_indices", "tensor_store"))
        if self.use_aux["grounding_map"]:
            self.aux_losses.add_auxiliary(
                ClassAuxiliary2D("aux_grounding_map",
                                 self.params["relevance_channels"], 2,
                                 self.params["dropout"], "map_R_W",
                                 "lm_pos_map", "lm_mentioned", "tensor_store"))
        if self.use_aux["goal_map"]:
            self.aux_losses.add_auxiliary(
                GoalAuxiliary2D("aux_goal_map", self.params["goal_channels"],
                                self.params["global_map_size"], "map_G_W",
                                "goal_pos_map"))
        # RSS model uses templated data for landmark and side prediction
        if self.use_aux["language"] and self.params["templates"]:
            self.aux_losses.add_auxiliary(
                ClassAuxiliary("aux_lang_lm", self.params["emb_size"],
                               self.params["num_landmarks"], 1,
                               "sentence_embed", "lm_mentioned_tplt"))
            self.aux_losses.add_auxiliary(
                ClassAuxiliary("aux_lang_side", self.params["emb_size"],
                               self.params["num_sides"], 1, "sentence_embed",
                               "side_mentioned_tplt"))
        # CoRL model uses alignment-model groundings
        elif self.use_aux["language"]:
            # one output for each landmark, 2 classes per output. This is for finetuning, so use the embedding that's gonna be fine tuned
            self.aux_losses.add_auxiliary(
                ClassAuxiliary("aux_lang_lm_nl", self.params["emb_size"], 2,
                               self.params["num_landmarks"], "sentence_embed",
                               "lang_lm_mentioned"))
        if self.use_aux["l1_regularization"]:
            self.aux_losses.add_auxiliary(
                FeatureRegularizationAuxiliary2D("aux_regularize_features",
                                                 "l1", "map_S_W"))
            self.aux_losses.add_auxiliary(
                FeatureRegularizationAuxiliary2D("aux_regularize_features",
                                                 "l1", "map_R_W"))

        self.goal_acc_meter = MovingAverageMeter(10)

        self.aux_losses.print_auxiliary_info()

        self.action_loss = ActionLoss()

        self.env_id = None
        self.prev_instruction = None
        self.seq_step = 0

    def cuda(self, device=None):
        CudaModule.cuda(self, device)
        self.aux_losses.cuda(device)
        self.sentence_embedding.cuda(device)
        self.map_accumulator_w.cuda(device)
        self.map_processor_a_w.cuda(device)
        self.map_processor_b_r.cuda(device)
        self.img_to_features_w.cuda(device)
        self.map_to_action.cuda(device)
        self.action_loss.cuda(device)
        self.map_transform_w_to_r.cuda(device)
        self.map_transform_r_to_w.cuda(device)
        return self

    def steal_cross_domain_modules(self, other_self):
        # TODO: Consider whether to share auxiliary losses, and if so, all of them?
        self.aux_losses = other_self.aux_losses
        self.action_loss = other_self.action_loss

        # TODO: Make sure that none of these things are stateful, or that there are resets after every forward pass
        self.sentence_embedding = other_self.sentence_embedding
        self.map_accumulator_w = other_self.map_accumulator_w
        self.map_processor_a_w = other_self.map_processor_a_w
        self.map_processor_b_r = other_self.map_processor_b_r
        self.map_to_action = other_self.map_to_action

        # We'll have a separate one of these for each domain
        #self.img_to_features_w = other_self.img_to_features_w

        # TODO: Check that statefulness is not an issue in sharing modules
        # These have no parameters so no point sharing
        #self.map_transform_w_to_r = other_self.map_transform_w_to_r
        #self.map_transform_r_to_w = other_self.map_transform_r_to_w

    def both_domain_parameters(self, other_self):
        # This function iterates and yields parameters from this module and the other module, but does not yield
        # shared parameters twice.
        # First yield all of the other module's parameters
        for p in other_self.parameters():
            yield p
        # Then yield all the parameters from the this module that are not shared with the other one
        for p in self.img_to_features_w.parameters():
            yield p
        return

    def get_iter(self):
        return int(self.iter.data[0])

    def inc_iter(self):
        self.iter += 1

    def load_img_feature_weights(self):
        if self.params.get("load_feature_net"):
            filename = self.params.get("feature_net_filename")
            weights = load_pytorch_model(None, filename)
            prefix = self.params.get("feature_net_tensor_name")
            if prefix:
                weights = find_state_subdict(weights, prefix)
            # TODO: This breaks OOP conventions
            self.img_to_features_w.img_to_features.load_state_dict(weights)
            print(
                f"Loaded pretrained weights from file {filename} with prefix {prefix}"
            )

    def init_weights(self):
        self.img_to_features_w.init_weights()
        self.load_img_feature_weights()
        self.map_accumulator_w.init_weights()
        self.sentence_embedding.init_weights()
        self.map_to_action.init_weights()
        self.map_processor_a_w.init_weights()
        self.map_processor_b_r.init_weights()

    def reset(self):
        self.tensor_store.reset()
        self.sentence_embedding.reset()
        self.img_to_features_w.reset()
        self.map_accumulator_w.reset()
        self.map_transform_w_to_r.reset()
        self.map_transform_r_to_w.reset()
        self.load_img_feature_weights()
        self.prev_instruction = None

    def set_env_context(self, context):
        print("Set env context to: " + str(context))
        self.env_id = context["env_id"]

    def save_viz(self, images_in, instruction):
        # Save incoming images
        imsave(
            os.path.join(get_viz_dir_for_rollout(),
                         "fpv_" + str(self.seq_step) + ".png"), images_in)
        #self.tensor_store.keep_input("fpv_img", images_in)
        # Save all of these tensors from the tensor store as images
        save_tensors_as_images(self.tensor_store, [
            "images_w", "fpv_img", "fpv_features", "map_F_W", "map_M_W",
            "map_S_W", "map_R_W", "map_R_R", "map_G_R", "map_G_W"
        ], str(self.seq_step))

        # Save action as image
        action = self.tensor_store.get_inputs_batch(
            "action")[-1].data.cpu().squeeze().numpy()
        action_fname = get_viz_dir_for_rollout() + "action_" + str(
            self.seq_step) + ".png"
        Presenter().save_action(action, action_fname, "")

        instruction_fname = get_viz_dir_for_rollout() + "instruction.txt"
        with open(instruction_fname, "w") as fp:
            fp.write(instruction)

    def get_action(self, state, instruction):
        """
        Given a DroneState (from PomdpInterface) and instruction, produce a numpy 4D action (x, y, theta, pstop)
        :param state: DroneState object with the raw image from the simulator
        :param instruction: Tokenized instruction given the corpus
        #TODO: Absorb corpus within model
        :return:
        """
        # TODO: Simplify this
        self.eval()
        images_np_pure = state.image
        state_np = state.state

        #print("Act: " + debug_untokenize_instruction(instruction))

        images_np = standardize_image(images_np_pure)
        image_fpv = Variable(none_padded_seq_to_tensor([images_np]))
        state = Variable(none_padded_seq_to_tensor([state_np]))
        # Add the batch dimension

        first_step = True
        if instruction == self.prev_instruction:
            first_step = False
        self.prev_instruction = instruction
        instruction_str = debug_untokenize_instruction(instruction)

        # TODO: Move this to PomdpInterface (for now it's here because this is already visualizing the maps)
        if first_step:
            if self.rviz is not None:
                self.rviz.publish_instruction_text(
                    "instruction", debug_untokenize_instruction(instruction))

        img_in_t = image_fpv
        img_in_t.volatile = True

        instr_len = [len(instruction)] if instruction is not None else None
        instruction = torch.LongTensor(instruction).unsqueeze(0)
        instruction = cuda_var(instruction, self.is_cuda, self.cuda_device)

        state.volatile = True

        if self.is_cuda:
            if img_in_t is not None:
                img_in_t = img_in_t.cuda(self.cuda_device)
            state = state.cuda(self.cuda_device)

        step_enc = None
        plan_now = None

        self.seq_step += 1

        action = self(img_in_t,
                      state,
                      instruction,
                      instr_len,
                      plan=plan_now,
                      pos_enc=step_enc)

        passive_mode_debug_projections = True
        if passive_mode_debug_projections:
            self.show_landmark_locations(loop=False, states=state)
            self.reset()

        # Run auxiliary objectives for debugging purposes (e.g. to compute classification predictions)
        if self.params.get("run_auxiliaries_at_test_time"):
            _, _ = self.aux_losses.calculate_aux_loss(self.tensor_store,
                                                      reduce_average=True)
            overlaid = self.get_overlaid_classification_results(
                whole_batch=False)

        # Save materials for analysis and presentation
        if self.params["write_figures"]:
            self.save_viz(images_np_pure, instruction_str)

        output_action = action.squeeze().data.cpu().numpy()
        stop_prob = output_action[3]
        output_stop = 1 if stop_prob > self.params["stop_p"] else 0
        output_action[3] = output_stop

        return output_action

    def get_overlaid_classification_results(self, map_not_features=False):
        if map_not_features:
            predictions_name = "aux_class_map_predictions"
        else:
            predictions_name = "aux_class_predictions"
        predictions = self.tensor_store.get_latest_input(predictions_name)
        if predictions is None:
            return None
        predictions = predictions[0].detach()
        # Get the 3 channels corresponding to no landmark, banana and gorilla
        predictions = predictions[[0, 3, 24], :, :]
        images = self.tensor_store.get_latest_input("images")[0].detach()
        overlaid = Presenter().overlaid_image(images,
                                              predictions,
                                              gray_bg=True)
        return overlaid

    def deterministic_action(self, action_mean, action_std, stop_prob):
        batch_size = action_mean.size(0)
        action = Variable(
            empty_float_tensor((batch_size, 4), self.is_cuda,
                               self.cuda_device))
        action[:, 0:3] = action_mean[:, 0:3]
        action[:, 3] = stop_prob
        return action

    # This is called before beginning an execution sequence
    def start_sequence(self):
        self.seq_step = 0
        self.reset()
        print("RESETTED!")
        return

    # TODO: Move this somewhere and standardize
    def cam_poses_from_states(self, states):
        cam_pos = states[:, 9:12]
        cam_rot = states[:, 12:16]

        pos_variance = 0
        rot_variance = 0
        if self.params.get("use_pos_noise"):
            pos_variance = self.params["noisy_pos_variance"]
        if self.params.get("use_rot_noise"):
            rot_variance = self.params["noisy_rot_variance"]

        pose = Pose(cam_pos, cam_rot)
        if self.params.get("use_pos_noise") or self.params.get(
                "use_rot_noise"):
            pose = get_noisy_poses_torch(pose,
                                         pos_variance,
                                         rot_variance,
                                         cuda=self.is_cuda,
                                         cuda_device=self.cuda_device)
        return pose

    def forward(self,
                images,
                states,
                instructions,
                instr_lengths,
                has_obs=None,
                plan=None,
                save_maps_only=False,
                pos_enc=None,
                noisy_poses=None,
                halfway=False):
        """
        :param images: BxCxHxW batch of images (observations)
        :param states: BxK batch of drone states
        :param instructions: BxM LongTensor where M is the maximum length of any instruction
        :param instr_lengths: list of len B of integers, indicating length of each instruction
        :param has_obs: list of booleans of length B indicating whether the given element in the sequence has an observation
        :param yield_semantic_maps: If true, will not compute actions (full model), but return the semantic maps that
            were built along the way in response to the images. This is ugly, but allows code reuse
        :return:
        """
        cam_poses = self.cam_poses_from_states(states)
        g_poses = None  #[None for pose in cam_poses]
        self.prof.tick("out")

        #str_instr = debug_untokenize_instruction(instructions[0].data[:instr_lengths[0]])
        #print("Trn: " + str_instr)

        # Calculate the instruction embedding
        if instructions is not None:
            # TODO: Take batch of instructions and their lengths, return batch of embeddings. Store the last one as internal state
            sent_embeddings = self.sentence_embedding(instructions,
                                                      instr_lengths)
            self.tensor_store.keep_inputs("sentence_embed", sent_embeddings)
        else:
            sent_embeddings = self.sentence_embedding.get()

        self.prof.tick("embed")

        # Extract and project features onto the egocentric frame for each image
        features_w, coverages_w = self.img_to_features_w(images,
                                                         cam_poses,
                                                         sent_embeddings,
                                                         self.tensor_store,
                                                         show="")

        # If we're running the model halway, return now. This is to compute enough features for the wasserstein critic, but no more
        if halfway:
            return None

        # Don't back-prop into resnet if we're freezing these features (TODO: instead set requires grad to false)
        if self.params.get("freeze_feature_net"):
            features_w = features_w.detach()

        self.prof.tick("img_to_map_frame")
        self.tensor_store.keep_inputs("images", images)
        self.tensor_store.keep_inputs("map_F_w", features_w)
        self.tensor_store.keep_inputs("map_M_w", coverages_w)

        if run_metadata.IS_ROLLOUT:
            Presenter().show_image(features_w.data[0, 0:3],
                                   "F",
                                   torch=True,
                                   scale=8,
                                   waitkey=1)

        # Accumulate the egocentric features in a global map
        maps_s_w = self.map_accumulator_w(features_w,
                                          coverages_w,
                                          add_mask=has_obs,
                                          show="acc" if IMG_DBG else "")
        map_poses_w = g_poses
        self.tensor_store.keep_inputs("map_S_W", maps_s_w)
        self.prof.tick("map_accumulate")

        Presenter().show_image(maps_s_w.data[0],
                               f"{self.name}_S_map_W",
                               torch=True,
                               scale=4,
                               waitkey=1)

        # Do grounding of objects in the map chosen to do so
        maps_r_w, map_poses_r_w = self.map_processor_a_w(maps_s_w,
                                                         sent_embeddings,
                                                         map_poses_w,
                                                         show="")
        self.tensor_store.keep_inputs("map_R_W", maps_r_w)
        Presenter().show_image(maps_r_w.data[0],
                               f"{self.name}_R_map_W",
                               torch=True,
                               scale=4,
                               waitkey=1)
        self.prof.tick("map_proc_gnd")

        # Transform to drone's reference frame
        self.map_transform_w_to_r.set_maps(maps_r_w, map_poses_r_w)
        maps_r_r, map_poses_r_r = self.map_transform_w_to_r.get_maps(cam_poses)
        self.tensor_store.keep_inputs("map_R_R", maps_r_r)
        self.prof.tick("transform_w_to_r")

        # Predict goal location
        maps_g_r, map_poses_g_r = self.map_processor_b_r(
            maps_r_r, sent_embeddings, map_poses_r_r)
        self.tensor_store.keep_inputs("map_G_R", maps_g_r)

        # Transform back to map frame
        self.map_transform_r_to_w.set_maps(maps_g_r, map_poses_g_r)
        maps_g_w, _ = self.map_transform_r_to_w.get_maps(None)
        self.tensor_store.keep_inputs("map_G_W", maps_g_w)
        self.prof.tick("map_proc_b")

        # Show and publish to RVIZ
        Presenter().show_image(maps_g_w.data[0],
                               f"{self.name}_G_map_W",
                               torch=True,
                               scale=8,
                               waitkey=1)
        if self.rviz:
            self.rviz.publish_map(
                "goal_map", maps_g_w[0].data.cpu().numpy().transpose(1, 2, 0),
                self.params["world_size_m"])

        # Output the final action given the processed map
        action_pred = self.map_to_action(maps_g_r, sent_embeddings)
        out_action = self.deterministic_action(action_pred[:, 0:3], None,
                                               action_pred[:, 3])
        self.tensor_store.keep_inputs("action", out_action)
        self.prof.tick("map_to_action")

        return out_action

    # TODO: The below two methods seem to do the same thing
    def maybe_cuda(self, tensor):
        if self.is_cuda:
            return tensor.cuda()
        else:
            return tensor

    def cuda_var(self, tensor):
        return cuda_var(tensor, self.is_cuda, self.cuda_device)

    def unbatch(self, batch):
        # TODO: Carefully consider this line. This is necessary to reset state between batches (e.g. delete all tensors in the tensor store)
        self.reset()
        # Get rid of the batch dimension for everything
        images = self.maybe_cuda(batch["images"])[0]
        seq_len = images.shape[0]
        instructions = self.maybe_cuda(batch["instr"])[0][:seq_len]
        instr_lengths = batch["instr_len"][0]
        states = self.maybe_cuda(batch["states"])[0]
        actions = self.maybe_cuda(batch["actions"])[0]

        # Auxiliary labels
        lm_pos_fpv = batch["lm_pos_fpv"][0]
        lm_pos_map = batch["lm_pos_map"][0]
        lm_indices = batch["lm_indices"][0]
        goal_pos_map = batch["goal_loc"][0]

        # TODO: Get rid of this. We will have lm_mentioned booleans and lm_mentioned_idx integers and that's it.
        TEMPLATES = True
        if TEMPLATES:
            lm_mentioned_tplt = batch["lm_mentioned_tplt"][0]
            side_mentioned_tplt = batch["side_mentioned_tplt"][0]
            side_mentioned_tplt = self.cuda_var(side_mentioned_tplt)
            lm_mentioned_tplt = self.cuda_var(lm_mentioned_tplt)
            lang_lm_mentioned = None
        else:
            lm_mentioned_tplt = None
            side_mentioned_tplt = None
            lang_lm_mentioned = batch["lang_lm_mentioned"][0]
        lm_mentioned = batch["lm_mentioned"][0]
        # This is the first-timestep metadata
        metadata = batch["md"][0]

        lm_pos_map = [
            torch.from_numpy(
                transformations.pos_m_to_px(
                    p.numpy(), self.params["global_map_size"],
                    self.params["world_size_m"], self.params["world_size_px"]))
            if p is not None else None for p in lm_pos_map
        ]

        goal_pos_map = torch.from_numpy(
            transformations.pos_m_to_px(goal_pos_map.numpy(),
                                        self.params["global_map_size"],
                                        self.params["world_size_m"],
                                        self.params["world_size_px"]))

        lm_pos_map = [
            self.cuda_var(s.long()) if s is not None else None
            for s in lm_pos_map
        ]
        lm_pos_fpv_features = [
            self.cuda_var(
                (s /
                 self.img_to_features_w.img_to_features.get_downscale_factor()
                 ).long()) if s is not None else None for s in lm_pos_fpv
        ]
        lm_pos_fpv_img = [
            self.cuda_var(s.long()) if s is not None else None
            for s in lm_pos_fpv
        ]
        lm_indices = [
            self.cuda_var(s) if s is not None else None for s in lm_indices
        ]
        goal_pos_map = self.cuda_var(goal_pos_map)
        if not TEMPLATES:
            lang_lm_mentioned = self.cuda_var(lang_lm_mentioned)
        lm_mentioned = [
            self.cuda_var(s) if s is not None else None for s in lm_mentioned
        ]

        obs_mask = [True for _ in range(seq_len)]
        plan_mask = [True for _ in range(seq_len)]
        pos_enc = None

        # TODO: Figure out how to keep these properly. Perhaps as a whole batch is best
        self.tensor_store.keep_inputs("lm_pos_fpv_img", lm_pos_fpv_img)
        self.tensor_store.keep_inputs("lm_pos_fpv_features",
                                      lm_pos_fpv_features)
        self.tensor_store.keep_inputs("lm_pos_map", lm_pos_map)
        self.tensor_store.keep_inputs("lm_indices", lm_indices)
        self.tensor_store.keep_inputs("goal_pos_map", goal_pos_map)
        if not TEMPLATES:
            self.tensor_store.keep_inputs("lang_lm_mentioned",
                                          lang_lm_mentioned)
        else:
            self.tensor_store.keep_inputs("lm_mentioned_tplt",
                                          lm_mentioned_tplt)
            self.tensor_store.keep_inputs("side_mentioned_tplt",
                                          side_mentioned_tplt)
        self.tensor_store.keep_inputs("lm_mentioned", lm_mentioned)

        # ----------------------------------------------------------------------------
        # Optional Auxiliary Inputs
        # ----------------------------------------------------------------------------
        #if self.aux_losses.input_required("lm_pos_map"):
        self.tensor_store.keep_inputs("lm_pos_map", lm_pos_map)
        #if self.aux_losses.input_required("lm_indices"):
        self.tensor_store.keep_inputs("lm_indices", lm_indices)
        #if self.aux_losses.input_required("lm_mentioned"):
        self.tensor_store.keep_inputs("lm_mentioned", lm_mentioned)

        return images, instructions, instr_lengths, states, actions, \
               lm_pos_fpv_img, lm_pos_fpv_features, lm_pos_map, lm_indices, goal_pos_map, \
               lm_mentioned, lm_mentioned_tplt, side_mentioned_tplt, lang_lm_mentioned, \
               metadata, obs_mask, plan_mask, pos_enc

    def show_landmark_locations(self, loop=True, states=None):
        # Show landmark locations in first-person images
        img_all = self.tensor_store.get("images")
        img_w_all = self.tensor_store.get("images_w")
        import rollout.run_metadata as md
        if md.IS_ROLLOUT:
            # TODO: Discard this and move this to PomdpInterface or something
            # (it's got nothing to do with the model)
            # load landmark positions from configs
            from data_io.env import load_env_config
            from learning.datasets.aux_data_providers import get_landmark_locations_airsim
            from learning.models.semantic_map.pinhole_camera_inv import PinholeCameraProjection
            projector = PinholeCameraProjection(
                map_size_px=self.params["global_map_size"],
                world_size_px=self.params["world_size_px"],
                world_size_m=self.params["world_size_m"],
                img_x=self.params["img_w"],
                img_y=self.params["img_h"],
                cam_fov=self.params["cam_h_fov"],
                #TODO: Handle correctly
                domain="sim",
                use_depth=False)
            conf_json = load_env_config(md.ENV_ID)
            landmark_names, landmark_indices, landmark_pos = get_landmark_locations_airsim(
                conf_json)
            cam_poses = self.cam_poses_from_states(states)
            cam_pos = cam_poses.position[0]
            cam_rot = cam_poses.orientation[0]
            lm_pos_map_all = []
            lm_pos_img_all = []
            for i, landmark_in_world in enumerate(landmark_pos):
                lm_pos_img, landmark_in_cam, status = projector.world_point_to_image(
                    cam_pos, cam_rot, landmark_in_world)
                lm_pos_map = torch.from_numpy(
                    transformations.pos_m_to_px(
                        landmark_in_world[np.newaxis, :],
                        self.params["global_map_size"],
                        self.params["world_size_m"],
                        self.params["world_size_px"]))
                lm_pos_map_all += [lm_pos_map[0]]
                if lm_pos_img is not None:
                    lm_pos_img_all += [lm_pos_img]

            lm_pos_img_all = [lm_pos_img_all]
            lm_pos_map_all = [lm_pos_map_all]

        else:
            lm_pos_img_all = self.tensor_store.get("lm_pos_fpv_img")
            lm_pos_map_all = self.tensor_store.get("lm_pos_map")

        print("Plotting landmark points")

        for i in range(len(img_all)):
            p = Presenter()
            overlay_fpv = p.overlay_pts_on_image(img_all[i][0],
                                                 lm_pos_img_all[i])
            overlay_map = p.overlay_pts_on_image(img_w_all[i][0],
                                                 lm_pos_map_all[i])
            p.show_image(overlay_fpv, "landmarks_on_fpv_img", scale=8)
            p.show_image(overlay_map, "landmarks_on_map", scale=20)

            if not loop:
                break

    def calc_tensor_statistics(self, prefix, tensor):
        stats = {}
        stats[f"{prefix}_mean"] = torch.mean(tensor).item()
        stats[f"{prefix}_l2"] = torch.norm(tensor).item()
        stats[f"{prefix}_stddev"] = torch.std(tensor).item()
        return stats

    def get_activation_statistics(self, keys):
        stats = {}
        from utils.dict_tools import dict_merge
        for key in keys:
            t = self.tensor_store.get_inputs_batch(key)
            t_stats = self.calc_tensor_statistics(key, t)
            stats = dict_merge(stats, t_stats)
        return stats

    # Forward pass for training (with batch optimizations
    def sup_loss_on_batch(self, batch, eval, halfway=False):
        self.prof.tick("out")

        action_loss_total = Variable(
            empty_float_tensor([1], self.is_cuda, self.cuda_device))

        if batch is None:
            print("Skipping None Batch")
            return action_loss_total

        images, instructions, instr_lengths, states, action_labels, \
        lm_pos_fpv_img, lm_pos_fpv_features, lm_pos_map, lm_indices, goal_pos_map, \
        lm_mentioned, lm_mentioned_tplt, side_mentioned_tplt, lang_lm_mentioned, \
        metadata, obs_mask, plan_mask, pos_enc = self.unbatch(batch)

        # ----------------------------------------------------------------------------
        self.prof.tick("inputs")

        pred_actions = self(images,
                            states,
                            instructions,
                            instr_lengths,
                            has_obs=obs_mask,
                            plan=plan_mask,
                            pos_enc=pos_enc,
                            halfway=halfway)

        # Debugging landmark locations
        if False:
            self.show_landmark_locations()

        # Don't compute any losses - those will not be used. All we care about are the intermediate activations
        if halfway:
            return None, self.tensor_store

        action_losses, _ = self.action_loss(action_labels,
                                            pred_actions,
                                            batchreduce=False)

        self.prof.tick("call")

        action_losses = self.action_loss.batch_reduce_loss(action_losses)
        action_loss = self.action_loss.reduce_loss(action_losses)

        action_loss_total = action_loss

        self.prof.tick("loss")

        aux_losses, aux_metrics = self.aux_losses.calculate_aux_loss(
            self.tensor_store, reduce_average=True)
        aux_loss = self.aux_losses.combine_losses(aux_losses, self.aux_weights)

        #overlaid = self.get_overlaid_classification_results()
        #Presenter().show_image(overlaid, "classification", scale=2)

        prefix = f"{self.model_name}/{'eval' if eval else 'train'}"
        act_prefix = f"{self.model_name}_activations/{'eval' if eval else 'train'}"

        # Mean, stddev, norm of maps
        act_stats = self.get_activation_statistics(
            ["map_S_W", "map_R_W", "map_G_W"])
        self.writer.add_dict(act_prefix, act_stats, self.get_iter())

        self.writer.add_dict(prefix, get_current_meters(), self.get_iter())
        self.writer.add_dict(prefix, aux_losses, self.get_iter())
        self.writer.add_dict(prefix, aux_metrics, self.get_iter())
        self.writer.add_scalar(prefix + "/action_loss",
                               action_loss_total.data.cpu().item(),
                               self.get_iter())
        # TODO: Log value here
        self.writer.add_scalar(prefix + "/goal_accuracy",
                               self.goal_acc_meter.get(), self.get_iter())

        self.prof.tick("auxiliaries")

        total_loss = action_loss_total + aux_loss

        self.inc_iter()

        self.prof.tick("summaries")
        self.prof.loop()
        self.prof.print_stats(1)

        return total_loss, self.tensor_store

    def get_dataset(self,
                    data=None,
                    envs=None,
                    dataset_names=None,
                    dataset_prefix=None,
                    eval=False):
        # TODO: Maybe use eval here
        #if self.fpv:
        data_sources = []
        # If we're running auxiliary objectives, we need to include the data sources for the auxiliary labels
        #if self.use_aux_class_features or self.use_aux_class_on_map or self.use_aux_grounding_features or self.use_aux_grounding_on_map:
        #if self.use_aux_goal_on_map:
        data_sources.append(aup.PROVIDER_LM_POS_DATA)
        data_sources.append(aup.PROVIDER_GOAL_POS)
        #data_sources.append(aup.PROVIDER_LANDMARKS_MENTIONED)
        data_sources.append(aup.PROVIDER_LANG_TEMPLATE)

        #if self.use_rot_noise or self.use_pos_noise:
        #    data_sources.append(aup.PROVIDER_POSE_NOISE)

        return SegmentDataset(data=data,
                              env_list=envs,
                              dataset_names=dataset_names,
                              dataset_prefix=dataset_prefix,
                              aux_provider_names=data_sources,
                              segment_level=True)
Exemplo n.º 18
0
class ModelTrajectoryTopDown(ModuleWithAuxiliaries):

    def __init__(self, run_name="", model_class=MODEL_RSS,
                 aux_class_features=False, aux_grounding_features=False,
                 aux_class_map=False, aux_grounding_map=False, aux_goal_map=False,
                 aux_lang=False, aux_traj=False, rot_noise=False, pos_noise=False):

        super(ModelTrajectoryTopDown, self).__init__()
        self.model_name = "sm_trajectory" + str(model_class)
        self.model_class = model_class
        print("Init model of type: ", str(model_class))
        self.run_name = run_name
        self.writer = LoggingSummaryWriter(log_dir="runs/" + run_name)

        self.params = get_current_parameters()["Model"]
        self.aux_weights = get_current_parameters()["AuxWeights"]

        self.prof = SimpleProfiler(torch_sync=PROFILE, print=PROFILE)
        self.iter = nn.Parameter(torch.zeros(1), requires_grad=False)

        # Auxiliary Objectives
        self.use_aux_class_features = aux_class_features
        self.use_aux_grounding_features = aux_grounding_features
        self.use_aux_class_on_map = aux_class_map
        self.use_aux_grounding_on_map = aux_grounding_map
        self.use_aux_goal_on_map = aux_goal_map
        self.use_aux_lang = aux_lang
        self.use_aux_traj_on_map = aux_traj
        self.use_aux_reg_map = self.aux_weights["regularize_map"]

        self.use_rot_noise = rot_noise
        self.use_pos_noise = pos_noise


        # Path-pred FPV model definition
        # --------------------------------------------------------------------------------------------------------------

        self.img_to_features_w = FPVToGlobalMap(
            source_map_size=self.params["global_map_size"], world_size_px=self.params["world_size_px"], world_size=self.params["world_size_m"],
            res_channels=self.params["resnet_channels"], map_channels=self.params["feature_channels"],
            img_w=self.params["img_w"], img_h=self.params["img_h"], img_dbg=IMG_DBG)

        self.map_accumulator_w = LeakyIntegratorGlobalMap(source_map_size=self.params["global_map_size"], world_in_map_size=self.params["world_size_px"])

        # Pre-process the accumulated map to do language grounding if necessary - in the world reference frame
        if self.use_aux_grounding_on_map and not self.use_aux_grounding_features:
            self.map_processor_a_w = LangFilterMapProcessor(
                source_map_size=self.params["global_map_size"],
                world_size=self.params["world_size_px"],
                embed_size=self.params["emb_size"],
                in_channels=self.params["feature_channels"],
                out_channels=self.params["relevance_channels"],
                spatial=False, cat_out=True)
        else:
            self.map_processor_a_w = IdentityMapProcessor(source_map_size=self.params["global_map_size"], world_size=self.params["world_size_px"])

        if self.use_aux_goal_on_map:
            self.map_processor_b_r = LangFilterMapProcessor(source_map_size=self.params["local_map_size"],
                                                            world_size=self.params["world_size_px"],
                                                            embed_size=self.params["emb_size"],
                                                            in_channels=self.params["relevance_channels"],
                                                            out_channels=self.params["goal_channels"],
                                                            spatial=True, cat_out=True)
        else:
            self.map_processor_b_r = IdentityMapProcessor(source_map_size=self.params["local_map_size"],
                                                          world_size=self.params["world_size_px"])

        pred_channels = self.params["goal_channels"] + self.params["relevance_channels"]

        # Common
        # --------------------------------------------------------------------------------------------------------------

        # Sentence Embedding
        self.sentence_embedding = SentenceEmbeddingSimple(
            self.params["word_emb_size"], self.params["emb_size"], self.params["emb_layers"])

        self.map_transform_w_to_r = MapTransformerBase(source_map_size=self.params["global_map_size"],
                                                       dest_map_size=self.params["local_map_size"],
                                                       world_size=self.params["world_size_px"])
        self.map_transform_r_to_w = MapTransformerBase(source_map_size=self.params["local_map_size"],
                                                       dest_map_size=self.params["global_map_size"],
                                                       world_size=self.params["world_size_px"])

        # Batch select is used to drop and forget semantic maps at those timestaps that we're not planning in
        self.batch_select = MapBatchSelect()
        # Since we only have path predictions for some timesteps (the ones not dropped above), we use this to fill
        # in the missing pieces by reorienting the past trajectory prediction into the frame of the current timestep
        self.map_batch_fill_missing = MapBatchFillMissing(self.params["local_map_size"], self.params["world_size_px"])

        # Passing true to freeze will freeze these weights regardless of whether they've been explicitly reloaded or not
        enable_weight_saving(self.sentence_embedding, "sentence_embedding", alwaysfreeze=False)

        # Output an action given the global semantic map
        if self.params["map_to_action"] == "downsample2":
            self.map_to_action = EgoMapToActionTriplet(
                map_channels=self.params["map_to_act_channels"],
                map_size=self.params["local_map_size"],
                other_features_size=self.params["emb_size"])

        elif self.params["map_to_action"] == "cropped":
            self.map_to_action = CroppedMapToActionTriplet(
                map_channels=self.params["map_to_act_channels"],
                map_size=self.params["local_map_size"],
                other_features_size=self.params["emb_size"]
            )

        # Don't freeze the trajectory to action weights, because it will be pre-trained during path-prediction training
        # and finetuned on all timesteps end-to-end
        enable_weight_saving(self.map_to_action, "map_to_action", alwaysfreeze=False, neverfreeze=True)

        # Auxiliary Objectives
        # --------------------------------------------------------------------------------------------------------------

        # We add all auxiliaries that are necessary. The first argument is the auxiliary name, followed by parameters,
        # followed by variable number of names of inputs. ModuleWithAuxiliaries will automatically collect these inputs
        # that have been saved with keep_auxiliary_input() during execution
        if aux_class_features:
            self.add_auxiliary(ClassAuxiliary2D("aux_class", None,  self.params["feature_channels"], self.params["num_landmarks"], self.params["dropout"],
                                                "fpv_features", "lm_pos_fpv", "lm_indices"))
        if aux_grounding_features:
            self.add_auxiliary(ClassAuxiliary2D("aux_ground", None, self.params["relevance_channels"], 2, self.params["dropout"],
                                                "fpv_features_g", "lm_pos_fpv", "lm_mentioned"))
        if aux_class_map:
            self.add_auxiliary(ClassAuxiliary2D("aux_class_map", self.params["world_size_px"], self.params["feature_channels"], self.params["num_landmarks"], self.params["dropout"],
                                                "map_s_w_select", "lm_pos_map_select", "lm_indices_select"))
        if aux_grounding_map:
            self.add_auxiliary(ClassAuxiliary2D("aux_grounding_map", self.params["world_size_px"], self.params["relevance_channels"], 2, self.params["dropout"],
                                                "map_a_w_select", "lm_pos_map_select", "lm_mentioned_select"))
        if aux_goal_map:
            self.add_auxiliary(GoalAuxiliary2D("aux_goal_map", self.params["goal_channels"], self.params["world_size_px"],
                                               "map_b_w", "goal_pos_map"))
        # RSS model uses templated data for landmark and side prediction
        if self.use_aux_lang and self.params["templates"]:
            self.add_auxiliary(ClassAuxiliary("aux_lang_lm", self.params["emb_size"], self.params["num_landmarks"], 1,
                                                "sentence_embed", "lm_mentioned_tplt"))
            self.add_auxiliary(ClassAuxiliary("aux_lang_side", self.params["emb_size"], self.params["num_sides"], 1,
                                                "sentence_embed", "side_mentioned_tplt"))
        # CoRL model uses alignment-model groundings
        elif self.use_aux_lang:
            # one output for each landmark, 2 classes per output. This is for finetuning, so use the embedding that's gonna be fine tuned
            self.add_auxiliary(ClassAuxiliary("aux_lang_lm_nl", self.params["emb_size"], 2, self.params["num_landmarks"],
                                                "sentence_embed", "lang_lm_mentioned"))
        if self.use_aux_traj_on_map:
            self.add_auxiliary(PathAuxiliary2D("aux_path", "map_b_r_select", "traj_gt_r_select"))

        if self.use_aux_reg_map:
            self.add_auxiliary(FeatureRegularizationAuxiliary2D("aux_regularize_features", None, "l1",
                                                                "map_s_w_select", "lm_pos_map_select"))

        self.goal_good_criterion = GoalPredictionGoodCriterion(ok_distance=3.2)
        self.goal_acc_meter = MovingAverageMeter(10)

        self.print_auxiliary_info()

        self.action_loss = ActionLoss()

        self.env_id = None
        self.prev_instruction = None
        self.seq_step = 0

    # TODO: Try to hide these in a superclass or something. They take up a lot of space:
    def cuda(self, device=None):
        ModuleWithAuxiliaries.cuda(self, device)
        self.sentence_embedding.cuda(device)
        self.map_accumulator_w.cuda(device)
        self.map_processor_a_w.cuda(device)
        self.map_processor_b_r.cuda(device)
        self.img_to_features_w.cuda(device)
        self.map_to_action.cuda(device)
        self.action_loss.cuda(device)
        self.map_batch_fill_missing.cuda(device)
        self.map_transform_w_to_r.cuda(device)
        self.map_transform_r_to_w.cuda(device)
        self.batch_select.cuda(device)
        self.map_batch_fill_missing.cuda(device)
        return self

    def get_iter(self):
        return int(self.iter.data[0])

    def inc_iter(self):
        self.iter += 1

    def init_weights(self):
        self.img_to_features_w.init_weights()
        self.map_accumulator_w.init_weights()
        self.sentence_embedding.init_weights()
        self.map_to_action.init_weights()
        self.map_processor_a_w.init_weights()
        self.map_processor_b_r.init_weights()

    def reset(self):
        # TODO: This is error prone. Create a class StatefulModule, iterate submodules and reset all stateful modules
        super(ModelTrajectoryTopDown, self).reset()
        self.sentence_embedding.reset()
        self.img_to_features_w.reset()
        self.map_accumulator_w.reset()
        self.map_processor_a_w.reset()
        self.map_processor_b_r.reset()
        self.map_transform_w_to_r.reset()
        self.map_transform_r_to_w.reset()
        self.map_batch_fill_missing.reset()
        self.prev_instruction = None

    def setEnvContext(self, context):
        print("Set env context to: " + str(context))
        self.env_id = context["env_id"]

    def save_viz(self, images_in):
        imsave(get_viz_dir() + "fpv_" + str(self.seq_step) + ".png", images_in)
        features_cam = self.get_inputs_batch("fpv_features")[-1, 0, 0:3]
        save_tensor_as_img(features_cam, "F_c", self.env_id)
        feature_map_torch = self.get_inputs_batch("f_w")[-1, 0, 0:3]
        save_tensor_as_img(feature_map_torch, "F_w", self.env_id)
        coverage_map_torch = self.get_inputs_batch("m_w")[-1, 0, 0:3]
        save_tensor_as_img(coverage_map_torch, "M_w", self.env_id)
        semantic_map_torch = self.get_inputs_batch("map_s_w_select")[-1, 0, 0:3]
        save_tensor_as_img(semantic_map_torch, "S_w", self.env_id)
        relmap_torch = self.get_inputs_batch("map_a_w_select")[-1, 0, 0:3]
        save_tensor_as_img(relmap_torch, "R_w", self.env_id)
        relmap_r_torch = self.get_inputs_batch("map_a_r_select")[-1, 0, 0:3]
        save_tensor_as_img(relmap_r_torch, "R_r", self.env_id)
        goalmap_torch = self.get_inputs_batch("map_b_w_select")[-1, 0, 0:3]
        save_tensor_as_img(goalmap_torch, "G_w", self.env_id)
        goalmap_r_torch = self.get_inputs_batch("map_b_r_select")[-1, 0, 0:3]
        save_tensor_as_img(goalmap_r_torch, "G_r", self.env_id)

        action = self.get_inputs_batch("action")[-1].data.cpu().squeeze().numpy()
        action_fname = self.get_viz_dir() + "action_" + str(self.seq_step) + ".png"
        Presenter().save_action(action, action_fname, "")

    def get_action(self, state, instruction):
        """
        Given a DroneState (from PomdpInterface) and instruction, produce a numpy 4D action (x, y, theta, pstop)
        :param state: DroneState object with the raw image from the simulator
        :param instruction: Tokenized instruction given the corpus
        #TODO: Absorb corpus within model
        :return:
        """
        # TODO: Simplify this
        self.eval()
        images_np_pure = state.image
        state_np = state.state

        #print("Act: " + debug_untokenize_instruction(instruction))

        images_np = standardize_image(images_np_pure)
        image_fpv = Variable(none_padded_seq_to_tensor([images_np]))
        state = Variable(none_padded_seq_to_tensor([state_np]))
        # Add the batch dimension

        first_step = True
        if instruction == self.prev_instruction:
            first_step = False
        self.prev_instruction = instruction

        img_in_t = image_fpv
        img_in_t.volatile = True

        instr_len = [len(instruction)] if instruction is not None else None
        instruction = torch.LongTensor(instruction).unsqueeze(0)
        instruction = cuda_var(instruction, self.is_cuda, self.cuda_device)

        state.volatile = True

        if self.is_cuda:
            if img_in_t is not None:
                img_in_t = img_in_t.cuda(self.cuda_device)
            state = state.cuda(self.cuda_device)

        step_enc = None
        plan_now = None

        self.seq_step += 1

        action = self(img_in_t, state, instruction, instr_len, plan=plan_now, pos_enc=step_enc)

        # Save materials for paper and presentation
        if False:
            self.save_viz(images_np_pure)

        output_action = action.squeeze().data.cpu().numpy()
        stop_prob = output_action[3]
        output_stop = 1 if stop_prob > 0.5 else 0
        output_action[3] = output_stop

        return output_action

    def deterministic_action(self, action_mean, action_std, stop_prob):
        batch_size = action_mean.size(0)
        action = Variable(empty_float_tensor((batch_size, 4), self.is_cuda, self.cuda_device))
        action[:, 0:3] = action_mean[:, 0:3]
        action[:, 3] = stop_prob
        return action

    def sample_action(self, action_mean, action_std, stop_prob):
        action = torch.normal(action_mean, action_std)
        stop = torch.bernoulli(stop_prob)
        return action, stop

    # This is called before beginning an execution sequence
    def start_sequence(self):
        self.seq_step = 0
        self.reset()
        print("RESETTED!")
        return

    # TODO: Move this somewhere and standardize
    def cam_poses_from_states(self, states):
        cam_pos = states[:, 9:12]
        cam_rot = states[:, 12:16]

        pos_variance = 0
        rot_variance = 0
        if self.use_pos_noise:
            pos_variance = self.params["noisy_pos_variance"]
        if self.use_rot_noise:
            rot_variance = self.params["noisy_rot_variance"]

        pose = Pose(cam_pos, cam_rot)
        if self.use_pos_noise or self.use_rot_noise:
            pose = get_noisy_poses_torch(pose, pos_variance, rot_variance, cuda=self.is_cuda, cuda_device=self.cuda_device)
        return pose

    def forward(self, images, states, instructions, instr_lengths, has_obs=None, plan=None, save_maps_only=False, pos_enc=None, noisy_poses=None):
        """
        :param images: BxCxHxW batch of images (observations)
        :param states: BxK batch of drone states
        :param instructions: BxM LongTensor where M is the maximum length of any instruction
        :param instr_lengths: list of len B of integers, indicating length of each instruction
        :param has_obs: list of booleans of length B indicating whether the given element in the sequence has an observation
        :param yield_semantic_maps: If true, will not compute actions (full model), but return the semantic maps that
            were built along the way in response to the images. This is ugly, but allows code reuse
        :return:
        """
        cam_poses = self.cam_poses_from_states(states)
        g_poses = None#[None for pose in cam_poses]
        self.prof.tick("out")

        #print("Trn: " + debug_untokenize_instruction(instructions[0].data[:instr_lengths[0]]))

        # Calculate the instruction embedding
        if instructions is not None:
            # TODO: Take batch of instructions and their lengths, return batch of embeddings. Store the last one as internal state
            sent_embeddings = self.sentence_embedding(instructions, instr_lengths)
            self.keep_inputs("sentence_embed", sent_embeddings)
        else:
            sent_embeddings = self.sentence_embedding.get()

        self.prof.tick("embed")

        # Extract and project features onto the egocentric frame for each image
        features_w, coverages_w = self.img_to_features_w(images, cam_poses, sent_embeddings, self, show="")
        self.prof.tick("img_to_map_frame")
        self.keep_inputs("f_w", features_w)
        self.keep_inputs("m_w", coverages_w)

        # Accumulate the egocentric features in a global map
        maps_w = self.map_accumulator_w(features_w, coverages_w, add_mask=has_obs, show="acc" if IMG_DBG else "")
        map_poses_w = g_poses

        # TODO: Maybe keep maps_w if necessary
        #self.keep_inputs("map_sm_local", maps_m)
        self.prof.tick("map_accumulate")

        # Throw away those timesteps that don't correspond to planning timesteps
        maps_w_select, map_poses_w_select, cam_poses_select, noisy_poses_select, _, sent_embeddings_select, pos_enc = \
            self.batch_select(maps_w, map_poses_w, cam_poses, noisy_poses, None, sent_embeddings, pos_enc, plan)

        # Only process the maps on planning timesteps
        if len(maps_w_select) > 0:
            self.keep_inputs("map_s_w_select", maps_w_select)
            self.prof.tick("batch_select")

            # Process the map via the two map_procesors
            # Do grounding of objects in the map chosen to do so
            maps_w_select, map_poses_w_select = self.map_processor_a_w(maps_w_select, sent_embeddings_select, map_poses_w_select, show="")
            self.keep_inputs("map_a_w_select", maps_w_select)

            self.prof.tick("map_proc_gnd")

            self.map_transform_w_to_r.set_maps(maps_w_select, map_poses_w_select)
            maps_m_select, map_poses_m_select = self.map_transform_w_to_r.get_maps(cam_poses_select)

            self.keep_inputs("map_a_r_select", maps_w_select)
            self.prof.tick("transform_w_to_r")

            self.keep_inputs("map_a_r_perturbed_select", maps_m_select)

            self.prof.tick("map_perturb")

            # Include positional encoding for path prediction
            if pos_enc is not None:
                sent_embeddings_pp = torch.cat([sent_embeddings_select, pos_enc.unsqueeze(1)], dim=1)
            else:
                sent_embeddings_pp = sent_embeddings_select

            # Process the map via the two map_procesors (e.g. predict the trajectory that we'll be taking)
            maps_m_select, map_poses_m_select = self.map_processor_b_r(maps_m_select, sent_embeddings_pp, map_poses_m_select)

            self.keep_inputs("map_b_r_select", maps_m_select)

            if True:
                self.map_transform_r_to_w.set_maps(maps_m_select, map_poses_m_select)
                maps_b_w_select, _ = self.map_transform_r_to_w.get_maps(None)
                self.keep_inputs("map_b_w_select", maps_b_w_select)

            self.prof.tick("map_proc_b")

        else:
            maps_m_select = None

        maps_m, map_poses_m = self.map_batch_fill_missing(maps_m_select, cam_poses, plan, show="")
        self.keep_inputs("map_b_r", maps_m)
        self.prof.tick("map_fill_missing")

        # Keep global maps for auxiliary objectives if necessary
        if self.input_required("map_b_w"):
            maps_b, _ = self.map_processor_b_r.get_maps(g_poses)
            self.keep_inputs("map_b_w", maps_b)

        self.prof.tick("keep_global_maps")

        if run_metadata.IS_ROLLOUT:
            pass
            #Presenter().show_image(maps_m.data[0, 0:3], "plan_map_now", torch=True, scale=4, waitkey=1)
            #Presenter().show_image(maps_w.data[0, 0:3], "sm_map_now", torch=True, scale=4, waitkey=1)
        self.prof.tick("viz")

        # Output the final action given the processed map
        action_pred = self.map_to_action(maps_m, sent_embeddings)
        out_action = self.deterministic_action(action_pred[:, 0:3], None, action_pred[:, 3])

        self.keep_inputs("action", out_action)
        self.prof.tick("map_to_action")

        return out_action

    # TODO: The below two methods seem to do the same thing
    def maybe_cuda(self, tensor):
        if self.is_cuda:
            return tensor.cuda()
        else:
            return tensor

    def cuda_var(self, tensor):
        return cuda_var(tensor, self.is_cuda, self.cuda_device)

    # Forward pass for training (with batch optimizations
    def sup_loss_on_batch(self, batch, eval):
        self.prof.tick("out")

        action_loss_total = Variable(empty_float_tensor([1], self.is_cuda, self.cuda_device))

        if batch is None:
            print("Skipping None Batch")
            return action_loss_total

        images = self.maybe_cuda(batch["images"])

        instructions = self.maybe_cuda(batch["instr"])
        instr_lengths = batch["instr_len"]
        states = self.maybe_cuda(batch["states"])
        actions = self.maybe_cuda(batch["actions"])

        # Auxiliary labels
        lm_pos_fpv = batch["lm_pos_fpv"]
        lm_pos_map = batch["lm_pos_map"]
        lm_indices = batch["lm_indices"]
        goal_pos_map = batch["goal_loc"]

        TEMPLATES = True
        if TEMPLATES:
            lm_mentioned_tplt = batch["lm_mentioned_tplt"]
            side_mentioned_tplt = batch["side_mentioned_tplt"]
        else:
            lm_mentioned = batch["lm_mentioned"]
            lang_lm_mentioned = batch["lang_lm_mentioned"]

        # stops = self.maybe_cuda(batch["stops"])
        masks = self.maybe_cuda(batch["masks"])
        # This is the first-timestep metadata
        metadata = batch["md"]

        seq_len = images.size(1)
        batch_size = images.size(0)
        count = 0
        correct_goal_count = 0
        goal_count = 0

        # Loop thru batch
        for b in range(batch_size):
            seg_idx = -1

            self.reset()

            self.prof.tick("out")
            b_seq_len = len_until_nones(metadata[b])

            # TODO: Generalize this
            # Slice the data according to the sequence length
            b_metadata = metadata[b][:b_seq_len]
            b_images = images[b][:b_seq_len]
            b_instructions = instructions[b][:b_seq_len]
            b_instr_len = instr_lengths[b][:b_seq_len]
            b_states = states[b][:b_seq_len]
            b_actions = actions[b][:b_seq_len]
            b_lm_pos_fpv = lm_pos_fpv[b][:b_seq_len]
            b_lm_pos_map = lm_pos_map[b][:b_seq_len]
            b_lm_indices = lm_indices[b][:b_seq_len]
            b_goal_pos = goal_pos_map[b][:b_seq_len]
            if not TEMPLATES:
                b_lang_lm_mentioned = lang_lm_mentioned[b][:b_seq_len]
                b_lm_mentioned = lm_mentioned[b][:b_seq_len]

            b_lm_pos_map = [self.cuda_var(s.long()) if s is not None else None for s in b_lm_pos_map]
            b_lm_pos_fpv = [self.cuda_var((s / RESNET_FACTOR).long()) if s is not None else None for s in b_lm_pos_fpv]
            b_lm_indices = [self.cuda_var(s) if s is not None else None for s in b_lm_indices]
            b_goal_pos = self.cuda_var(b_goal_pos)
            if not TEMPLATES:
                b_lang_lm_mentioned = self.cuda_var(b_lang_lm_mentioned)
                b_lm_mentioned = [self.cuda_var(s) if s is not None else None for s in b_lm_mentioned]

            # TODO: Figure out how to keep these properly. Perhaps as a whole batch is best
            # TODO: Introduce a key-value store (encapsulate instead of inherit)
            self.keep_inputs("lm_pos_fpv", b_lm_pos_fpv)
            self.keep_inputs("lm_pos_map", b_lm_pos_map)
            self.keep_inputs("lm_indices", b_lm_indices)
            self.keep_inputs("goal_pos_map", b_goal_pos)
            if not TEMPLATES:
                self.keep_inputs("lang_lm_mentioned", b_lang_lm_mentioned)
                self.keep_inputs("lm_mentioned", b_lm_mentioned)

            # TODO: Abstract all of these if-elses in a modular way once we know which ones are necessary
            if TEMPLATES:
                b_lm_mentioned_tplt = lm_mentioned_tplt[b][:b_seq_len]
                b_side_mentioned_tplt = side_mentioned_tplt[b][:b_seq_len]
                b_side_mentioned_tplt = self.cuda_var(b_side_mentioned_tplt)
                b_lm_mentioned_tplt = self.cuda_var(b_lm_mentioned_tplt)
                self.keep_inputs("lm_mentioned_tplt", b_lm_mentioned_tplt)
                self.keep_inputs("side_mentioned_tplt", b_side_mentioned_tplt)

                b_lm_mentioned = b_lm_mentioned_tplt


            b_obs_mask = [True for _ in range(b_seq_len)]
            b_plan_mask = [True for _ in range(b_seq_len)]
            b_plan_mask_t_cpu = torch.Tensor(b_plan_mask) == True
            b_plan_mask_t = self.maybe_cuda(b_plan_mask_t_cpu)
            b_pos_enc = None

            # ----------------------------------------------------------------------------
            # Optional Auxiliary Inputs
            # ----------------------------------------------------------------------------
            if self.input_required("lm_pos_map_select"):
                b_lm_pos_map_select = [lm_pos for i,lm_pos in enumerate(b_lm_pos_map) if b_plan_mask[i]]
                self.keep_inputs("lm_pos_map_select", b_lm_pos_map_select)
            if self.input_required("lm_indices_select"):
                b_lm_indices_select = [lm_idx for i,lm_idx in enumerate(b_lm_indices) if b_plan_mask[i]]
                self.keep_inputs("lm_indices_select", b_lm_indices_select)
            if self.input_required("lm_mentioned_select"):
                b_lm_mentioned_select = [lm_m for i,lm_m in enumerate(b_lm_mentioned) if b_plan_mask[i]]
                self.keep_inputs("lm_mentioned_select", b_lm_mentioned_select)

            # ----------------------------------------------------------------------------

            self.prof.tick("inputs")

            actions = self(b_images, b_states, b_instructions, b_instr_len,
                           has_obs=b_obs_mask, plan=b_plan_mask, pos_enc=b_pos_enc)

            action_losses, _ = self.action_loss(b_actions, actions, batchreduce=False)

            self.prof.tick("call")

            action_losses = self.action_loss.batch_reduce_loss(action_losses)
            action_loss = self.action_loss.reduce_loss(action_losses)

            action_loss_total = action_loss
            count += b_seq_len

            self.prof.tick("loss")

        action_loss_avg = action_loss_total / (count + 1e-9)

        self.prof.tick("out")

        # Doing this in the end (outside of se
        aux_losses = self.calculate_aux_loss(reduce_average=True)
        aux_loss = self.combine_aux_losses(aux_losses, self.aux_weights)

        prefix = self.model_name + ("/eval" if eval else "/train")

        self.writer.add_dict(prefix, get_current_meters(), self.get_iter())
        self.writer.add_dict(prefix, aux_losses, self.get_iter())
        self.writer.add_scalar(prefix + "/action_loss", action_loss_avg.data.cpu()[0], self.get_iter())
        # TODO: Log value here
        self.writer.add_scalar(prefix + "/goal_accuracy", self.goal_acc_meter.get(), self.get_iter())

        self.prof.tick("auxiliaries")

        total_loss = action_loss_avg + aux_loss

        self.inc_iter()

        self.prof.tick("summaries")
        self.prof.loop()
        self.prof.print_stats(1)

        return total_loss

    def get_dataset(self, data=None, envs=None, dataset_name=None, eval=False):
        # TODO: Maybe use eval here
        #if self.fpv:
        data_sources = []
        # If we're running auxiliary objectives, we need to include the data sources for the auxiliary labels
        #if self.use_aux_class_features or self.use_aux_class_on_map or self.use_aux_grounding_features or self.use_aux_grounding_on_map:
        #if self.use_aux_goal_on_map:
        data_sources.append(aup.PROVIDER_LM_POS_DATA)
        data_sources.append(aup.PROVIDER_GOAL_POS)
        #data_sources.append(aup.PROVIDER_LANDMARKS_MENTIONED)
        data_sources.append(aup.PROVIDER_LANG_TEMPLATE)

        #if self.use_rot_noise or self.use_pos_noise:
        #    data_sources.append(aup.PROVIDER_POSE_NOISE)

        return SegmentDataset(data=data, env_list=envs, dataset_name=dataset_name, aux_provider_names=data_sources, segment_level=True)
Exemplo n.º 19
0
class IdentityIntegratorMap(MapTransformerBase):

    def __init__(self, source_map_size, world_size_px, world_size_m):
        super(IdentityIntegratorMap, self).__init__(source_map_size, world_size_px, world_size_m)
        self.map_size = source_map_size
        self.world_size = world_size_px
        self.world_size_m = world_size_m
        self.child_transformer = MapTransformerBase(source_map_size, world_size_px, world_size_m)

        self.prof = SimpleProfiler(torch_sync=PROFILE, print=PROFILE)
        self.map_memory = MapTransformerBase(source_map_size, world_size_px, world_size_m)

        self.last_observation = None

        self.dbg_t = None
        self.seq = 0

    def init_weights(self):
        pass

    def reset(self):
        super(IdentityIntegratorMap, self).reset()
        self.map_memory.reset()
        self.child_transformer.reset()
        self.seq = 0
        self.last_observation = None

    def cuda(self, device=None):
        MapTransformerBase.cuda(self, device)
        self.child_transformer.cuda(device)
        self.map_memory.cuda(device)
        return self

    def dbg_write_extra(self, map, pose):
        if DebugWriter().should_write():
            map = map[0:1, 0:3]
            self.seq += 1
            # Initialize a transformer module
            if pose is not None:
                if self.dbg_t is None:
                    self.dbg_t = MapTransformerBase(self.map_size, self.world_size, self.world_size_m).to(map.device)

                # Transform the prediction to the global frame and write out to disk.
                self.dbg_t.set_map(map, pose)
                map_global, _ = self.dbg_t.get_map(None)
            else:
                map_global = map
            DebugWriter().write_img(map_global[0], "gif_overlaid", args={"world_size": self.world_size, "name": "identity_integrator"})

    def forward(self, images, cam_poses, add_mask=None, show=False):
        #show="li"
        self.prof.tick(".")
        batch_size = len(cam_poses)

        assert add_mask is None or add_mask[0] is not None, "The first observation in a sequence needs to be used!"

        all_maps_out_r = []

        self.prof.tick("maps_to_global")

        # For each timestep, take the latest map that was available, transformed into this timestep
        # Do only a maximum of one transformation for any map to avoid cascading of errors!
        for i in range(batch_size):

            if add_mask is None or add_mask[i]:
                this_obs = (images[i:i+1], cam_poses[i:i+1])
                self.last_observation = this_obs
            else:
                last_obs = self.last_observation
                assert last_obs is not None, "The first observation in a sequence needs to be used!"

                self.child_transformer.set_map(last_obs[0], last_obs[1])
                this_obs = self.child_transformer.get_map(cam_poses[i:i+1])

            all_maps_out_r.append(this_obs[0])

            if show != "":
                Presenter().show_image(this_obs.data[0, 0:3], show, torch=True, scale=8, waitkey=50)

        self.prof.tick("integrate")

        # Step 3: Convert all maps to local frame
        all_maps_r = torch.cat(all_maps_out_r, dim=0)

        # Write gifs for debugging
        self.dbg_write_extra(all_maps_r, None)

        self.set_maps(all_maps_r, cam_poses)

        self.prof.tick("maps_to_local")
        self.prof.loop()
        self.prof.print_stats(10)

        return all_maps_r, cam_poses

    def forward_deprecated(self, images, cam_poses, add_mask=None, show=False):
        #show="li"
        self.prof.tick(".")
        batch_size = len(cam_poses)

        assert add_mask is None or add_mask[0] is not None, "The first observation in a sequence needs to be used!"

        # Step 1: All local maps to global:
        #  TODO: Allow inputing global maps when new projector is ready
        self.child_transformer.set_maps(images, cam_poses)
        observations_g, _ = self.child_transformer.get_maps(None)

        all_maps_out_g = []

        self.prof.tick("maps_to_global")

        # TODO: Draw past trajectory on an extra channel of the semantic map
        # Step 2: Integrate serially in the global frame
        for i in range(batch_size):

            # If we don't have a map yet, initialize the map to this observation
            if self.map_memory.latest_maps is None:
                self.map_memory.set_map(observations_g[i:i+1], None)

            # Allow masking of observations
            if add_mask is None or add_mask[i]:
                # Use the map from this frame
                map_g = observations_g[i:i+1]
                self.map_memory.set_map(map_g, None)
            else:
                # Use the latest available map oriented in global frame
                map_g, _ = self.map_memory.get_map(None)

            if show != "":
                Presenter().show_image(map_g.data[0, 0:3], show, torch=True, scale=8, waitkey=50)

            all_maps_out_g.append(map_g)

        self.prof.tick("integrate")

        # Step 3: Convert all maps to local frame
        all_maps_g = torch.cat(all_maps_out_g, dim=0)

        # Write gifs for debugging
        self.dbg_write_extra(all_maps_g, None)

        self.child_transformer.set_maps(all_maps_g, None)
        maps_r, _ = self.child_transformer.get_maps(cam_poses)
        self.set_maps(maps_r, cam_poses)

        self.prof.tick("maps_to_local")
        self.prof.loop()
        self.prof.print_stats(10)

        return maps_r, cam_poses
Exemplo n.º 20
0
class FPVToEgoMap(MapTransformerBase):
    def __init__(self,
                 source_map_size, world_size_px,
                 world_size, img_w, img_h,
                 embed_size, map_channels, gnd_channels, res_channels=32,
                 lang_filter=False, img_dbg=False):
        super(FPVToEgoMap, self).__init__(source_map_size, world_size_px)

        self.image_debug = img_dbg
        self.use_lang_filter = lang_filter

        # Process images using a resnet to get a feature map
        if self.image_debug:
            self.img_to_features = nn.MaxPool2d(8)
        else:
            # Provide enough padding so that the map is scaled down by powers of 2.
            self.img_to_features = ImgToFeatures(res_channels, map_channels)

        if self.use_lang_filter:
            self.lang_filter = MapLangSemanticFilter(embed_size, map_channels, gnd_channels)

        # Project feature maps to the global frame
        self.map_projection = PinholeCameraProjectionModule(
            source_map_size, world_size_px, world_size, source_map_size / 2, img_w, img_h)

        self.grid_sampler = GridSampler()

        self.prof = SimpleProfiler(torch_sync=PROFILE, print=PROFILE)

        self.actual_images = None

    def cuda(self, device=None):
        MapTransformerBase.cuda(self, device)
        self.map_projection.cuda(device)
        self.grid_sampler.cuda(device)
        self.img_to_features.cuda(device)
        if self.use_lang_filter:
            self.lang_filter.cuda(device)

    def init_weights(self):
        if not self.image_debug:
            self.img_to_features.init_weights()

    def reset(self):
        self.actual_images = None
        super(FPVToEgoMap, self).reset()

    def forward_fpv_features(self, images, sentence_embeds, parent=None):
        """
        Compute the first-person image features given the first-person images
        If grounding loss is enabled, will also return sentence_embedding conditioned image features
        :param images: images to compute features on
        :param sentence_embeds: sentence embeddings for each image
        :param parent:
        :return: features_fpv_vis - the visual features extracted using the ResNet
                 features_fpv_gnd - the grounded visual features obtained after applying a 1x1 language-conditioned conv
        """
        # Extract image features. If they've been precomputed ahead of time, just grab it by the provided index
        features_fpv_vis = self.img_to_features(images)

        if parent is not None:
            parent.keep_inputs("fpv_features", features_fpv_vis)
        self.prof.tick("feat")

        # If required, pre-process image features by grounding them in language
        if self.use_lang_filter:
            self.lang_filter.precompute_conv_weights(sentence_embeds)
            features_gnd = self.lang_filter(features_fpv_vis)
            if parent is not None:
                parent.keep_inputs("fpv_features_g", features_gnd)
            self.prof.tick("gnd")
            return features_fpv_vis, features_gnd

        return features_fpv_vis, None

    def forward(self, images, poses, sentence_embeds, parent=None, show=""):

        self.prof.tick("out")

        features_fpv_vis_only, features_fpv_gnd_only = self.forward_fpv_features(images, sentence_embeds, parent)

        # If we have grounding features, the overall features are a concatenation of grounded and non-grounded features
        if features_fpv_gnd_only is not None:
            features_fpv_all = torch.cat([features_fpv_gnd_only, features_fpv_vis_only], dim=1)
        else:
            features_fpv_all = features_fpv_vis_only

        # Project first-person view features on to the map in egocentric frame
        grid_maps = self.map_projection(poses)
        self.prof.tick("proj_map")
        features_r = self.grid_sampler(features_fpv_all, grid_maps)

        # Obtain an ego-centric map mask of where we have new information
        ones_size = list(features_fpv_all.size())
        ones_size[1] = 1
        tmp_ones = empty_float_tensor(ones_size, self.is_cuda, self.cuda_device).fill_(1.0)
        new_coverages = self.grid_sampler(tmp_ones, grid_maps)

        # Make sure that new_coverage is a 0/1 mask (grid_sampler applies bilinear interpolation)
        new_coverages = new_coverages - torch.min(new_coverages)
        new_coverages = new_coverages / torch.max(new_coverages)

        self.prof.tick("gsample")

        if show != "":
            Presenter().show_image(images.data[0, 0:3], show + "_img", torch=True, scale=1, waitkey=1)
            Presenter().show_image(features_r.data[0, 0:3], show, torch=True, scale=6, waitkey=1)
            Presenter().show_image(new_coverages.data[0], show + "_covg", torch=True, scale=6, waitkey=1)

        self.prof.loop()
        self.prof.print_stats(10)

        return features_r, new_coverages
Exemplo n.º 21
0
    def train_epoch(self,
                    env_list_common=None,
                    env_list_sim=None,
                    data_list_real=None,
                    data_list_sim=None,
                    eval=False):

        if eval:
            self.model_real.eval()
            self.model_sim.eval()
            self.model_critic.eval()
            inference_type = "eval"
            epoch_num = self.train_epoch_num
            self.test_epoch_num += 1
        else:
            self.model_real.train()
            self.model_sim.train()
            self.model_critic.train()
            inference_type = "train"
            epoch_num = self.train_epoch_num
            self.train_epoch_num += 1

        dataset_real_common = self.model_real.get_dataset(
            data=data_list_real,
            envs=env_list_common,
            domain="real",
            dataset_names=self.real_dataset_names,
            dataset_prefix="supervised",
            eval=eval)
        dataset_sim_common = self.model_real.get_dataset(
            data=data_list_real,
            envs=env_list_common,
            domain="sim",
            dataset_names=self.sim_dataset_names,
            dataset_prefix="supervised",
            eval=eval)
        dataset_real_halfway = self.model_real.get_dataset(
            data=data_list_real,
            envs=env_list_common,
            domain="real",
            dataset_names=self.real_dataset_names,
            dataset_prefix="supervised",
            eval=eval,
            halfway_only=True)
        dataset_sim_halfway = self.model_real.get_dataset(
            data=data_list_real,
            envs=env_list_common,
            domain="sim",
            dataset_names=self.sim_dataset_names,
            dataset_prefix="supervised",
            eval=eval,
            halfway_only=True)
        dataset_sim = self.model_sim.get_dataset(
            data=data_list_sim,
            envs=env_list_sim,
            domain="sim",
            dataset_names=self.sim_dataset_names,
            dataset_prefix="supervised",
            eval=eval)

        print("Beginning supervised epoch:")
        print("   Sim dataset names: ", self.sim_dataset_names)
        print("   Dataset sizes: ")
        print("   dataset_real_common  ", len(dataset_real_common))
        print("   dataset_sim_common  ", len(dataset_sim_common))
        print("   dataset_real_halfway  ", len(dataset_real_halfway))
        print("   dataset_sim_halfway  ", len(dataset_sim_halfway))
        print("   dataset_sim  ", len(dataset_sim))
        print("   env_list_sim_len ", len(env_list_sim))
        print("   env_list_common_len ", len(env_list_common))
        if len(dataset_sim) == 0 or len(dataset_sim_common) == 0:
            print("Missing data! Waiting for RL to generate it?")
            return 0

        dual_model_loader = DualDataloader(
            dataset_a=dataset_real_common,
            dataset_b=dataset_sim_common,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.model_common_loaders,
            pin_memory=False,
            timeout=0,
            drop_last=False,
            joint_length="max")

        sim_loader = DataLoader(dataset=dataset_sim,
                                collate_fn=dataset_sim.collate_fn,
                                batch_size=self.batch_size,
                                shuffle=True,
                                num_workers=self.model_sim_loaders,
                                pin_memory=False,
                                timeout=0,
                                drop_last=False)
        sim_iterator = iter(sim_loader)

        dual_critic_loader = DualDataloader(dataset_a=dataset_real_halfway,
                                            dataset_b=dataset_sim_halfway,
                                            batch_size=self.batch_size,
                                            shuffle=True,
                                            num_workers=self.critic_loaders,
                                            pin_memory=False,
                                            timeout=0,
                                            drop_last=False,
                                            joint_length="infinite")
        dual_critic_iterator = iter(dual_critic_loader)

        wloss_before_updates_writer = LoggingSummaryWriter(
            log_dir=
            f"{get_logging_dir()}/runs/{self.run_name}/discriminator_before_updates"
        )
        wloss_after_updates_writer = LoggingSummaryWriter(
            log_dir=
            f"{get_logging_dir()}/runs/{self.run_name}/discriminator_after_updates"
        )

        samples_real = len(dataset_real_common)
        samples_common = len(dataset_sim_common)
        samples_sim = len(dataset_sim)
        if samples_real == 0 or samples_sim == 0 or samples_common == 0:
            print(
                f"DATASET HAS NO DATA: REAL: {samples_real > 0}, SIM: {samples_sim > 0}, COMMON: {samples_common}"
            )
            return -1.0

        num_batches = len(dual_model_loader)

        epoch_loss = 0
        count = 0
        critic_elapsed_iterations = 0

        prof = SimpleProfiler(torch_sync=PROFILE, print=PROFILE)

        prof.tick("out")

        # Alternate training critic and model
        for real_batch, sim_batch in dual_model_loader:
            if real_batch is None or sim_batch is None:
                print("none")
                continue

            prof.tick("load_model_data")
            # Train the model for model_steps in a row, then train the critic, and repeat
            critic_batch_num = 0

            if count % self.model_steps == 0 and not eval and not self.disable_wloss:
                #print("\nTraining critic\n")
                # Train the critic for self.critic_steps steps
                for cstep in range(self.critic_steps):
                    # Each batch is actually a single rollout (we batch the rollout data across the sequence)
                    # To collect a batch of rollouts, we need to keep iterating
                    real_store = KeyTensorStore()
                    sim_store = KeyTensorStore()
                    for b in range(self.critic_batch_size):
                        # Get the next non-None batch
                        real_c_batch, sim_c_batch = None, None
                        while real_c_batch is None or sim_c_batch is None:
                            real_c_batch, sim_c_batch = next(
                                dual_critic_iterator)
                        prof.tick("critic_load_data")
                        # When training the critic, we don't backprop into the model, so we don't need gradients here
                        with torch.no_grad():
                            real_loss, real_store_b = self.model_real.sup_loss_on_batch(
                                real_c_batch, eval=eval, halfway=True)
                            sim_loss, sim_store_b = self.model_sim.sup_loss_on_batch(
                                sim_c_batch, eval=eval, halfway=True)
                        prof.tick("critic_features")
                        real_store.append(real_store_b)
                        sim_store.append(sim_store_b)
                        prof.tick("critic_store_append")

                    # Forward the critic
                    # The real_store and sim_store should really be a batch of multiple rollouts
                    wdist_loss_a, critic_store = self.model_critic.calc_domain_loss(
                        real_store, sim_store)

                    prof.tick("critic_domain_loss")

                    # Store the first and last critic loss
                    if cstep == 0:
                        wdist_loss_before_updates = wdist_loss_a.detach().cpu()
                    if cstep == self.critic_steps - 1:
                        wdist_loss_after_updates = wdist_loss_a.detach().cpu()

                    if self.model_oracle_critic:
                        wdist_loss_oracle, oracle_store = self.model_oracle_critic.calc_domain_loss(
                            real_store, sim_store)
                        wdist_loss_a += wdist_loss_oracle

                    # Update the critic
                    critic_batch_num += 1
                    self.optim_critic.zero_grad()
                    # Wasserstein distance is maximum distance transport cost under Lipschitz constraint, so we maximize it
                    (-wdist_loss_a).backward()
                    self.optim_critic.step()
                    #sys.stdout.write(f"\r    Critic batch: {critic_batch_num}/{critic_steps} d_loss: {wdist_loss_a.data.item()}")
                    #sys.stdout.flush()
                    prof.tick("critic_backward")

                # Write wasserstein loss before and after wasertein loss updates
                prefix = "pvn_critic" + ("/eval" if eval else "/train")
                wloss_before_updates_writer.add_scalar(
                    f"{prefix}/w_score_before_updates",
                    wdist_loss_before_updates.item(),
                    self.model_critic.get_iter())
                wloss_after_updates_writer.add_scalar(
                    f"{prefix}/w_score_before_updates",
                    wdist_loss_after_updates.item(),
                    self.model_critic.get_iter())

                critic_elapsed_iterations += 1

                # Clean up GPU memory
                del wdist_loss_a
                del critic_store
                del real_store
                del sim_store
                del real_store_b
                del sim_store_b
                prof.tick("del")

            # Forward the model on the bi-domain data
            disable_losses = [
                "visitation_dist"
            ] if self.params.get("disable_real_loss") else []
            real_loss, real_store = self.model_real.sup_loss_on_batch(
                real_batch,
                eval,
                halfway=False,
                grad_noise=False,
                disable_losses=disable_losses)
            sim_loss, sim_store = self.model_sim.sup_loss_on_batch(
                sim_batch, eval, halfway=False)
            prof.tick("model_forward")

            # Forward the model K times on simulation only data
            for b in range(self.sim_steps_per_common_step):
                # Get the next non-None batch
                sim_batch = None
                while sim_batch is None:
                    try:
                        sim_batch = next(sim_iterator)
                    except StopIteration as e:
                        sim_iterator = iter(sim_loader)
                        print("retry")
                        continue
                prof.tick("load_model_data")
                sim_loss_b, _ = self.model_sim.sup_loss_on_batch(sim_batch,
                                                                 eval,
                                                                 halfway=False)
                sim_loss = (sim_loss + sim_loss_b) if sim_loss else sim_loss_b
                #print(f"  Model forward common sim, loss: {sim_loss_b.detach().cpu().item()}")
                prof.tick("model_forward")

            prof.tick("model_forward")

            # TODO: Reconsider this
            sim_loss = sim_loss / max(self.model_batch_size,
                                      self.sim_steps_per_common_step)
            real_loss = real_loss / max(self.model_batch_size,
                                        self.sim_steps_per_common_step)
            total_loss = real_loss + sim_loss

            if not self.disable_wloss:
                # Forward the critic
                wdist_loss_b, critic_store = self.model_critic.calc_domain_loss(
                    real_store, sim_store)
                # Increase the iteration on the oracle without running it so that the Tensorboard plots align
                if self.model_oracle_critic:
                    self.model_oracle_critic.iter += 1

                prof.tick("model_domain_loss")
                # Minimize average real/sim losses, maximize domain loss
                total_loss = total_loss + wdist_loss_b

            # Grad step
            if not eval and total_loss.requires_grad:
                self.optim_models.zero_grad()
                total_loss.backward()
                self.optim_models.step()
                prof.tick("model_backward")

            print(
                f"Batch: {count}/{num_batches} r_loss: {real_loss.data.item() if real_loss else None} s_loss: {sim_loss.data.item()}"
            )

            # Get losses as floats
            epoch_loss += total_loss.data.item()
            count += 1

            self.train_segment += 0 if eval else 1
            self.test_segment += 1 if eval else 0

            prof.loop()
            prof.print_stats(self.model_steps)

            if self.iterations_per_epoch and count > self.iterations_per_epoch:
                break

        print("")
        epoch_loss /= (count + 1e-15)

        return epoch_loss
Exemplo n.º 22
0
class ModelGSFPV(nn.Module):
    def __init__(self,
                 run_name="",
                 aux_class_features=False,
                 aux_grounding_features=False,
                 aux_lang=False,
                 recurrence=False):

        super(ModelGSFPV, self).__init__()
        self.model_name = "gs_fpv" + "_mem" if recurrence else ""
        self.run_name = run_name
        self.writer = LoggingSummaryWriter(log_dir="runs/" + run_name)

        self.params = get_current_parameters()["Model"]
        self.aux_weights = get_current_parameters()["AuxWeights"]

        self.prof = SimpleProfiler(torch_sync=PROFILE, print=PROFILE)
        self.iter = nn.Parameter(torch.zeros(1), requires_grad=False)

        # Auxiliary Objectives
        self.use_aux_class_features = aux_class_features
        self.use_aux_grounding_features = aux_grounding_features
        self.use_aux_lang = aux_lang
        self.use_recurrence = recurrence

        self.img_to_features_w = FPVToFPVMap(self.params["img_w"],
                                             self.params["img_h"],
                                             self.params["resnet_channels"],
                                             self.params["feature_channels"])

        self.lang_filter_gnd = MapLangSemanticFilter(
            self.params["emb_size"], self.params["feature_channels"],
            self.params["relevance_channels"])

        self.lang_filter_goal = MapLangSpatialFilter(
            self.params["emb_size"], self.params["relevance_channels"],
            self.params["goal_channels"])

        self.map_downsample = DownsampleResidual(
            self.params["map_to_act_channels"], 2)

        self.recurrence = RecurrentEmbedding(
            self.params["gs_fpv_feature_map_size"],
            self.params["gs_fpv_recurrence_size"])

        # Sentence Embedding
        self.sentence_embedding = SentenceEmbeddingSimple(
            self.params["word_emb_size"], self.params["emb_size"],
            self.params["emb_layers"])

        in_features_size = self.params[
            "gs_fpv_feature_map_size"] + self.params["emb_size"]
        if self.use_recurrence:
            in_features_size += self.params["gs_fpv_recurrence_size"]

        self.features_to_action = DenseMlpBlock2(in_features_size,
                                                 self.params["mlp_hidden"], 4)

        # Auxiliary Objectives
        # --------------------------------------------------------------------------------------------------------------

        self.add_auxiliary(
            ClassAuxiliary2D("aux_class", None,
                             self.params["feature_channels"],
                             self.params["num_landmarks"], "fpv_features",
                             "lm_pos_fpv", "lm_indices"))
        self.add_auxiliary(
            ClassAuxiliary2D("aux_ground", None,
                             self.params["relevance_channels"], 2,
                             "fpv_features_g", "lm_pos_fpv", "lm_mentioned"))
        if self.params["templates"]:
            self.add_auxiliary(
                ClassAuxiliary("aux_lang_lm", self.params["emb_size"],
                               self.params["num_landmarks"], 1,
                               "sentence_embed", "lm_mentioned_tplt"))
            self.add_auxiliary(
                ClassAuxiliary("aux_lang_side", self.params["emb_size"],
                               self.params["num_sides"], 1, "sentence_embed",
                               "side_mentioned_tplt"))
        else:
            self.add_auxiliary(
                ClassAuxiliary("aux_lang_lm_nl", self.params["emb_size"], 2,
                               self.params["num_landmarks"], "sentence_embed",
                               "lang_lm_mentioned"))

        self.action_loss = ActionLoss()

        self.env_id = None
        self.prev_instruction = None
        self.seq_step = 0

    # TODO: Try to hide these in a superclass or something. They take up a lot of space:
    def cuda(self, device=None):
        ModuleWithAuxiliaries.cuda(self, device)
        self.sentence_embedding.cuda(device)
        self.img_to_features_w.cuda(device)
        self.lang_filter_gnd.cuda(device)
        self.lang_filter_goal.cuda(device)
        self.action_loss.cuda(device)
        self.recurrence.cuda(device)
        return self

    def get_iter(self):
        return int(self.iter.data[0])

    def inc_iter(self):
        self.iter += 1

    def init_weights(self):
        self.img_to_features_w.init_weights()
        self.lang_filter_gnd.init_weights()
        self.lang_filter_goal.init_weights()
        self.sentence_embedding.init_weights()

    def reset(self):
        # TODO: This is error prone. Create a class StatefulModule, iterate submodules and reset all stateful modules
        super(ModelGSFPV, self).reset()
        self.sentence_embedding.reset()
        self.img_to_features_w.reset()
        self.recurrence.reset()
        self.prev_instruction = None
        print("GS_FPV_MEM_RESET")

    def setEnvContext(self, context):
        print("Set env context to: " + str(context))
        self.env_id = context["env_id"]

    def start_segment_rollout(self, *args):
        self.reset()

    def get_action(self, state, instruction):
        """
        Given a DroneState (from PomdpInterface) and instruction, produce a numpy 4D action (x, y, theta, pstop)
        :param state: DroneState object with the raw image from the simulator
        :param instruction: Tokenized instruction given the corpus
        #TODO: Absorb corpus within model
        :return:
        """
        # TODO: Simplify this
        self.eval()
        images_np_pure = state.image
        state_np = state.state

        #print("Act: " + debug_untokenize_instruction(instruction))

        images_np = standardize_image(images_np_pure)
        image_fpv = Variable(none_padded_seq_to_tensor([images_np]))
        state = Variable(none_padded_seq_to_tensor([state_np]))
        self.prev_instruction = instruction

        img_in_t = image_fpv
        img_in_t.volatile = True

        instr_len = [len(instruction)] if instruction is not None else None
        instruction = torch.LongTensor(instruction).unsqueeze(0)
        instruction = cuda_var(instruction, self.is_cuda, self.cuda_device)

        state.volatile = True

        if self.is_cuda:
            img_in_t = img_in_t.cuda(self.cuda_device)
            state = state.cuda(self.cuda_device)

        self.seq_step += 1

        action = self(img_in_t, state, instruction, instr_len)

        output_action = action.squeeze().data.cpu().numpy()
        print("action: ", output_action)

        stop_prob = output_action[3]
        output_stop = 1 if stop_prob > self.params["stop_threshold"] else 0
        output_action[3] = output_stop

        return output_action

    def deterministic_action(self, action_mean, action_std, stop_prob):
        batch_size = action_mean.size(0)
        action = Variable(
            empty_float_tensor((batch_size, 4), self.is_cuda,
                               self.cuda_device))
        action[:, 0:3] = action_mean[:, 0:3]
        action[:, 3] = stop_prob
        return action

    def sample_action(self, action_mean, action_std, stop_prob):
        action = torch.normal(action_mean, action_std)
        stop = torch.bernoulli(stop_prob)
        return action, stop

    # This is called before beginning an execution sequence
    def start_sequence(self):
        self.seq_step = 0
        self.reset()
        print("RESETTED!")
        return

    # TODO: Move this somewhere and standardize
    def cam_poses_from_states(self, states):
        cam_pos = states[:, 9:12]
        cam_rot = states[:, 12:16]
        pose = Pose(cam_pos, cam_rot)
        return pose

    def forward(self, images, states, instructions, instr_lengths):
        """
        :param images: BxCxHxW batch of images (observations)
        :param states: BxK batch of drone states
        :param instructions: BxM LongTensor where M is the maximum length of any instruction
        :param instr_lengths: list of len B of integers, indicating length of each instruction
        :param has_obs: list of booleans of length B indicating whether the given element in the sequence has an observation
        :param yield_semantic_maps: If true, will not compute actions (full model), but return the semantic maps that
            were built along the way in response to the images. This is ugly, but allows code reuse
        :return:
        """
        cam_poses = self.cam_poses_from_states(states)
        self.prof.tick("out")

        #print("Trn: " + debug_untokenize_instruction(instructions[0].data[:instr_lengths[0]]))

        # Calculate the instruction embedding
        if instructions is not None:
            # TODO: Take batch of instructions and their lengths, return batch of embeddings. Store the last one as internal state
            sent_embeddings = self.sentence_embedding(instructions,
                                                      instr_lengths)
            self.keep_inputs("sentence_embed", sent_embeddings)
        else:
            sent_embeddings = self.sentence_embedding.get()

        self.prof.tick("embed")

        seq_size = len(images)

        # Extract and project features onto the egocentric frame for each image
        fpv_features = self.img_to_features_w(images,
                                              cam_poses,
                                              sent_embeddings,
                                              self,
                                              show="")

        self.keep_inputs("fpv_features", fpv_features)
        self.prof.tick("img_to_map_frame")

        self.lang_filter_gnd.precompute_conv_weights(sent_embeddings)
        self.lang_filter_goal.precompute_conv_weights(sent_embeddings)

        gnd_features = self.lang_filter_gnd(fpv_features)
        goal_features = self.lang_filter_goal(gnd_features)

        self.keep_inputs("fpv_features_g", gnd_features)
        visual_features = torch.cat([gnd_features, goal_features], dim=1)

        lstm_in_features = visual_features.view([seq_size, 1, -1])

        catlist = [lstm_in_features.view([seq_size, -1]), sent_embeddings]

        if self.use_recurrence:
            memory_features = self.recurrence(lstm_in_features)
            catlist.append(memory_features[:, 0, :])

        action_features = torch.cat(catlist, dim=1)

        # Output the final action given the processed map
        action_pred = self.features_to_action(action_features)
        action_pred[:, 3] = torch.sigmoid(action_pred[:, 3])
        out_action = self.deterministic_action(action_pred[:, 0:3], None,
                                               action_pred[:, 3])
        self.prof.tick("map_to_action")

        return out_action

    def maybe_cuda(self, tensor):
        if self.is_cuda:
            return tensor.cuda()
        else:
            return tensor

    def cuda_var(self, tensor):
        return cuda_var(tensor, self.is_cuda, self.cuda_device)

    # Forward pass for training (with batch optimizations
    def sup_loss_on_batch(self, batch, eval):
        self.prof.tick("out")

        action_loss_total = Variable(
            empty_float_tensor([1], self.is_cuda, self.cuda_device))

        if batch is None:
            print("Skipping None Batch")
            return action_loss_total

        images = self.maybe_cuda(batch["images"])

        instructions = self.maybe_cuda(batch["instr"])
        instr_lengths = batch["instr_len"]
        states = self.maybe_cuda(batch["states"])
        actions = self.maybe_cuda(batch["actions"])

        # Auxiliary labels
        lm_pos_fpv = batch["lm_pos_fpv"]
        lm_indices = batch["lm_indices"]
        lm_mentioned = batch["lm_mentioned"]
        lang_lm_mentioned = batch["lang_lm_mentioned"]

        templates = get_current_parameters()["Environment"]["Templates"]
        if templates:
            lm_mentioned_tplt = batch["lm_mentioned_tplt"]
            side_mentioned_tplt = batch["side_mentioned_tplt"]

        # stops = self.maybe_cuda(batch["stops"])
        masks = self.maybe_cuda(batch["masks"])
        metadata = batch["md"]

        seq_len = images.size(1)
        batch_size = images.size(0)
        count = 0
        correct_goal_count = 0
        goal_count = 0

        # Loop thru batch
        for b in range(batch_size):
            seg_idx = -1

            self.reset()

            self.prof.tick("out")
            b_seq_len = len_until_nones(metadata[b])

            # TODO: Generalize this
            # Slice the data according to the sequence length
            b_metadata = metadata[b][:b_seq_len]
            b_images = images[b][:b_seq_len]
            b_instructions = instructions[b][:b_seq_len]
            b_instr_len = instr_lengths[b][:b_seq_len]
            b_states = states[b][:b_seq_len]
            b_actions = actions[b][:b_seq_len]
            b_lm_pos_fpv = lm_pos_fpv[b][:b_seq_len]
            b_lm_indices = lm_indices[b][:b_seq_len]
            b_lm_mentioned = lm_mentioned[b][:b_seq_len]

            b_lm_pos_fpv = [
                self.cuda_var(
                    (s / RESNET_FACTOR).long()) if s is not None else None
                for s in b_lm_pos_fpv
            ]
            b_lm_indices = [
                self.cuda_var(s) if s is not None else None
                for s in b_lm_indices
            ]
            b_lm_mentioned = [
                self.cuda_var(s) if s is not None else None
                for s in b_lm_mentioned
            ]

            # TODO: Figure out how to keep these properly. Perhaps as a whole batch is best
            # TODO: Introduce a key-value store (encapsulate instead of inherit)
            self.keep_inputs("lm_pos_fpv", b_lm_pos_fpv)
            self.keep_inputs("lm_indices", b_lm_indices)
            self.keep_inputs("lm_mentioned", b_lm_mentioned)

            # TODO: Abstract all of these if-elses in a modular way once we know which ones are necessary
            if templates:
                b_lm_mentioned_tplt = lm_mentioned_tplt[b][:b_seq_len]
                b_side_mentioned_tplt = side_mentioned_tplt[b][:b_seq_len]
                b_side_mentioned_tplt = self.cuda_var(b_side_mentioned_tplt)
                b_lm_mentioned_tplt = self.cuda_var(b_lm_mentioned_tplt)
                self.keep_inputs("lm_mentioned_tplt", b_lm_mentioned_tplt)
                self.keep_inputs("side_mentioned_tplt", b_side_mentioned_tplt)
            else:
                b_lang_lm_mentioned = self.cuda_var(
                    lang_lm_mentioned[b][:b_seq_len])
                self.keep_inputs("lang_lm_mentioned", b_lang_lm_mentioned)

            # ----------------------------------------------------------------------------

            self.prof.tick("inputs")

            actions = self(b_images, b_states, b_instructions, b_instr_len)

            action_losses, _ = self.action_loss(b_actions,
                                                actions,
                                                batchreduce=False)

            self.prof.tick("call")
            action_losses = self.action_loss.batch_reduce_loss(action_losses)
            action_loss = self.action_loss.reduce_loss(action_losses)
            action_loss_total = action_loss
            count += b_seq_len

            self.prof.tick("loss")

        action_loss_avg = action_loss_total / (count + 1e-9)

        self.prof.tick("out")

        # Doing this in the end (outside of se
        aux_losses = self.calculate_aux_loss(reduce_average=True)
        aux_loss = self.combine_aux_losses(aux_losses, self.aux_weights)

        prefix = self.model_name + ("/eval" if eval else "/train")

        self.writer.add_dict(prefix, get_current_meters(), self.get_iter())
        self.writer.add_dict(prefix, aux_losses, self.get_iter())
        self.writer.add_scalar(prefix + "/action_loss",
                               action_loss_avg.data.cpu()[0], self.get_iter())

        self.prof.tick("auxiliaries")

        total_loss = action_loss_avg + aux_loss

        self.inc_iter()

        self.prof.tick("summaries")
        self.prof.loop()
        self.prof.print_stats(1)

        return total_loss

    def get_dataset(self,
                    data=None,
                    envs=None,
                    dataset_names=None,
                    dataset_prefix=None,
                    eval=False):
        # TODO: Maybe use eval here
        #if self.fpv:
        data_sources = []
        data_sources.append(aup.PROVIDER_LM_POS_DATA)
        data_sources.append(aup.PROVIDER_LANDMARKS_MENTIONED)

        templates = get_current_parameters()["Environment"]["Templates"]
        if templates:
            data_sources.append(aup.PROVIDER_LANG_TEMPLATE)

        return SegmentDataset(data=data,
                              env_list=envs,
                              dataset_names=dataset_names,
                              dataset_prefix=dataset_prefix,
                              aux_provider_names=data_sources,
                              segment_level=True)
Exemplo n.º 23
0
class ModelMisra2017(ModuleWithAuxiliaries):
    def __init__(self, run_name=""):

        super(ModelMisra2017, self).__init__()
        self.model_name = "misra2017"
        self.run_name = run_name
        self.writer = LoggingSummaryWriter(log_dir="runs/" + run_name)

        self.params = get_current_parameters()["Model"]
        self.trajectory_len = get_current_parameters(
        )["Setup"]["trajectory_length"]

        self.prof = SimpleProfiler(torch_sync=PROFILE, print=PROFILE)
        self.iter = nn.Parameter(torch.zeros(1), requires_grad=False)

        # CNN over images - using what is essentially SimpleImage currently
        self.image_module = ImageCnnEmnlp(
            image_emb_size=self.params["image_emb_dim"],
            input_num_channels=3 *
            5,  # 3 channels per image - 5 images in history
            image_height=self.params["img_h"],
            image_width=self.params["img_w"])

        # LSTM to embed text
        self.text_module = TextSimpleModule(
            emb_dim=self.params["word_emb_dim"],
            hidden_dim=self.params["emb_size"],
            vocab_size=self.params["vocab_size"])

        # Action module to embed previous action+block
        self.action_module = ActionSimpleModule(
            num_actions=self.params["num_actions"],
            action_emb_size=self.params["action_emb_dim"])

        # Put it all together
        self.final_module = IncrementalMultimodalEmnlp(
            image_module=self.image_module,
            text_module=self.text_module,
            action_module=self.action_module,
            input_embedding_size=self.params["lstm_emb_dim"] +
            self.params["image_emb_dim"] + self.params["action_emb_dim"],
            output_hidden_size=self.params["h1_hidden_dim"],
            blocks_hidden_size=self.params["blocks_hidden_dim"],
            directions_hidden_size=self.params["action_hidden_dim"],
            max_episode_length=self.trajectory_len)

        self.action_loss = ActionLoss()

        self.env_id = None
        self.prev_instruction = None
        self.seq_step = 0
        self.model_state = None
        self.image_emb_seq = None
        self.state_feature = None

    # TODO: Try to hide these in a superclass or something. They take up a lot of space:
    def cuda(self, device=None):
        ModuleWithAuxiliaries.cuda(self, device)
        self.image_module.cuda(device)
        self.text_module.cuda(device)
        self.final_module.cuda(device)
        self.action_module.cuda(device)
        self.action_loss.cuda(device)
        return self

    def get_iter(self):
        return int(self.iter.data[0])

    def inc_iter(self):
        self.iter += 1

    def init_weights(self):
        self.final_module.init_weights()

    def reset(self):
        # TODO: This is error prone. Create a class StatefulModule, iterate submodules and reset all stateful modules
        super(ModelMisra2017, self).reset()
        self.seq_step = 0
        self.model_state = None
        pass

    def setEnvContext(self, context):
        print("Set env context to: " + str(context))
        self.env_id = context["env_id"]

    def get_action(self, state, instruction):
        """
        Given a DroneState (from PomdpInterface) and instruction, produce a numpy 4D action (x, y, theta, pstop)
        :param state: DroneState object with the raw image from the simulator
        :param instruction: Tokenized instruction given the corpus
        #TODO: Absorb corpus within model
        :return:
        """
        # TODO: Simplify this
        self.eval()
        images_np_pure = state.image
        state_np = state.state

        #print("Act: " + debug_untokenize_instruction(instruction))

        images_np = standardize_image(images_np_pure)
        image_fpv = Variable(none_padded_seq_to_tensor([images_np]))
        state = Variable(none_padded_seq_to_tensor([state_np]))
        # Add the batch dimension

        first_step = True
        if instruction == self.prev_instruction:
            first_step = False
        self.prev_instruction = instruction

        img_in_t = image_fpv
        img_in_t.volatile = True

        instr_len = [len(instruction)] if instruction is not None else None
        instruction = torch.LongTensor(instruction).unsqueeze(0)
        instruction = cuda_var(instruction, self.is_cuda, self.cuda_device)

        state.volatile = True

        if self.is_cuda:
            img_in_t = img_in_t.cuda(self.cuda_device)
            state = state.cuda(self.cuda_device)

        self.seq_step += 1

        action = self(img_in_t, instruction, instr_len)

        output_action = action.squeeze().data.cpu().numpy()
        stop_prob = output_action[3]
        output_stop = 1 if stop_prob > 0.5 else 0
        output_action[3] = output_stop

        #print("action: ", output_action)

        return output_action

    def deterministic_action(self, action_mean, action_std, stop_prob):
        batch_size = action_mean.size(0)
        action = Variable(
            empty_float_tensor((batch_size, 4), self.is_cuda,
                               self.cuda_device))
        action[:, 0:3] = action_mean[:, 0:3]
        action[:, 3] = stop_prob
        return action

    def sample_action(self, action_mean, action_std, stop_prob):
        action = torch.normal(action_mean, action_std)
        stop = torch.bernoulli(stop_prob)
        return action, stop

    # This is called before beginning an execution sequence
    def start_sequence(self):
        self.seq_step = 0
        self.reset()
        print("RESETTED!")
        return

    # TODO: Move this somewhere and standardize
    def cam_poses_from_states(self, states):
        cam_pos = states[:, 9:12]
        cam_rot = states[:, 12:16]
        pose = Pose(cam_pos, cam_rot)
        return pose

    def instructions_to_dipandrew(self, instructions, instr_lengths):
        out = []
        for i in range(len(instructions)):
            instr_i = instructions[i:i + 1, 0:instr_lengths[i]]
            out.append(instr_i)
        return out

    def forward(self, images, instructions, instr_lengths):

        seq_len = len(images)

        instr_dipandrew = self.instructions_to_dipandrew(
            instructions, instr_lengths)

        # Add sequence dimension, since we're treating batches as sequences
        images = images.unsqueeze(0)

        all_actions = []
        for i in range(seq_len):
            time_in = np.asarray([self.seq_step])
            time_in = Variable(
                self.maybe_cuda(torch.from_numpy(time_in).long()))
            action_i, self.model_state = self.final_module(
                images[0:1, i:i + 1], instr_dipandrew[i], time_in,
                self.model_state)

            self.seq_step += 1
            all_actions.append(action_i)

        actions = torch.cat(all_actions, dim=0)
        return actions

    def maybe_cuda(self, tensor):
        if self.is_cuda:
            return tensor.cuda()
        else:
            return tensor

    def cuda_var(self, tensor):
        return cuda_var(tensor, self.is_cuda, self.cuda_device)

    # Forward pass for training (with batch optimizations
    def sup_loss_on_batch(self, batch, eval):
        self.prof.tick("out")

        action_loss_total = Variable(
            empty_float_tensor([1], self.is_cuda, self.cuda_device))

        if batch is None:
            print("Skipping None Batch")
            return action_loss_total

        images = self.maybe_cuda(batch["images"])
        instructions = self.maybe_cuda(batch["instr"])
        instr_lengths = batch["instr_len"]
        actions = self.maybe_cuda(batch["actions"])

        metadata = batch["md"]

        batch_size = images.size(0)
        count = 0

        # Loop thru batch
        for b in range(batch_size):
            self.reset()
            self.prof.tick("out")
            b_seq_len = len_until_nones(metadata[b])

            # TODO: Generalize this
            # Slice the data according to the sequence length
            b_metadata = metadata[b][:b_seq_len]
            b_images = images[b][:b_seq_len]
            b_instructions = instructions[b][:b_seq_len]
            b_instr_len = instr_lengths[b][:b_seq_len]
            b_actions = actions[b][:b_seq_len]

            # ----------------------------------------------------------------------------

            self.prof.tick("inputs")

            actions = self(b_images, b_instructions, b_instr_len)

            action_losses, _ = self.action_loss(b_actions,
                                                actions,
                                                batchreduce=False)

            self.prof.tick("call")
            action_losses = self.action_loss.batch_reduce_loss(action_losses)
            action_loss = self.action_loss.reduce_loss(action_losses)
            action_loss_total = action_loss
            count += b_seq_len

            self.prof.tick("loss")

        action_loss_avg = action_loss_total / (count + 1e-9)

        self.prof.tick("out")

        prefix = self.model_name + ("/eval" if eval else "/train")

        self.writer.add_dict(prefix, get_current_meters(), self.get_iter())
        self.writer.add_scalar(prefix + "/action_loss",
                               action_loss_avg.data.cpu()[0], self.get_iter())

        total_loss = action_loss_avg

        self.inc_iter()

        self.prof.loop()
        self.prof.print_stats(1)

        return total_loss

    def get_dataset(self, data=None, envs=None, dataset_name=None, eval=False):
        # TODO: Maybe use eval here
        #if self.fpv:
        return SegmentDataset(data=data,
                              env_list=envs,
                              dataset_name=dataset_name,
                              aux_provider_names=[],
                              segment_level=True)
Exemplo n.º 24
0
    def train_epoch(self, train_data=None, train_envs=None, eval=False):
        if eval:
            self.model.eval()
            inference_type = "eval"
            epoch_num = self.train_epoch_num
            self.test_epoch_num += 1
        else:
            self.model.train()
            inference_type = "train"
            epoch_num = self.train_epoch_num
            self.train_epoch_num += 1

        dataset = self.model.get_dataset(data=train_data,
                                         envs=train_envs,
                                         dataset_name="supervised",
                                         eval=eval)
        # TODO: Get rid of this:
        if hasattr(dataset, "set_word2token"):
            dataset.set_word2token(self.token2word, self.word2token)

        dataloader = DataLoader(dataset,
                                collate_fn=dataset.collate_fn,
                                batch_size=self.batch_size,
                                shuffle=True,
                                num_workers=self.num_loaders,
                                pin_memory=False,
                                timeout=0,
                                drop_last=False)

        num_samples = len(dataset)
        if num_samples == 0:
            print("DATASET HAS NO DATA!")
            return -1.0

        num_batches = int(
            (num_samples + self.batch_size - 1) / self.batch_size)

        epoch_loss = 0
        count = 0

        prof = SimpleProfiler(torch_sync=PROFILE, print=PROFILE)

        prof.tick("out")

        #try:
        for batch in dataloader:

            if batch is None:
                #print("None batch!")
                continue

            prof.tick("batch_load")
            # Zero gradients before each segment and initialize zero segment loss
            self.optim.zero_grad()

            #try:
            if True:

                batch_loss = self.model.sup_loss_on_batch(batch, eval)

                if type(batch_loss) == int:
                    print("Ding")

                prof.tick("forward")

                # Backprop and step
                if not eval:
                    batch_loss.backward()

                    prof.tick("backward")

                    # This is SLOW! Don't do it often
                    # TODO: Get rid of tensorboard
                    #if self.batch_num % 20 == 0 and hasattr(self.model, "writer"):
                    #    params = self.model.named_parameters()
                    #    self.write_grad_summaries(self.model.writer, params, self.batch_num)

                    self.batch_num += 1
                    self.optim.step()

                    prof.tick("optim")

                # Get losses as floats
                epoch_loss += batch_loss.data[0]
                count += 1

                sys.stdout.write("\r Batch:" + str(count) + " / " +
                                 str(num_batches) + " loss: " +
                                 str(batch_loss.data[0]))
                sys.stdout.flush()

                self.train_segment += 0 if eval else 1
                self.test_segment += 1 if eval else 0

                prof.tick("rep")

            prof.loop()
            prof.print_stats(10)
            #except Exception as e:
            #    print("Exception encountered during batch update")
            #    print(e)

        #except Exception as e:
        #    print("Error during epoch training")
        #    print(e)
        #    return

        if hasattr(self.model, "write_eoe_summaries"):
            self.model.write_eoe_summaries(inference_type, epoch_num)

        print("")
        epoch_loss /= (count + 1e-15)

        if hasattr(self.model, "writer"):
            self.model.writer.add_scalar(
                self.name + "/" + inference_type + "_epoch_loss", epoch_loss,
                epoch_num)

        return epoch_loss
Exemplo n.º 25
0
class PVN_Stage1_Bidomain_Original(nn.Module):
    def __init__(self, run_name="", domain="sim"):

        super(PVN_Stage1_Bidomain_Original, self).__init__()
        self.model_name = "pvn_stage1"
        self.run_name = run_name
        self.domain = domain
        self.writer = LoggingSummaryWriter(
            log_dir=f"{get_logging_dir()}/runs/{run_name}/{self.domain}")
        #self.writer = DummySummaryWriter()

        self.root_params = get_current_parameters()["ModelPVN"]
        self.params = self.root_params["Stage1"]
        self.use_aux = self.root_params["UseAux"]
        self.aux_weights = self.root_params["AuxWeights"]

        if self.params.get("weight_override"):
            aux_weights_override_name = "AuxWeightsRealOverride" if self.domain == "real" else "AuxWeightsSimOverride"
            aux_weights_override = self.root_params.get(
                aux_weights_override_name)
            if aux_weights_override:
                print(
                    f"Overriding auxiliary weights for domain: {self.domain}")
                self.aux_weights = dict_merge(self.aux_weights,
                                              aux_weights_override)

        self.prof = SimpleProfiler(torch_sync=PROFILE, print=PROFILE)
        self.iter = nn.Parameter(torch.zeros(1), requires_grad=False)

        self.tensor_store = KeyTensorStore()
        self.losses = AuxiliaryLosses()

        # Auxiliary Objectives
        self.do_perturb_maps = self.params["perturb_maps"]
        print("Perturbing maps: ", self.do_perturb_maps)

        # Path-pred FPV model definition
        # --------------------------------------------------------------------------------------------------------------

        self.num_feature_channels = self.params[
            "feature_channels"]  # + params["relevance_channels"]
        self.num_map_channels = self.params["pathpred_in_channels"]

        self.img_to_features_w = FPVToGlobalMap(
            source_map_size=self.params["global_map_size"],
            world_size_px=self.params["world_size_px"],
            world_size_m=self.params["world_size_m"],
            res_channels=self.params["resnet_channels"],
            map_channels=self.params["feature_channels"],
            img_w=self.params["img_w"],
            img_h=self.params["img_h"],
            cam_h_fov=self.params["cam_h_fov"],
            domain=domain,
            img_dbg=IMG_DBG)

        self.map_accumulator_w = LeakyIntegratorGlobalMap(
            source_map_size=self.params["global_map_size"],
            world_size_px=self.params["world_size_px"],
            world_size_m=self.params["world_size_m"])

        self.add_init_pos_to_coverage = AddDroneInitPosToCoverage(
            world_size_px=self.params["world_size_px"],
            world_size_m=self.params["world_size_m"],
            map_size_px=self.params["local_map_size"])

        # Pre-process the accumulated map to do language grounding if necessary - in the world reference frame
        self.map_processor_grounding = LangFilterMapProcessor(
            embed_size=self.params["emb_size"],
            in_channels=self.params["feature_channels"],
            out_channels=self.params["relevance_channels"],
            spatial=False,
            cat_out=False)

        ratio_prior_channels = self.params["feature_channels"]

        # Process the global accumulated map
        self.path_predictor_lingunet = RatioPathPredictor(
            self.params["lingunet"],
            prior_channels_in=self.params["feature_channels"],
            posterior_channels_in=self.params["pathpred_in_channels"],
            dual_head=self.params["predict_confidence"],
            compute_prior=self.params["compute_prior"],
            use_prior=self.params["use_prior_only"],
            oob=self.params["clip_observability"])

        print("UNet Channels: " + str(self.num_map_channels))
        print("Feature Channels: " + str(self.num_feature_channels))

        # TODO:O Verify that config has the same randomization parameters (yaw, pos, etc)
        self.second_transform = self.do_perturb_maps or self.params[
            "predict_in_start_frame"]

        # Sentence Embedding
        self.sentence_embedding = SentenceEmbeddingSimple(
            self.params["word_emb_size"], self.params["emb_size"],
            self.params["emb_layers"], self.params["emb_dropout"])

        self.map_transform_local_to_local = MapTransformer(
            source_map_size=self.params["local_map_size"],
            dest_map_size=self.params["local_map_size"],
            world_size_px=self.params["world_size_px"],
            world_size_m=self.params["world_size_m"])

        self.map_transform_global_to_local = MapTransformer(
            source_map_size=self.params["global_map_size"],
            dest_map_size=self.params["local_map_size"],
            world_size_px=self.params["world_size_px"],
            world_size_m=self.params["world_size_m"])

        self.map_transform_local_to_global = MapTransformer(
            source_map_size=self.params["local_map_size"],
            dest_map_size=self.params["global_map_size"],
            world_size_px=self.params["world_size_px"],
            world_size_m=self.params["world_size_m"])

        self.map_transform_s_to_p = self.map_transform_local_to_local
        self.map_transform_w_to_s = self.map_transform_global_to_local
        self.map_transform_w_to_r = self.map_transform_global_to_local
        self.map_transform_r_to_s = self.map_transform_local_to_local
        self.map_transform_r_to_w = self.map_transform_local_to_global
        self.map_transform_p_to_w = self.map_transform_local_to_global
        self.map_transform_p_to_r = self.map_transform_local_to_local

        # Batch select is used to drop and forget semantic maps at those timestaps that we're not planning in
        self.batch_select = MapBatchSelect()
        # Since we only have path predictions for some timesteps (the ones not dropped above), we use this to fill
        # in the missing pieces by reorienting the past trajectory prediction into the frame of the current timestep
        self.map_batch_fill_missing = MapBatchFillMissing(
            self.params["local_map_size"], self.params["world_size_px"],
            self.params["world_size_m"])

        self.spatialsoftmax = SpatialSoftmax2d()
        self.visitation_softmax = VisitationSoftmax()

        #TODO:O Use CroppedMapToActionTriplet in Wrapper as Stage2
        # Auxiliary Objectives
        # --------------------------------------------------------------------------------------------------------------

        # We add all auxiliaries that are necessary. The first argument is the auxiliary name, followed by parameters,
        # followed by variable number of names of inputs. ModuleWithAuxiliaries will automatically collect these inputs
        # that have been saved with keep_auxiliary_input() during execution
        if self.use_aux["class_features"]:
            self.losses.add_auxiliary(
                ClassAuxiliary2D("class_features",
                                 self.params["feature_channels"],
                                 self.params["num_landmarks"], 0,
                                 "fpv_features", "lm_pos_fpv", "lm_indices"))
        if self.use_aux["grounding_features"]:
            self.losses.add_auxiliary(
                ClassAuxiliary2D("grounding_features",
                                 self.params["relevance_channels"], 2, 0,
                                 "fpv_features_g", "lm_pos_fpv",
                                 "lm_mentioned"))
        if self.use_aux["class_map"]:
            self.losses.add_auxiliary(
                ClassAuxiliary2D("class_map", self.params["feature_channels"],
                                 self.params["num_landmarks"], 0, "S_W_select",
                                 "lm_pos_map_select", "lm_indices_select"))
        if self.use_aux["grounding_map"]:
            self.losses.add_auxiliary(
                ClassAuxiliary2D("grounding_map",
                                 self.params["relevance_channels"], 2, 0,
                                 "R_W_select", "lm_pos_map_select",
                                 "lm_mentioned_select"))
        # CoRL model uses alignment-model groundings
        if self.use_aux["lang"]:
            # one output for each landmark, 2 classes per output. This is for finetuning, so use the embedding that's gonna be fine tuned
            self.losses.add_auxiliary(
                ClassAuxiliary("lang", self.params["emb_size"], 2,
                               self.params["num_landmarks"], "sentence_embed",
                               "lang_lm_mentioned"))

        if self.use_aux["regularize_map"]:
            self.losses.add_auxiliary(
                FeatureRegularizationAuxiliary2D("regularize_map", "l1",
                                                 "S_W_select"))

        lossfunc = self.params["path_loss_function"]
        if self.params["clip_observability"]:
            self.losses.add_auxiliary(
                PathAuxiliary2D("visitation_dist", lossfunc,
                                self.params["clip_observability"],
                                "log_v_dist_s_select",
                                "v_dist_s_ground_truth_select", "SM_S_select"))
        else:
            self.losses.add_auxiliary(
                PathAuxiliary2D("visitation_dist", lossfunc,
                                self.params["clip_observability"],
                                "log_v_dist_s_select",
                                "v_dist_s_ground_truth_select", "SM_S_select"))

        self.goal_good_criterion = GoalPredictionGoodCriterion(
            ok_distance=self.params["world_size_px"] * 0.1)
        self.goal_acc_meter = MovingAverageMeter(10)
        self.visible_goal_acc_meter = MovingAverageMeter(10)
        self.invisible_goal_acc_meter = MovingAverageMeter(10)
        self.visible_goal_frac_meter = MovingAverageMeter(10)

        self.losses.print_auxiliary_info()

        self.total_goals = 0
        self.correct_goals = 0

        self.env_id = None
        self.env_img = None
        self.seg_idx = None
        self.prev_instruction = None
        self.seq_step = 0

        self.should_save_path_overlays = False

    def make_picklable(self):
        self.writer = DummySummaryWriter()

    def steal_cross_domain_modules(self, other_self):
        self.iter = other_self.iter
        self.losses = other_self.losses
        self.sentence_embedding = other_self.sentence_embedding
        self.map_accumulator_w = other_self.map_accumulator_w
        self.map_processor_grounding = other_self.map_processor_grounding
        self.path_predictor_lingunet = other_self.path_predictor_lingunet
        #self.img_to_features_w = other_self.img_to_features_w

    def both_domain_parameters(self, other_self):
        # This function iterates and yields parameters from this module and the other module, but does not yield
        # shared parameters twice.
        # First yield all of the other module's parameters
        for p in other_self.parameters():
            yield p
        # Then yield all the parameters from the this module that are not shared with the other one
        for p in self.img_to_features_w.parameters():
            yield p
        return

    def get_iter(self):
        return int(self.iter.data[0])

    def inc_iter(self):
        self.iter += 1

    def load_state_dict(self, state_dict, strict=True):
        super(PVN_Stage1_Bidomain_Original,
              self).load_state_dict(state_dict, strict)

    def init_weights(self):
        self.img_to_features_w.init_weights()
        self.map_accumulator_w.init_weights()
        self.sentence_embedding.init_weights()
        self.map_processor_grounding.init_weights()
        self.path_predictor_lingunet.init_weights()

    def reset(self):
        # TODO: This is error prone. Create a class StatefulModule, iterate submodules and reset all stateful modules
        self.tensor_store.reset()
        self.sentence_embedding.reset()
        self.img_to_features_w.reset()
        self.map_accumulator_w.reset()
        self.map_batch_fill_missing.reset()
        self.prev_instruction = None

    def setEnvContext(self, context):
        print("Set env context to: " + str(context))
        self.env_id = context["env_id"]
        self.env_img = env.load_env_img(self.env_id, 256, 256)
        self.env_img = self.env_img[:, :, [2, 1, 0]]

    def set_save_path_overlays(self, save_path_overlays):
        self.should_save_path_overlays = save_path_overlays

    #TODO:O Figure out what to do with save_ground_truth_overlays

    def print_metrics(self):
        print(f"Model {self.model_name}:{self.domain} metrics:")
        print(
            f"   Goal accuracy: {float(self.correct_goals) / self.total_goals}"
        )

    def goal_visible(self, masks, goal_pos):
        goal_mask = masks.detach()[0, 0, :, :]
        goal_pos = goal_pos[0].long().detach()
        visible = bool(
            (goal_mask[goal_pos[0], goal_pos[1]] > 0.5).detach().cpu().item())
        return visible

    # This is called before beginning an execution sequence
    def start_sequence(self):
        self.seq_step = 0
        self.reset()
        return

    def cam_poses_from_states(self, states):
        cam_pos = states[:, 9:12]
        cam_rot = states[:, 12:16]
        pose = Pose(cam_pos, cam_rot)
        return pose

    def forward(self,
                images,
                states,
                instructions,
                instr_lengths,
                plan=None,
                noisy_start_poses=None,
                start_poses=None,
                firstseg=None,
                select_only=True,
                halfway=False,
                grad_noise=False,
                rl=False):
        """
        :param images: BxCxHxW batch of images (observations)
        :param states: BxK batch of drone states
        :param instructions: BxM LongTensor where M is the maximum length of any instruction
        :param instr_lengths: list of len B of integers, indicating length of each instruction
        :param plan: list of B booleans indicating True for timesteps where we do planning and False otherwise
        :param noisy_start_poses: list of noisy start poses (for data-augmentation). These define the path-prediction frame at training time
        :param start_poses: list of drone start poses (these should be equal in practice)
        :param firstseg: list of booleans indicating True if a new segment starts at that timestep
        :param select_only: boolean indicating whether to only compute visitation distributions for planning timesteps (default True)
        :param rl: boolean indicating if we're doing reinforcement learning. If yes, output more than the visitation distribution
        :return:
        """
        cam_poses = self.cam_poses_from_states(states)
        g_poses = None  # None pose is a placeholder for the canonical global pose.
        self.prof.tick("out")

        self.tensor_store.keep_inputs("fpv", images)

        # Calculate the instruction embedding
        if instructions is not None:
            # TODO: Take batch of instructions and their lengths, return batch of embeddings. Store the last one as internal state
            # TODO: There's an assumption here that there's only a single instruction in the batch and it doesn't change
            # UNCOMMENT THE BELOW LINE TO REVERT BACK TO GENERAL CASE OF SEPARATE INSTRUCTION PER STEP
            if self.params["ignore_instruction"]:
                # If we're ignoring instructions, just feed in an instruction that consists of a single zero-token
                sent_embeddings = self.sentence_embedding(
                    torch.zeros_like(instructions[0:1, 0:1]),
                    torch.ones_like(instr_lengths[0:1]))
            else:
                sent_embeddings = self.sentence_embedding(
                    instructions[0:1], instr_lengths[0:1])
            self.tensor_store.keep_inputs("sentence_embed", sent_embeddings)
        else:
            sent_embeddings = self.sentence_embedding.get()

        self.prof.tick("embed")

        # Extract and project features onto the egocentric frame for each image
        F_W, M_W = self.img_to_features_w(images,
                                          cam_poses,
                                          sent_embeddings,
                                          self.tensor_store,
                                          show="",
                                          halfway=halfway)

        # For training the critic, this is as far as we need to poceed with the computation.
        # self.img_to_features_w has stored computed feature maps inside the tensor store, which will then be retrieved by the critic
        if halfway == True:  # Warning: halfway must be True not truthy
            return None, None

        self.tensor_store.keep_inputs("F_w", F_W)
        self.tensor_store.keep_inputs("M_w", M_W)
        self.prof.tick("img_to_map_frame")

        # Accumulate the egocentric features in a global map
        reset_mask = firstseg if self.params["clear_history"] else None

        # Consider the space very near the drone and right under it as observed - draw ones on the observability mask
        # If we treat that space as unobserved, then there's going to be a gap in the visitation distribution, which
        # makes training with RL more difficult, as there is no reward feedback if the drone doesn't cross that gap.
        if self.params.get("cover_init_pos", False):
            StartMasks_R = self.add_init_pos_to_coverage.get_init_pos_masks(
                M_W.shape[0], M_W.device)
            StartMasks_W, _ = self.map_transform_r_to_w(
                StartMasks_R, cam_poses, None)
            M_W = self.add_init_pos_to_coverage(M_W, StartMasks_W)

        S_W, SM_W = self.map_accumulator_w(F_W,
                                           M_W,
                                           reset_mask=reset_mask,
                                           show="acc" if IMG_DBG else "")
        S_W_poses = g_poses
        self.prof.tick("map_accumulate")

        # If we're training Stage 2 with imitation learning from ground truth visitation distributions, we want to
        # compute observability masks with the same code that's used in Stage 1 to avoid mistakes.
        if halfway == "observability":
            map_uncoverage_w = 1 - SM_W
            return map_uncoverage_w

        # Throw away those timesteps that don't correspond to planning timesteps
        S_W_select, SM_W_select, S_W_poses_select, cam_poses_select, noisy_start_poses_select, start_poses_select, sent_embeddings_select = \
            self.batch_select(S_W, SM_W, S_W_poses, cam_poses, noisy_start_poses, start_poses, sent_embeddings, plan)

        #maps_m_prior_select, maps_m_posterior_select = None, None

        # Only process the maps on plannieng timesteps
        if len(S_W_select) == 0:
            return None

        self.tensor_store.keep_inputs("S_W_select", S_W_select)
        self.prof.tick("batch_select")

        # Process the map via the two map_procesors
        # Do grounding of objects in the map chosen to do so
        if self.use_aux["grounding_map"]:
            R_W_select, RS_W_poses_select = self.map_processor_grounding(
                S_W_select, sent_embeddings_select, S_W_poses_select, show="")
            self.tensor_store.keep_inputs("R_W_select", R_W_select)
            self.prof.tick("map_proc_gnd")
            # Concatenate grounding map and semantic map along channel dimension
            RS_W_select = torch.cat([S_W_select, R_W_select], 1)

        else:
            RS_W_select = S_W_select
            RS_W_poses_select = S_W_poses_select

        s_poses_select = start_poses_select if self.params[
            "predict_in_start_frame"] else cam_poses_select
        RS_S_select, RS_S_poses_select = self.map_transform_w_to_s(
            RS_W_select, RS_W_poses_select, s_poses_select)
        SM_S_select, SM_S_poses_select = self.map_transform_w_to_s(
            SM_W_select, S_W_poses_select, s_poses_select)

        assert SM_S_poses_select == RS_S_poses_select, "Masks and maps should have the same pose in start frame"

        self.tensor_store.keep_inputs("RS_S_select", RS_S_select)
        self.tensor_store.keep_inputs("SM_S_select", SM_S_select)
        self.prof.tick("transform_w_to_s")

        # Data augmentation for trajectory prediction
        map_poses_clean_select = None
        # TODO: Figure out if we can just swap out start poses for noisy poses and get rid of separate noisy poses
        if self.do_perturb_maps:
            assert noisy_start_poses_select is not None, "Noisy poses must be provided if we're perturbing maps"
            RS_P_select, RS_P_poses_select = self.map_transform_s_to_p(
                RS_S_select, RS_S_poses_select, noisy_start_poses_select)
        else:
            RS_P_select, RS_P_poses_select = RS_S_select, RS_S_poses_select

        self.tensor_store.keep_inputs("RS_perturbed_select", RS_P_select)
        self.prof.tick("map_perturb")

        sent_embeddings_pp = sent_embeddings_select

        # Run lingunet on the map to predict visitation distribution scores (pre-softmax)
        # ---------
        log_v_dist_p_select, v_dist_p_poses_select = self.path_predictor_lingunet(
            RS_P_select,
            sent_embeddings_pp,
            RS_P_poses_select,
            tensor_store=self.tensor_store)
        # ---------

        self.prof.tick("pathpred")

        # TODO: Shouldn't we be transforming probability distributions instead of scores? Otherwise OOB space will have weird values
        # Transform distributions back to world reference frame and keep them (these are the model outputs)
        both_inner_w, v_dist_w_poses_select = self.map_transform_p_to_w(
            log_v_dist_p_select.inner_distribution, v_dist_p_poses_select,
            None)
        log_v_dist_w_select = Partial2DDistribution(
            both_inner_w, log_v_dist_p_select.outer_prob_mass)
        self.tensor_store.keep_inputs("log_v_dist_w_select",
                                      log_v_dist_w_select)

        # Transform distributions back to start reference frame and keep them (for auxiliary objective)
        both_inner_s, v_dist_s_poses_select = self.map_transform_p_to_r(
            log_v_dist_p_select.inner_distribution, v_dist_p_poses_select,
            start_poses_select)
        log_v_dist_s_select = Partial2DDistribution(
            both_inner_s, log_v_dist_p_select.outer_prob_mass)
        self.tensor_store.keep_inputs("log_v_dist_s_select",
                                      log_v_dist_s_select)

        # prime number will mean that it will alternate between sim and real
        if self.get_iter() % 23 == 0:
            lsfm = SpatialSoftmax2d()
            for i in range(S_W_select.shape[0]):
                Presenter().show_image(S_W_select.detach().cpu()[i, 0:3],
                                       f"{self.domain}_s_w_select",
                                       scale=4,
                                       waitkey=1)
                Presenter().show_image(lsfm(
                    log_v_dist_s_select.inner_distribution).detach().cpu()[i],
                                       f"{self.domain}_v_dist_s_select",
                                       scale=4,
                                       waitkey=1)
                Presenter().show_image(lsfm(
                    log_v_dist_p_select.inner_distribution).detach().cpu()[i],
                                       f"{self.domain}_v_dist_p_select",
                                       scale=4,
                                       waitkey=1)
                Presenter().show_image(RS_P_select.detach().cpu()[i, 0:3],
                                       f"{self.domain}_rs_p_select",
                                       scale=4,
                                       waitkey=1)
                break

        self.prof.tick("transform_back")

        # If we're predicting the trajectory only on some timesteps, then for each timestep k, use the map from
        # timestep k if predicting on timestep k. otherwise use the map from timestep j - the last timestep
        # that had a trajectory prediction, rotated in the frame of timestep k.
        if select_only:
            # If we're just pre-training the trajectory prediction, don't waste time on generating the missing maps
            log_v_dist_w = log_v_dist_w_select
            v_dist_w_poses = v_dist_w_poses_select
        else:
            raise NotImplementedError("select_only must be True")

        return_list = [log_v_dist_w, v_dist_w_poses]
        if rl:
            internals_for_rl = {
                "map_coverage_w": SM_W,
                "map_uncoverage_w": 1 - SM_W
            }
            return_list.append(internals_for_rl)

        return tuple(return_list)

    def maybe_cuda(self, tensor):
        return tensor.to(next(self.parameters()).device)

    def cuda_var(self, tensor):
        return tensor.to(next(self.parameters()).device)

    def unbatch(self, batch, halfway=False):
        # Inputs
        images = self.maybe_cuda(batch["images"][0])
        seq_len = len(images)
        instructions = self.maybe_cuda(batch["instr"][0][:seq_len])
        instr_lengths = batch["instr_len"][0][:seq_len]
        states = self.maybe_cuda(batch["states"][0])

        if not halfway:

            plan_mask = batch["plan_mask"][
                0]  # True for every timestep that we do visitation prediction
            firstseg_mask = batch["firstseg_mask"][
                0]  # True for every timestep that is a new instruction segment

            # Labels (including for auxiliary losses)
            lm_pos_fpv = batch["lm_pos_fpv"][
                0]  # All object 2D coordinates in the first-person image
            lm_pos_map_m = batch["lm_pos_map"][
                0]  # All object 2D coordinates in the semantic map
            lm_indices = batch["lm_indices"][0]  # All object class indices
            goal_pos_map_m = batch["goal_loc"][
                0]  # Goal location in the world in meters_and_metrics
            lm_mentioned = batch["lm_mentioned"][
                0]  # 1/0 labels whether object was mentioned/not mentioned in template instruction
            # TODO: We're taking the FIRST label here. SINGLE SEGMENT ASSUMPTION
            lang_lm_mentioned = batch["lang_lm_mentioned"][0][
                0]  # integer labes as to which object was mentioned
            start_poses = batch["start_poses"][0]
            noisy_start_poses = get_noisy_poses_torch(
                start_poses.numpy(),
                self.params["pos_variance"],
                self.params["rot_variance"],
                cuda=False,
                cuda_device=None)

            # Ground truth visitation distributions (in start and global frames)
            v_dist_w_ground_truth_select = self.maybe_cuda(
                batch["traj_ground_truth"][0])
            start_poses_select = self.batch_select.one(
                start_poses, plan_mask, v_dist_w_ground_truth_select.device)
            v_dist_s_ground_truth_select, poses_s = self.map_transform_w_to_s(
                v_dist_w_ground_truth_select, None, start_poses_select)
            #self.tensor_store.keep_inputs("v_dist_w_ground_truth_select", v_dist_w_ground_truth_select)
            self.tensor_store.keep_inputs("v_dist_s_ground_truth_select",
                                          v_dist_s_ground_truth_select)
            #Presenter().show_image(v_dist_s_ground_truth_select.detach().cpu()[0,0], "v_dist_s_ground_truth_select", waitkey=1, scale=4)
            #Presenter().show_image(v_dist_w_ground_truth_select.detach().cpu()[0,0], "v_dist_w_ground_truth_select", waitkey=1, scale=4)

            lm_pos_map_px = [
                torch.from_numpy(
                    transformations.pos_m_to_px(p.numpy(),
                                                self.params["global_map_size"],
                                                self.params["world_size_m"],
                                                self.params["world_size_px"]))
                if p is not None else None for p in lm_pos_map_m
            ]
            goal_pos_map_px = torch.from_numpy(
                transformations.pos_m_to_px(goal_pos_map_m.numpy(),
                                            self.params["global_map_size"],
                                            self.params["world_size_m"],
                                            self.params["world_size_px"]))

            resnet_factor = self.img_to_features_w.img_to_features.get_downscale_factor(
            )
            lm_pos_fpv = [
                self.cuda_var(
                    (s / resnet_factor).long()) if s is not None else None
                for s in lm_pos_fpv
            ]
            lm_indices = [
                self.cuda_var(s) if s is not None else None for s in lm_indices
            ]
            lm_mentioned = [
                self.cuda_var(s) if s is not None else None
                for s in lm_mentioned
            ]
            lang_lm_mentioned = self.cuda_var(lang_lm_mentioned)
            lm_pos_map_px = [
                self.cuda_var(s.long()) if s is not None else None
                for s in lm_pos_map_px
            ]
            goal_pos_map_px = self.cuda_var(goal_pos_map_px)

            self.tensor_store.keep_inputs("lm_pos_fpv", lm_pos_fpv)
            self.tensor_store.keep_inputs("lm_pos_map", lm_pos_map_px)
            self.tensor_store.keep_inputs("lm_indices", lm_indices)
            self.tensor_store.keep_inputs("lm_mentioned", lm_mentioned)
            self.tensor_store.keep_inputs("lang_lm_mentioned",
                                          lang_lm_mentioned)
            self.tensor_store.keep_inputs("goal_pos_map", goal_pos_map_px)

            lm_pos_map_select = [
                lm_pos for i, lm_pos in enumerate(lm_pos_map_px)
                if plan_mask[i]
            ]
            lm_indices_select = [
                lm_idx for i, lm_idx in enumerate(lm_indices) if plan_mask[i]
            ]
            lm_mentioned_select = [
                lm_m for i, lm_m in enumerate(lm_mentioned) if plan_mask[i]
            ]
            goal_pos_map_select = [
                pos for i, pos in enumerate(goal_pos_map_px) if plan_mask[i]
            ]

            self.tensor_store.keep_inputs("lm_pos_map_select",
                                          lm_pos_map_select)
            self.tensor_store.keep_inputs("lm_indices_select",
                                          lm_indices_select)
            self.tensor_store.keep_inputs("lm_mentioned_select",
                                          lm_mentioned_select)
            self.tensor_store.keep_inputs("goal_pos_map_select",
                                          goal_pos_map_select)

        # We won't need this extra information
        else:
            noisy_poses, start_poses, noisy_start_poses = None, None, None
            plan_mask, firstseg_mask = None, None

        metadata = batch["md"][0][0]
        env_id = metadata["env_id"]
        self.tensor_store.set_flag("env_id", env_id)

        return images, states, instructions, instr_lengths, plan_mask, firstseg_mask, start_poses, noisy_start_poses, metadata

    # Forward pass for training
    def sup_loss_on_batch(self,
                          batch,
                          eval,
                          halfway=False,
                          grad_noise=False,
                          disable_losses=[]):
        self.prof.tick("out")
        self.reset()

        if batch is None:
            print("Skipping None Batch")
            zero = torch.zeros([1]).float().to(next(self.parameters()).device)
            return zero, self.tensor_store

        images, states, instructions, instr_len, plan_mask, firstseg_mask, \
         start_poses, noisy_start_poses, metadata = self.unbatch(batch, halfway=halfway)
        self.prof.tick("unbatch_inputs")

        # ----------------------------------------------------------------------------
        _ = self(images,
                 states,
                 instructions,
                 instr_len,
                 plan=plan_mask,
                 firstseg=firstseg_mask,
                 noisy_start_poses=start_poses if eval else noisy_start_poses,
                 start_poses=start_poses,
                 select_only=True,
                 halfway=halfway,
                 grad_noise=grad_noise)
        # ----------------------------------------------------------------------------

        if self.should_save_path_overlays:
            self.save_path_overlays(metadata)

        # If we run the model halfway, we only need to calculate features needed for the wasserstein loss
        # If we want to include more features in wasserstein critic, have to run the forward pass a bit further
        if halfway and not halfway == "v2":
            return None, self.tensor_store

        # The returned values are not used here - they're kept in the tensor store which is used as an input to a loss
        self.prof.tick("call")

        if not halfway:
            # Calculate goal-prediction accuracy:
            goal_pos = self.tensor_store.get_inputs_batch("goal_pos_map",
                                                          cat_not_stack=True)
            success_goal = self.goal_good_criterion(
                self.tensor_store.get_inputs_batch("log_v_dist_w_select",
                                                   cat_not_stack=True),
                goal_pos)
            acc = 1.0 if success_goal else 0.0
            self.goal_acc_meter.put(acc)
            goal_visible = self.goal_visible(
                self.tensor_store.get_inputs_batch("M_w", cat_not_stack=True),
                goal_pos)
            if goal_visible:
                self.visible_goal_acc_meter.put(acc)
            else:
                self.invisible_goal_acc_meter.put(acc)
            self.visible_goal_frac_meter.put(1.0 if goal_visible else 0.0)

            self.correct_goals += acc
            self.total_goals += 1

            self.prof.tick("goal_acc")

        if halfway == "v2":
            disable_losses = ["visitation_dist", "lang"]

        losses, metrics = self.losses.calculate_aux_loss(
            tensor_store=self.tensor_store,
            reduce_average=True,
            disable_losses=disable_losses)
        loss = self.losses.combine_losses(losses, self.aux_weights)

        self.prof.tick("calc_losses")

        prefix = self.model_name + ("/eval" if eval else "/train")
        iteration = self.get_iter()
        self.writer.add_dict(prefix, get_current_meters(), iteration)
        self.writer.add_dict(prefix, losses, iteration)
        self.writer.add_dict(prefix, metrics, iteration)

        if not halfway:
            self.writer.add_scalar(prefix + "/goal_accuracy",
                                   self.goal_acc_meter.get(), iteration)
            self.writer.add_scalar(prefix + "/visible_goal_accuracy",
                                   self.visible_goal_acc_meter.get(),
                                   iteration)
            self.writer.add_scalar(prefix + "/invisible_goal_accuracy",
                                   self.invisible_goal_acc_meter.get(),
                                   iteration)
            self.writer.add_scalar(prefix + "/visible_goal_fraction",
                                   self.visible_goal_frac_meter.get(),
                                   iteration)

        self.inc_iter()

        self.prof.tick("summaries")
        self.prof.loop()
        self.prof.print_stats(1)

        return loss, self.tensor_store

    def get_dataset(self,
                    data=None,
                    envs=None,
                    domain=None,
                    dataset_names=None,
                    dataset_prefix=None,
                    eval=False,
                    halfway_only=False):
        # TODO: Maybe use eval here
        data_sources = []
        # If we're running auxiliary objectives, we need to include the data sources for the auxiliary labels
        #if self.use_aux_class_features or self.use_aux_class_on_map or self.use_aux_grounding_features or self.use_aux_grounding_on_map:
        #if self.use_aux_goal_on_map:
        if not halfway_only:
            data_sources.append(aup.PROVIDER_LM_POS_DATA)
            data_sources.append(aup.PROVIDER_GOAL_POS)

            # Adding these in this order will compute poses with added noise and compute trajectory ground truth
            # in the reference frame of these noisy poses
            data_sources.append(aup.PROVIDER_START_POSES)

            if self.do_perturb_maps:
                print("PERTURBING MAPS!")
                # TODO: The noisy poses from the provider are not actually used!! Those should replace states instead!
                data_sources.append(aup.PROVIDER_NOISY_POSES)
                # TODO: Think this through. Perhaps we actually want dynamic ground truth given a noisy start position
                if self.params["predict_in_start_frame"]:
                    data_sources.append(
                        aup.PROVIDER_TRAJECTORY_GROUND_TRUTH_STATIC)
                else:
                    data_sources.append(
                        aup.PROVIDER_TRAJECTORY_GROUND_TRUTH_DYNAMIC_NOISY)
            else:
                print("NOT Perturbing Maps!")
                data_sources.append(aup.PROVIDER_NOISY_POSES)
                if self.params["predict_in_start_frame"]:
                    data_sources.append(
                        aup.PROVIDER_TRAJECTORY_GROUND_TRUTH_STATIC)
                else:
                    data_sources.append(
                        aup.PROVIDER_TRAJECTORY_GROUND_TRUTH_DYNAMIC)

            data_sources.append(aup.PROVIDER_LANDMARKS_MENTIONED)

            templates = get_current_parameters()["Environment"]["templates"]
            if templates:
                data_sources.append(aup.PROVIDER_LANG_TEMPLATE)

        return SegmentDataset(data=data,
                              env_list=envs,
                              domain=domain,
                              dataset_names=dataset_names,
                              dataset_prefix=dataset_prefix,
                              aux_provider_names=data_sources,
                              segment_level=True)