Exemplo n.º 1
0
class ModelGSFPV(nn.Module):
    def __init__(self,
                 run_name="",
                 aux_class_features=False,
                 aux_grounding_features=False,
                 aux_lang=False,
                 recurrence=False):

        super(ModelGSFPV, self).__init__()
        self.model_name = "gs_fpv" + "_mem" if recurrence else ""
        self.run_name = run_name
        self.writer = LoggingSummaryWriter(log_dir="runs/" + run_name)

        self.params = get_current_parameters()["Model"]
        self.aux_weights = get_current_parameters()["AuxWeights"]

        self.prof = SimpleProfiler(torch_sync=PROFILE, print=PROFILE)
        self.iter = nn.Parameter(torch.zeros(1), requires_grad=False)

        # Auxiliary Objectives
        self.use_aux_class_features = aux_class_features
        self.use_aux_grounding_features = aux_grounding_features
        self.use_aux_lang = aux_lang
        self.use_recurrence = recurrence

        self.img_to_features_w = FPVToFPVMap(self.params["img_w"],
                                             self.params["img_h"],
                                             self.params["resnet_channels"],
                                             self.params["feature_channels"])

        self.lang_filter_gnd = MapLangSemanticFilter(
            self.params["emb_size"], self.params["feature_channels"],
            self.params["relevance_channels"])

        self.lang_filter_goal = MapLangSpatialFilter(
            self.params["emb_size"], self.params["relevance_channels"],
            self.params["goal_channels"])

        self.map_downsample = DownsampleResidual(
            self.params["map_to_act_channels"], 2)

        self.recurrence = RecurrentEmbedding(
            self.params["gs_fpv_feature_map_size"],
            self.params["gs_fpv_recurrence_size"])

        # Sentence Embedding
        self.sentence_embedding = SentenceEmbeddingSimple(
            self.params["word_emb_size"], self.params["emb_size"],
            self.params["emb_layers"])

        in_features_size = self.params[
            "gs_fpv_feature_map_size"] + self.params["emb_size"]
        if self.use_recurrence:
            in_features_size += self.params["gs_fpv_recurrence_size"]

        self.features_to_action = DenseMlpBlock2(in_features_size,
                                                 self.params["mlp_hidden"], 4)

        # Auxiliary Objectives
        # --------------------------------------------------------------------------------------------------------------

        self.add_auxiliary(
            ClassAuxiliary2D("aux_class", None,
                             self.params["feature_channels"],
                             self.params["num_landmarks"], "fpv_features",
                             "lm_pos_fpv", "lm_indices"))
        self.add_auxiliary(
            ClassAuxiliary2D("aux_ground", None,
                             self.params["relevance_channels"], 2,
                             "fpv_features_g", "lm_pos_fpv", "lm_mentioned"))
        if self.params["templates"]:
            self.add_auxiliary(
                ClassAuxiliary("aux_lang_lm", self.params["emb_size"],
                               self.params["num_landmarks"], 1,
                               "sentence_embed", "lm_mentioned_tplt"))
            self.add_auxiliary(
                ClassAuxiliary("aux_lang_side", self.params["emb_size"],
                               self.params["num_sides"], 1, "sentence_embed",
                               "side_mentioned_tplt"))
        else:
            self.add_auxiliary(
                ClassAuxiliary("aux_lang_lm_nl", self.params["emb_size"], 2,
                               self.params["num_landmarks"], "sentence_embed",
                               "lang_lm_mentioned"))

        self.action_loss = ActionLoss()

        self.env_id = None
        self.prev_instruction = None
        self.seq_step = 0

    # TODO: Try to hide these in a superclass or something. They take up a lot of space:
    def cuda(self, device=None):
        ModuleWithAuxiliaries.cuda(self, device)
        self.sentence_embedding.cuda(device)
        self.img_to_features_w.cuda(device)
        self.lang_filter_gnd.cuda(device)
        self.lang_filter_goal.cuda(device)
        self.action_loss.cuda(device)
        self.recurrence.cuda(device)
        return self

    def get_iter(self):
        return int(self.iter.data[0])

    def inc_iter(self):
        self.iter += 1

    def init_weights(self):
        self.img_to_features_w.init_weights()
        self.lang_filter_gnd.init_weights()
        self.lang_filter_goal.init_weights()
        self.sentence_embedding.init_weights()

    def reset(self):
        # TODO: This is error prone. Create a class StatefulModule, iterate submodules and reset all stateful modules
        super(ModelGSFPV, self).reset()
        self.sentence_embedding.reset()
        self.img_to_features_w.reset()
        self.recurrence.reset()
        self.prev_instruction = None
        print("GS_FPV_MEM_RESET")

    def setEnvContext(self, context):
        print("Set env context to: " + str(context))
        self.env_id = context["env_id"]

    def start_segment_rollout(self, *args):
        self.reset()

    def get_action(self, state, instruction):
        """
        Given a DroneState (from PomdpInterface) and instruction, produce a numpy 4D action (x, y, theta, pstop)
        :param state: DroneState object with the raw image from the simulator
        :param instruction: Tokenized instruction given the corpus
        #TODO: Absorb corpus within model
        :return:
        """
        # TODO: Simplify this
        self.eval()
        images_np_pure = state.image
        state_np = state.state

        #print("Act: " + debug_untokenize_instruction(instruction))

        images_np = standardize_image(images_np_pure)
        image_fpv = Variable(none_padded_seq_to_tensor([images_np]))
        state = Variable(none_padded_seq_to_tensor([state_np]))
        self.prev_instruction = instruction

        img_in_t = image_fpv
        img_in_t.volatile = True

        instr_len = [len(instruction)] if instruction is not None else None
        instruction = torch.LongTensor(instruction).unsqueeze(0)
        instruction = cuda_var(instruction, self.is_cuda, self.cuda_device)

        state.volatile = True

        if self.is_cuda:
            img_in_t = img_in_t.cuda(self.cuda_device)
            state = state.cuda(self.cuda_device)

        self.seq_step += 1

        action = self(img_in_t, state, instruction, instr_len)

        output_action = action.squeeze().data.cpu().numpy()
        print("action: ", output_action)

        stop_prob = output_action[3]
        output_stop = 1 if stop_prob > self.params["stop_threshold"] else 0
        output_action[3] = output_stop

        return output_action

    def deterministic_action(self, action_mean, action_std, stop_prob):
        batch_size = action_mean.size(0)
        action = Variable(
            empty_float_tensor((batch_size, 4), self.is_cuda,
                               self.cuda_device))
        action[:, 0:3] = action_mean[:, 0:3]
        action[:, 3] = stop_prob
        return action

    def sample_action(self, action_mean, action_std, stop_prob):
        action = torch.normal(action_mean, action_std)
        stop = torch.bernoulli(stop_prob)
        return action, stop

    # This is called before beginning an execution sequence
    def start_sequence(self):
        self.seq_step = 0
        self.reset()
        print("RESETTED!")
        return

    # TODO: Move this somewhere and standardize
    def cam_poses_from_states(self, states):
        cam_pos = states[:, 9:12]
        cam_rot = states[:, 12:16]
        pose = Pose(cam_pos, cam_rot)
        return pose

    def forward(self, images, states, instructions, instr_lengths):
        """
        :param images: BxCxHxW batch of images (observations)
        :param states: BxK batch of drone states
        :param instructions: BxM LongTensor where M is the maximum length of any instruction
        :param instr_lengths: list of len B of integers, indicating length of each instruction
        :param has_obs: list of booleans of length B indicating whether the given element in the sequence has an observation
        :param yield_semantic_maps: If true, will not compute actions (full model), but return the semantic maps that
            were built along the way in response to the images. This is ugly, but allows code reuse
        :return:
        """
        cam_poses = self.cam_poses_from_states(states)
        self.prof.tick("out")

        #print("Trn: " + debug_untokenize_instruction(instructions[0].data[:instr_lengths[0]]))

        # Calculate the instruction embedding
        if instructions is not None:
            # TODO: Take batch of instructions and their lengths, return batch of embeddings. Store the last one as internal state
            sent_embeddings = self.sentence_embedding(instructions,
                                                      instr_lengths)
            self.keep_inputs("sentence_embed", sent_embeddings)
        else:
            sent_embeddings = self.sentence_embedding.get()

        self.prof.tick("embed")

        seq_size = len(images)

        # Extract and project features onto the egocentric frame for each image
        fpv_features = self.img_to_features_w(images,
                                              cam_poses,
                                              sent_embeddings,
                                              self,
                                              show="")

        self.keep_inputs("fpv_features", fpv_features)
        self.prof.tick("img_to_map_frame")

        self.lang_filter_gnd.precompute_conv_weights(sent_embeddings)
        self.lang_filter_goal.precompute_conv_weights(sent_embeddings)

        gnd_features = self.lang_filter_gnd(fpv_features)
        goal_features = self.lang_filter_goal(gnd_features)

        self.keep_inputs("fpv_features_g", gnd_features)
        visual_features = torch.cat([gnd_features, goal_features], dim=1)

        lstm_in_features = visual_features.view([seq_size, 1, -1])

        catlist = [lstm_in_features.view([seq_size, -1]), sent_embeddings]

        if self.use_recurrence:
            memory_features = self.recurrence(lstm_in_features)
            catlist.append(memory_features[:, 0, :])

        action_features = torch.cat(catlist, dim=1)

        # Output the final action given the processed map
        action_pred = self.features_to_action(action_features)
        action_pred[:, 3] = torch.sigmoid(action_pred[:, 3])
        out_action = self.deterministic_action(action_pred[:, 0:3], None,
                                               action_pred[:, 3])
        self.prof.tick("map_to_action")

        return out_action

    def maybe_cuda(self, tensor):
        if self.is_cuda:
            return tensor.cuda()
        else:
            return tensor

    def cuda_var(self, tensor):
        return cuda_var(tensor, self.is_cuda, self.cuda_device)

    # Forward pass for training (with batch optimizations
    def sup_loss_on_batch(self, batch, eval):
        self.prof.tick("out")

        action_loss_total = Variable(
            empty_float_tensor([1], self.is_cuda, self.cuda_device))

        if batch is None:
            print("Skipping None Batch")
            return action_loss_total

        images = self.maybe_cuda(batch["images"])

        instructions = self.maybe_cuda(batch["instr"])
        instr_lengths = batch["instr_len"]
        states = self.maybe_cuda(batch["states"])
        actions = self.maybe_cuda(batch["actions"])

        # Auxiliary labels
        lm_pos_fpv = batch["lm_pos_fpv"]
        lm_indices = batch["lm_indices"]
        lm_mentioned = batch["lm_mentioned"]
        lang_lm_mentioned = batch["lang_lm_mentioned"]

        templates = get_current_parameters()["Environment"]["Templates"]
        if templates:
            lm_mentioned_tplt = batch["lm_mentioned_tplt"]
            side_mentioned_tplt = batch["side_mentioned_tplt"]

        # stops = self.maybe_cuda(batch["stops"])
        masks = self.maybe_cuda(batch["masks"])
        metadata = batch["md"]

        seq_len = images.size(1)
        batch_size = images.size(0)
        count = 0
        correct_goal_count = 0
        goal_count = 0

        # Loop thru batch
        for b in range(batch_size):
            seg_idx = -1

            self.reset()

            self.prof.tick("out")
            b_seq_len = len_until_nones(metadata[b])

            # TODO: Generalize this
            # Slice the data according to the sequence length
            b_metadata = metadata[b][:b_seq_len]
            b_images = images[b][:b_seq_len]
            b_instructions = instructions[b][:b_seq_len]
            b_instr_len = instr_lengths[b][:b_seq_len]
            b_states = states[b][:b_seq_len]
            b_actions = actions[b][:b_seq_len]
            b_lm_pos_fpv = lm_pos_fpv[b][:b_seq_len]
            b_lm_indices = lm_indices[b][:b_seq_len]
            b_lm_mentioned = lm_mentioned[b][:b_seq_len]

            b_lm_pos_fpv = [
                self.cuda_var(
                    (s / RESNET_FACTOR).long()) if s is not None else None
                for s in b_lm_pos_fpv
            ]
            b_lm_indices = [
                self.cuda_var(s) if s is not None else None
                for s in b_lm_indices
            ]
            b_lm_mentioned = [
                self.cuda_var(s) if s is not None else None
                for s in b_lm_mentioned
            ]

            # TODO: Figure out how to keep these properly. Perhaps as a whole batch is best
            # TODO: Introduce a key-value store (encapsulate instead of inherit)
            self.keep_inputs("lm_pos_fpv", b_lm_pos_fpv)
            self.keep_inputs("lm_indices", b_lm_indices)
            self.keep_inputs("lm_mentioned", b_lm_mentioned)

            # TODO: Abstract all of these if-elses in a modular way once we know which ones are necessary
            if templates:
                b_lm_mentioned_tplt = lm_mentioned_tplt[b][:b_seq_len]
                b_side_mentioned_tplt = side_mentioned_tplt[b][:b_seq_len]
                b_side_mentioned_tplt = self.cuda_var(b_side_mentioned_tplt)
                b_lm_mentioned_tplt = self.cuda_var(b_lm_mentioned_tplt)
                self.keep_inputs("lm_mentioned_tplt", b_lm_mentioned_tplt)
                self.keep_inputs("side_mentioned_tplt", b_side_mentioned_tplt)
            else:
                b_lang_lm_mentioned = self.cuda_var(
                    lang_lm_mentioned[b][:b_seq_len])
                self.keep_inputs("lang_lm_mentioned", b_lang_lm_mentioned)

            # ----------------------------------------------------------------------------

            self.prof.tick("inputs")

            actions = self(b_images, b_states, b_instructions, b_instr_len)

            action_losses, _ = self.action_loss(b_actions,
                                                actions,
                                                batchreduce=False)

            self.prof.tick("call")
            action_losses = self.action_loss.batch_reduce_loss(action_losses)
            action_loss = self.action_loss.reduce_loss(action_losses)
            action_loss_total = action_loss
            count += b_seq_len

            self.prof.tick("loss")

        action_loss_avg = action_loss_total / (count + 1e-9)

        self.prof.tick("out")

        # Doing this in the end (outside of se
        aux_losses = self.calculate_aux_loss(reduce_average=True)
        aux_loss = self.combine_aux_losses(aux_losses, self.aux_weights)

        prefix = self.model_name + ("/eval" if eval else "/train")

        self.writer.add_dict(prefix, get_current_meters(), self.get_iter())
        self.writer.add_dict(prefix, aux_losses, self.get_iter())
        self.writer.add_scalar(prefix + "/action_loss",
                               action_loss_avg.data.cpu()[0], self.get_iter())

        self.prof.tick("auxiliaries")

        total_loss = action_loss_avg + aux_loss

        self.inc_iter()

        self.prof.tick("summaries")
        self.prof.loop()
        self.prof.print_stats(1)

        return total_loss

    def get_dataset(self,
                    data=None,
                    envs=None,
                    dataset_names=None,
                    dataset_prefix=None,
                    eval=False):
        # TODO: Maybe use eval here
        #if self.fpv:
        data_sources = []
        data_sources.append(aup.PROVIDER_LM_POS_DATA)
        data_sources.append(aup.PROVIDER_LANDMARKS_MENTIONED)

        templates = get_current_parameters()["Environment"]["Templates"]
        if templates:
            data_sources.append(aup.PROVIDER_LANG_TEMPLATE)

        return SegmentDataset(data=data,
                              env_list=envs,
                              dataset_names=dataset_names,
                              dataset_prefix=dataset_prefix,
                              aux_provider_names=data_sources,
                              segment_level=True)
Exemplo n.º 2
0
class ModelTrajectoryToAction(ModuleWithAuxiliaries):
    def __init__(self, run_name=""):

        super(ModelTrajectoryToAction, self).__init__()
        self.model_name = "lsvd_action"
        self.run_name = run_name
        self.writer = LoggingSummaryWriter(log_dir="runs/" + run_name)

        self.params = get_current_parameters()["ModelPVN"]
        self.aux_weights = get_current_parameters()["AuxWeights"]

        self.prof = SimpleProfiler(torch_sync=PROFILE, print=PROFILE)
        self.iter = nn.Parameter(torch.zeros(1), requires_grad=False)

        # Common
        # --------------------------------------------------------------------------------------------------------------
        self.map_transform_w_to_s = MapTransformerBase(
            source_map_size=self.params["global_map_size"],
            dest_map_size=self.params["local_map_size"],
            world_size=self.params["world_size_px"])

        self.map_transform_r_to_w = MapTransformerBase(
            source_map_size=self.params["local_map_size"],
            dest_map_size=self.params["global_map_size"],
            world_size=self.params["world_size_px"])

        # Output an action given the global semantic map
        if self.params["map_to_action"] == "downsample2":
            self.map_to_action = EgoMapToActionTriplet(
                map_channels=self.params["map_to_act_channels"],
                map_size=self.params["local_map_size"],
                other_features_size=self.params["emb_size"])

        elif self.params["map_to_action"] == "cropped":
            self.map_to_action = CroppedMapToActionTriplet(
                map_channels=self.params["map_to_act_channels"],
                map_size=self.params["local_map_size"],
                manual=self.params["manual_rule"],
                path_only=self.params["action_in_path_only"],
                recurrence=self.params["action_recurrence"])

        self.spatialsoftmax = SpatialSoftmax2d()
        self.gt_fill_missing = MapBatchFillMissing(
            self.params["local_map_size"], self.params["world_size_px"])

        # Don't freeze the trajectory to action weights, because it will be pre-trained during path-prediction training
        # and finetuned on all timesteps end-to-end
        enable_weight_saving(self.map_to_action,
                             "map_to_action",
                             alwaysfreeze=False,
                             neverfreeze=True)

        self.action_loss = ActionLoss()

        self.env_id = None
        self.seg_idx = None
        self.prev_instruction = None
        self.seq_step = 0
        self.get_act_start_pose = None
        self.gt_labels = None

    # TODO: Try to hide these in a superclass or something. They take up a lot of space:
    def cuda(self, device=None):
        ModuleWithAuxiliaries.cuda(self, device)
        self.map_to_action.cuda(device)
        self.action_loss.cuda(device)
        self.map_transform_w_to_s.cuda(device)
        self.map_transform_r_to_w.cuda(device)
        self.gt_fill_missing.cuda(device)
        return self

    def get_iter(self):
        return int(self.iter.data[0])

    def inc_iter(self):
        self.iter += 1

    def init_weights(self):
        self.map_to_action.init_weights()

    def reset(self):
        # TODO: This is error prone. Create a class StatefulModule, iterate submodules and reset all stateful modules
        super(ModelTrajectoryToAction, self).reset()
        self.map_transform_w_to_s.reset()
        self.map_transform_r_to_w.reset()
        self.gt_fill_missing.reset()

    def setEnvContext(self, context):
        print("Set env context to: " + str(context))
        self.env_id = context["env_id"]

    def start_segment_rollout(self):
        import rollout.run_metadata as md
        m_size = self.params["local_map_size"]
        w_size = self.params["world_size_px"]
        self.gt_labels = get_top_down_ground_truth_static_global(
            md.ENV_ID, md.START_IDX, md.END_IDX, m_size, m_size, w_size,
            w_size)
        self.seg_idx = md.SEG_IDX
        self.gt_labels = self.maybe_cuda(self.gt_labels)
        if self.params["clear_history"]:
            self.start_sequence()

    def get_action(self, state, instruction):
        """
        Given a DroneState (from PomdpInterface) and instruction, produce a numpy 4D action (x, y, theta, pstop)
        :param state: DroneState object with the raw image from the simulator
        :param instruction: Tokenized instruction given the corpus
        #TODO: Absorb corpus within model
        :return:
        """
        prof = SimpleProfiler(print=True)
        prof.tick(".")
        # TODO: Simplify this
        self.eval()
        images_np_pure = state.image
        state_np = state.state
        state = Variable(none_padded_seq_to_tensor([state_np]))

        #print("Act: " + debug_untokenize_instruction(instruction))

        # Add the batch dimension

        first_step = True
        if instruction == self.prev_instruction:
            first_step = False
        self.prev_instruction = instruction
        if first_step:
            self.get_act_start_pose = self.cam_poses_from_states(state[0:1])

        self.seq_step += 1

        # This is for training the policy to mimic the ground-truth state distribution with oracle actions
        # b_traj_gt_w_select = b_traj_ground_truth[b_plan_mask_t[:, np.newaxis, np.newaxis, np.newaxis].expand_as(b_traj_ground_truth)].view([-1] + gtsize)
        traj_gt_w = Variable(self.gt_labels)
        b_poses = self.cam_poses_from_states(state)
        # TODO: These source and dest should go as arguments to get_maps (in forward pass not params)
        transformer = MapTransformerBase(
            source_map_size=self.params["global_map_size"],
            world_size=self.params["world_size_px"],
            dest_map_size=self.params["local_map_size"])
        self.maybe_cuda(transformer)
        transformer.set_maps(traj_gt_w, None)
        traj_gt_r, _ = transformer.get_maps(b_poses)
        self.clear_inputs("traj_gt_r_select")
        self.clear_inputs("traj_gt_w_select")
        self.keep_inputs("traj_gt_r_select", traj_gt_r)
        self.keep_inputs("traj_gt_w_select", traj_gt_w)

        action = self(traj_gt_r, firstseg=[self.seq_step == 1])

        output_action = action.squeeze().data.cpu().numpy()

        stop_prob = output_action[3]
        output_stop = 1 if stop_prob > self.params["stop_threshold"] else 0
        output_action[3] = output_stop

        return output_action

    def deterministic_action(self, action_mean, action_std, stop_prob):
        batch_size = action_mean.size(0)
        action = Variable(
            empty_float_tensor((batch_size, 4), self.is_cuda,
                               self.cuda_device))
        action[:, 0:3] = action_mean[:, 0:3]
        action[:, 3] = stop_prob
        return action

    def sample_action(self, action_mean, action_std, stop_prob):
        action = torch.normal(action_mean, action_std)
        stop = torch.bernoulli(stop_prob)
        return action, stop

    # This is called before beginning an execution sequence
    def start_sequence(self):
        self.seq_step = 0
        self.reset()
        print("RESETTED!")
        return

    def cam_poses_from_states(self, states):
        cam_pos = states[:, 9:12]
        cam_rot = states[:, 12:16]
        pose = Pose(cam_pos, cam_rot)
        return pose

    def save(self, epoch):
        filename = self.params[
            "map_to_action_file"] + "_" + self.run_name + "_" + str(epoch)
        save_pytorch_model(self.map_to_action, filename)
        print("Saved action model to " + filename)

    def forward(self, traj_gt_r, firstseg=None):
        """
        :param images: BxCxHxW batch of images (observations)
        :param states: BxK batch of drone states
        :param instructions: BxM LongTensor where M is the maximum length of any instruction
        :param instr_lengths: list of len B of integers, indicating length of each instruction
        :param has_obs: list of booleans of length B indicating whether the given element in the sequence has an observation
        :param yield_semantic_maps: If true, will not compute actions (full model), but return the semantic maps that
            were built along the way in response to the images. This is ugly, but allows code reuse
        :return:
        """
        action_pred = self.map_to_action(traj_gt_r,
                                         None,
                                         fistseg_mask=firstseg)
        out_action = self.deterministic_action(action_pred[:, 0:3], None,
                                               action_pred[:, 3])
        self.keep_inputs("action", out_action)
        self.prof.tick("map_to_action")

        return out_action

    def maybe_cuda(self, tensor):
        if self.is_cuda:
            if False:
                if type(tensor) is Variable:
                    tensor.data.pin_memory()
                elif type(tensor) is Pose:
                    pass
                elif type(tensor) is torch.FloatTensor:
                    tensor.pin_memory()
            return tensor.cuda()
        else:
            return tensor

    def cuda_var(self, tensor):
        return cuda_var(tensor, self.is_cuda, self.cuda_device)

    # Forward pass for training (with batch optimizations
    def sup_loss_on_batch(self, batch, eval):
        self.prof.tick("out")

        action_loss_total = Variable(
            empty_float_tensor([1], self.is_cuda, self.cuda_device))

        if batch is None:
            print("Skipping None Batch")
            return action_loss_total

        actions = self.maybe_cuda(batch["actions"])
        states = self.maybe_cuda(batch["states"])

        firstseg_mask = batch["firstseg_mask"]

        # Auxiliary labels
        traj_ground_truth_select = self.maybe_cuda(batch["traj_ground_truth"])
        # stops = self.maybe_cuda(batch["stops"])
        metadata = batch["md"]
        batch_size = actions.size(0)
        count = 0

        # Loop thru batch
        for b in range(batch_size):
            seg_idx = -1

            self.reset()

            self.prof.tick("out")
            b_seq_len = len_until_nones(metadata[b])

            # TODO: Generalize this
            # Slice the data according to the sequence length
            b_metadata = metadata[b][:b_seq_len]
            b_actions = actions[b][:b_seq_len]
            b_traj_ground_truth_select = traj_ground_truth_select[b]
            b_states = states[b][:b_seq_len]

            self.keep_inputs("traj_gt_global_select",
                             b_traj_ground_truth_select)

            #b_firstseg = get_obs_mask_segstart(b_metadata)
            b_firstseg = firstseg_mask[b][:b_seq_len]

            # ----------------------------------------------------------------------------
            # Optional Auxiliary Inputs
            # ----------------------------------------------------------------------------
            gtsize = list(b_traj_ground_truth_select.size())[1:]
            b_poses = self.cam_poses_from_states(b_states)
            # TODO: These source and dest should go as arguments to get_maps (in forward pass not params)
            transformer = MapTransformerBase(
                source_map_size=self.params["global_map_size"],
                world_size=self.params["world_size_px"],
                dest_map_size=self.params["local_map_size"])
            self.maybe_cuda(transformer)
            transformer.set_maps(b_traj_ground_truth_select, None)
            traj_gt_local_select, _ = transformer.get_maps(b_poses)
            self.keep_inputs("traj_gt_r_select", traj_gt_local_select)
            self.keep_inputs("traj_gt_w_select", b_traj_ground_truth_select)

            # ----------------------------------------------------------------------------

            self.prof.tick("inputs")

            actions = self(traj_gt_local_select, firstseg=b_firstseg)
            action_losses, _ = self.action_loss(b_actions,
                                                actions,
                                                batchreduce=False)
            action_losses = self.action_loss.batch_reduce_loss(action_losses)
            action_loss = self.action_loss.reduce_loss(action_losses)
            action_loss_total = action_loss
            count += b_seq_len

            self.prof.tick("loss")

        action_loss_avg = action_loss_total / (count + 1e-9)
        prefix = self.model_name + ("/eval" if eval else "/train")
        self.writer.add_scalar(prefix + "/action_loss",
                               action_loss_avg.data.cpu()[0], self.get_iter())

        self.prof.tick("out")

        prefix = self.model_name + ("/eval" if eval else "/train")
        self.writer.add_dict(prefix, get_current_meters(), self.get_iter())

        self.inc_iter()

        self.prof.tick("summaries")
        self.prof.loop()
        self.prof.print_stats(1)

        return action_loss_avg

    def get_dataset(self, data=None, envs=None, dataset_name=None, eval=False):
        # TODO: Maybe use eval here
        data_sources = []
        data_sources.append(aup.PROVIDER_TRAJECTORY_GROUND_TRUTH_STATIC)
        return SegmentDataset(data=data,
                              env_list=envs,
                              dataset_name=dataset_name,
                              aux_provider_names=data_sources,
                              segment_level=True)
Exemplo n.º 3
0
class ModelTrajectoryTopDown(ModuleWithAuxiliaries):

    def __init__(self, run_name="", model_class=MODEL_RSS,
                 aux_class_features=False, aux_grounding_features=False,
                 aux_class_map=False, aux_grounding_map=False, aux_goal_map=False,
                 aux_lang=False, aux_traj=False, rot_noise=False, pos_noise=False):

        super(ModelTrajectoryTopDown, self).__init__()
        self.model_name = "sm_trajectory" + str(model_class)
        self.model_class = model_class
        print("Init model of type: ", str(model_class))
        self.run_name = run_name
        self.writer = LoggingSummaryWriter(log_dir="runs/" + run_name)

        self.params = get_current_parameters()["Model"]
        self.aux_weights = get_current_parameters()["AuxWeights"]

        self.prof = SimpleProfiler(torch_sync=PROFILE, print=PROFILE)
        self.iter = nn.Parameter(torch.zeros(1), requires_grad=False)

        # Auxiliary Objectives
        self.use_aux_class_features = aux_class_features
        self.use_aux_grounding_features = aux_grounding_features
        self.use_aux_class_on_map = aux_class_map
        self.use_aux_grounding_on_map = aux_grounding_map
        self.use_aux_goal_on_map = aux_goal_map
        self.use_aux_lang = aux_lang
        self.use_aux_traj_on_map = aux_traj
        self.use_aux_reg_map = self.aux_weights["regularize_map"]

        self.use_rot_noise = rot_noise
        self.use_pos_noise = pos_noise


        # Path-pred FPV model definition
        # --------------------------------------------------------------------------------------------------------------

        self.img_to_features_w = FPVToGlobalMap(
            source_map_size=self.params["global_map_size"], world_size_px=self.params["world_size_px"], world_size=self.params["world_size_m"],
            res_channels=self.params["resnet_channels"], map_channels=self.params["feature_channels"],
            img_w=self.params["img_w"], img_h=self.params["img_h"], img_dbg=IMG_DBG)

        self.map_accumulator_w = LeakyIntegratorGlobalMap(source_map_size=self.params["global_map_size"], world_in_map_size=self.params["world_size_px"])

        # Pre-process the accumulated map to do language grounding if necessary - in the world reference frame
        if self.use_aux_grounding_on_map and not self.use_aux_grounding_features:
            self.map_processor_a_w = LangFilterMapProcessor(
                source_map_size=self.params["global_map_size"],
                world_size=self.params["world_size_px"],
                embed_size=self.params["emb_size"],
                in_channels=self.params["feature_channels"],
                out_channels=self.params["relevance_channels"],
                spatial=False, cat_out=True)
        else:
            self.map_processor_a_w = IdentityMapProcessor(source_map_size=self.params["global_map_size"], world_size=self.params["world_size_px"])

        if self.use_aux_goal_on_map:
            self.map_processor_b_r = LangFilterMapProcessor(source_map_size=self.params["local_map_size"],
                                                            world_size=self.params["world_size_px"],
                                                            embed_size=self.params["emb_size"],
                                                            in_channels=self.params["relevance_channels"],
                                                            out_channels=self.params["goal_channels"],
                                                            spatial=True, cat_out=True)
        else:
            self.map_processor_b_r = IdentityMapProcessor(source_map_size=self.params["local_map_size"],
                                                          world_size=self.params["world_size_px"])

        pred_channels = self.params["goal_channels"] + self.params["relevance_channels"]

        # Common
        # --------------------------------------------------------------------------------------------------------------

        # Sentence Embedding
        self.sentence_embedding = SentenceEmbeddingSimple(
            self.params["word_emb_size"], self.params["emb_size"], self.params["emb_layers"])

        self.map_transform_w_to_r = MapTransformerBase(source_map_size=self.params["global_map_size"],
                                                       dest_map_size=self.params["local_map_size"],
                                                       world_size=self.params["world_size_px"])
        self.map_transform_r_to_w = MapTransformerBase(source_map_size=self.params["local_map_size"],
                                                       dest_map_size=self.params["global_map_size"],
                                                       world_size=self.params["world_size_px"])

        # Batch select is used to drop and forget semantic maps at those timestaps that we're not planning in
        self.batch_select = MapBatchSelect()
        # Since we only have path predictions for some timesteps (the ones not dropped above), we use this to fill
        # in the missing pieces by reorienting the past trajectory prediction into the frame of the current timestep
        self.map_batch_fill_missing = MapBatchFillMissing(self.params["local_map_size"], self.params["world_size_px"])

        # Passing true to freeze will freeze these weights regardless of whether they've been explicitly reloaded or not
        enable_weight_saving(self.sentence_embedding, "sentence_embedding", alwaysfreeze=False)

        # Output an action given the global semantic map
        if self.params["map_to_action"] == "downsample2":
            self.map_to_action = EgoMapToActionTriplet(
                map_channels=self.params["map_to_act_channels"],
                map_size=self.params["local_map_size"],
                other_features_size=self.params["emb_size"])

        elif self.params["map_to_action"] == "cropped":
            self.map_to_action = CroppedMapToActionTriplet(
                map_channels=self.params["map_to_act_channels"],
                map_size=self.params["local_map_size"],
                other_features_size=self.params["emb_size"]
            )

        # Don't freeze the trajectory to action weights, because it will be pre-trained during path-prediction training
        # and finetuned on all timesteps end-to-end
        enable_weight_saving(self.map_to_action, "map_to_action", alwaysfreeze=False, neverfreeze=True)

        # Auxiliary Objectives
        # --------------------------------------------------------------------------------------------------------------

        # We add all auxiliaries that are necessary. The first argument is the auxiliary name, followed by parameters,
        # followed by variable number of names of inputs. ModuleWithAuxiliaries will automatically collect these inputs
        # that have been saved with keep_auxiliary_input() during execution
        if aux_class_features:
            self.add_auxiliary(ClassAuxiliary2D("aux_class", None,  self.params["feature_channels"], self.params["num_landmarks"], self.params["dropout"],
                                                "fpv_features", "lm_pos_fpv", "lm_indices"))
        if aux_grounding_features:
            self.add_auxiliary(ClassAuxiliary2D("aux_ground", None, self.params["relevance_channels"], 2, self.params["dropout"],
                                                "fpv_features_g", "lm_pos_fpv", "lm_mentioned"))
        if aux_class_map:
            self.add_auxiliary(ClassAuxiliary2D("aux_class_map", self.params["world_size_px"], self.params["feature_channels"], self.params["num_landmarks"], self.params["dropout"],
                                                "map_s_w_select", "lm_pos_map_select", "lm_indices_select"))
        if aux_grounding_map:
            self.add_auxiliary(ClassAuxiliary2D("aux_grounding_map", self.params["world_size_px"], self.params["relevance_channels"], 2, self.params["dropout"],
                                                "map_a_w_select", "lm_pos_map_select", "lm_mentioned_select"))
        if aux_goal_map:
            self.add_auxiliary(GoalAuxiliary2D("aux_goal_map", self.params["goal_channels"], self.params["world_size_px"],
                                               "map_b_w", "goal_pos_map"))
        # RSS model uses templated data for landmark and side prediction
        if self.use_aux_lang and self.params["templates"]:
            self.add_auxiliary(ClassAuxiliary("aux_lang_lm", self.params["emb_size"], self.params["num_landmarks"], 1,
                                                "sentence_embed", "lm_mentioned_tplt"))
            self.add_auxiliary(ClassAuxiliary("aux_lang_side", self.params["emb_size"], self.params["num_sides"], 1,
                                                "sentence_embed", "side_mentioned_tplt"))
        # CoRL model uses alignment-model groundings
        elif self.use_aux_lang:
            # one output for each landmark, 2 classes per output. This is for finetuning, so use the embedding that's gonna be fine tuned
            self.add_auxiliary(ClassAuxiliary("aux_lang_lm_nl", self.params["emb_size"], 2, self.params["num_landmarks"],
                                                "sentence_embed", "lang_lm_mentioned"))
        if self.use_aux_traj_on_map:
            self.add_auxiliary(PathAuxiliary2D("aux_path", "map_b_r_select", "traj_gt_r_select"))

        if self.use_aux_reg_map:
            self.add_auxiliary(FeatureRegularizationAuxiliary2D("aux_regularize_features", None, "l1",
                                                                "map_s_w_select", "lm_pos_map_select"))

        self.goal_good_criterion = GoalPredictionGoodCriterion(ok_distance=3.2)
        self.goal_acc_meter = MovingAverageMeter(10)

        self.print_auxiliary_info()

        self.action_loss = ActionLoss()

        self.env_id = None
        self.prev_instruction = None
        self.seq_step = 0

    # TODO: Try to hide these in a superclass or something. They take up a lot of space:
    def cuda(self, device=None):
        ModuleWithAuxiliaries.cuda(self, device)
        self.sentence_embedding.cuda(device)
        self.map_accumulator_w.cuda(device)
        self.map_processor_a_w.cuda(device)
        self.map_processor_b_r.cuda(device)
        self.img_to_features_w.cuda(device)
        self.map_to_action.cuda(device)
        self.action_loss.cuda(device)
        self.map_batch_fill_missing.cuda(device)
        self.map_transform_w_to_r.cuda(device)
        self.map_transform_r_to_w.cuda(device)
        self.batch_select.cuda(device)
        self.map_batch_fill_missing.cuda(device)
        return self

    def get_iter(self):
        return int(self.iter.data[0])

    def inc_iter(self):
        self.iter += 1

    def init_weights(self):
        self.img_to_features_w.init_weights()
        self.map_accumulator_w.init_weights()
        self.sentence_embedding.init_weights()
        self.map_to_action.init_weights()
        self.map_processor_a_w.init_weights()
        self.map_processor_b_r.init_weights()

    def reset(self):
        # TODO: This is error prone. Create a class StatefulModule, iterate submodules and reset all stateful modules
        super(ModelTrajectoryTopDown, self).reset()
        self.sentence_embedding.reset()
        self.img_to_features_w.reset()
        self.map_accumulator_w.reset()
        self.map_processor_a_w.reset()
        self.map_processor_b_r.reset()
        self.map_transform_w_to_r.reset()
        self.map_transform_r_to_w.reset()
        self.map_batch_fill_missing.reset()
        self.prev_instruction = None

    def setEnvContext(self, context):
        print("Set env context to: " + str(context))
        self.env_id = context["env_id"]

    def save_viz(self, images_in):
        imsave(get_viz_dir() + "fpv_" + str(self.seq_step) + ".png", images_in)
        features_cam = self.get_inputs_batch("fpv_features")[-1, 0, 0:3]
        save_tensor_as_img(features_cam, "F_c", self.env_id)
        feature_map_torch = self.get_inputs_batch("f_w")[-1, 0, 0:3]
        save_tensor_as_img(feature_map_torch, "F_w", self.env_id)
        coverage_map_torch = self.get_inputs_batch("m_w")[-1, 0, 0:3]
        save_tensor_as_img(coverage_map_torch, "M_w", self.env_id)
        semantic_map_torch = self.get_inputs_batch("map_s_w_select")[-1, 0, 0:3]
        save_tensor_as_img(semantic_map_torch, "S_w", self.env_id)
        relmap_torch = self.get_inputs_batch("map_a_w_select")[-1, 0, 0:3]
        save_tensor_as_img(relmap_torch, "R_w", self.env_id)
        relmap_r_torch = self.get_inputs_batch("map_a_r_select")[-1, 0, 0:3]
        save_tensor_as_img(relmap_r_torch, "R_r", self.env_id)
        goalmap_torch = self.get_inputs_batch("map_b_w_select")[-1, 0, 0:3]
        save_tensor_as_img(goalmap_torch, "G_w", self.env_id)
        goalmap_r_torch = self.get_inputs_batch("map_b_r_select")[-1, 0, 0:3]
        save_tensor_as_img(goalmap_r_torch, "G_r", self.env_id)

        action = self.get_inputs_batch("action")[-1].data.cpu().squeeze().numpy()
        action_fname = self.get_viz_dir() + "action_" + str(self.seq_step) + ".png"
        Presenter().save_action(action, action_fname, "")

    def get_action(self, state, instruction):
        """
        Given a DroneState (from PomdpInterface) and instruction, produce a numpy 4D action (x, y, theta, pstop)
        :param state: DroneState object with the raw image from the simulator
        :param instruction: Tokenized instruction given the corpus
        #TODO: Absorb corpus within model
        :return:
        """
        # TODO: Simplify this
        self.eval()
        images_np_pure = state.image
        state_np = state.state

        #print("Act: " + debug_untokenize_instruction(instruction))

        images_np = standardize_image(images_np_pure)
        image_fpv = Variable(none_padded_seq_to_tensor([images_np]))
        state = Variable(none_padded_seq_to_tensor([state_np]))
        # Add the batch dimension

        first_step = True
        if instruction == self.prev_instruction:
            first_step = False
        self.prev_instruction = instruction

        img_in_t = image_fpv
        img_in_t.volatile = True

        instr_len = [len(instruction)] if instruction is not None else None
        instruction = torch.LongTensor(instruction).unsqueeze(0)
        instruction = cuda_var(instruction, self.is_cuda, self.cuda_device)

        state.volatile = True

        if self.is_cuda:
            if img_in_t is not None:
                img_in_t = img_in_t.cuda(self.cuda_device)
            state = state.cuda(self.cuda_device)

        step_enc = None
        plan_now = None

        self.seq_step += 1

        action = self(img_in_t, state, instruction, instr_len, plan=plan_now, pos_enc=step_enc)

        # Save materials for paper and presentation
        if False:
            self.save_viz(images_np_pure)

        output_action = action.squeeze().data.cpu().numpy()
        stop_prob = output_action[3]
        output_stop = 1 if stop_prob > 0.5 else 0
        output_action[3] = output_stop

        return output_action

    def deterministic_action(self, action_mean, action_std, stop_prob):
        batch_size = action_mean.size(0)
        action = Variable(empty_float_tensor((batch_size, 4), self.is_cuda, self.cuda_device))
        action[:, 0:3] = action_mean[:, 0:3]
        action[:, 3] = stop_prob
        return action

    def sample_action(self, action_mean, action_std, stop_prob):
        action = torch.normal(action_mean, action_std)
        stop = torch.bernoulli(stop_prob)
        return action, stop

    # This is called before beginning an execution sequence
    def start_sequence(self):
        self.seq_step = 0
        self.reset()
        print("RESETTED!")
        return

    # TODO: Move this somewhere and standardize
    def cam_poses_from_states(self, states):
        cam_pos = states[:, 9:12]
        cam_rot = states[:, 12:16]

        pos_variance = 0
        rot_variance = 0
        if self.use_pos_noise:
            pos_variance = self.params["noisy_pos_variance"]
        if self.use_rot_noise:
            rot_variance = self.params["noisy_rot_variance"]

        pose = Pose(cam_pos, cam_rot)
        if self.use_pos_noise or self.use_rot_noise:
            pose = get_noisy_poses_torch(pose, pos_variance, rot_variance, cuda=self.is_cuda, cuda_device=self.cuda_device)
        return pose

    def forward(self, images, states, instructions, instr_lengths, has_obs=None, plan=None, save_maps_only=False, pos_enc=None, noisy_poses=None):
        """
        :param images: BxCxHxW batch of images (observations)
        :param states: BxK batch of drone states
        :param instructions: BxM LongTensor where M is the maximum length of any instruction
        :param instr_lengths: list of len B of integers, indicating length of each instruction
        :param has_obs: list of booleans of length B indicating whether the given element in the sequence has an observation
        :param yield_semantic_maps: If true, will not compute actions (full model), but return the semantic maps that
            were built along the way in response to the images. This is ugly, but allows code reuse
        :return:
        """
        cam_poses = self.cam_poses_from_states(states)
        g_poses = None#[None for pose in cam_poses]
        self.prof.tick("out")

        #print("Trn: " + debug_untokenize_instruction(instructions[0].data[:instr_lengths[0]]))

        # Calculate the instruction embedding
        if instructions is not None:
            # TODO: Take batch of instructions and their lengths, return batch of embeddings. Store the last one as internal state
            sent_embeddings = self.sentence_embedding(instructions, instr_lengths)
            self.keep_inputs("sentence_embed", sent_embeddings)
        else:
            sent_embeddings = self.sentence_embedding.get()

        self.prof.tick("embed")

        # Extract and project features onto the egocentric frame for each image
        features_w, coverages_w = self.img_to_features_w(images, cam_poses, sent_embeddings, self, show="")
        self.prof.tick("img_to_map_frame")
        self.keep_inputs("f_w", features_w)
        self.keep_inputs("m_w", coverages_w)

        # Accumulate the egocentric features in a global map
        maps_w = self.map_accumulator_w(features_w, coverages_w, add_mask=has_obs, show="acc" if IMG_DBG else "")
        map_poses_w = g_poses

        # TODO: Maybe keep maps_w if necessary
        #self.keep_inputs("map_sm_local", maps_m)
        self.prof.tick("map_accumulate")

        # Throw away those timesteps that don't correspond to planning timesteps
        maps_w_select, map_poses_w_select, cam_poses_select, noisy_poses_select, _, sent_embeddings_select, pos_enc = \
            self.batch_select(maps_w, map_poses_w, cam_poses, noisy_poses, None, sent_embeddings, pos_enc, plan)

        # Only process the maps on planning timesteps
        if len(maps_w_select) > 0:
            self.keep_inputs("map_s_w_select", maps_w_select)
            self.prof.tick("batch_select")

            # Process the map via the two map_procesors
            # Do grounding of objects in the map chosen to do so
            maps_w_select, map_poses_w_select = self.map_processor_a_w(maps_w_select, sent_embeddings_select, map_poses_w_select, show="")
            self.keep_inputs("map_a_w_select", maps_w_select)

            self.prof.tick("map_proc_gnd")

            self.map_transform_w_to_r.set_maps(maps_w_select, map_poses_w_select)
            maps_m_select, map_poses_m_select = self.map_transform_w_to_r.get_maps(cam_poses_select)

            self.keep_inputs("map_a_r_select", maps_w_select)
            self.prof.tick("transform_w_to_r")

            self.keep_inputs("map_a_r_perturbed_select", maps_m_select)

            self.prof.tick("map_perturb")

            # Include positional encoding for path prediction
            if pos_enc is not None:
                sent_embeddings_pp = torch.cat([sent_embeddings_select, pos_enc.unsqueeze(1)], dim=1)
            else:
                sent_embeddings_pp = sent_embeddings_select

            # Process the map via the two map_procesors (e.g. predict the trajectory that we'll be taking)
            maps_m_select, map_poses_m_select = self.map_processor_b_r(maps_m_select, sent_embeddings_pp, map_poses_m_select)

            self.keep_inputs("map_b_r_select", maps_m_select)

            if True:
                self.map_transform_r_to_w.set_maps(maps_m_select, map_poses_m_select)
                maps_b_w_select, _ = self.map_transform_r_to_w.get_maps(None)
                self.keep_inputs("map_b_w_select", maps_b_w_select)

            self.prof.tick("map_proc_b")

        else:
            maps_m_select = None

        maps_m, map_poses_m = self.map_batch_fill_missing(maps_m_select, cam_poses, plan, show="")
        self.keep_inputs("map_b_r", maps_m)
        self.prof.tick("map_fill_missing")

        # Keep global maps for auxiliary objectives if necessary
        if self.input_required("map_b_w"):
            maps_b, _ = self.map_processor_b_r.get_maps(g_poses)
            self.keep_inputs("map_b_w", maps_b)

        self.prof.tick("keep_global_maps")

        if run_metadata.IS_ROLLOUT:
            pass
            #Presenter().show_image(maps_m.data[0, 0:3], "plan_map_now", torch=True, scale=4, waitkey=1)
            #Presenter().show_image(maps_w.data[0, 0:3], "sm_map_now", torch=True, scale=4, waitkey=1)
        self.prof.tick("viz")

        # Output the final action given the processed map
        action_pred = self.map_to_action(maps_m, sent_embeddings)
        out_action = self.deterministic_action(action_pred[:, 0:3], None, action_pred[:, 3])

        self.keep_inputs("action", out_action)
        self.prof.tick("map_to_action")

        return out_action

    # TODO: The below two methods seem to do the same thing
    def maybe_cuda(self, tensor):
        if self.is_cuda:
            return tensor.cuda()
        else:
            return tensor

    def cuda_var(self, tensor):
        return cuda_var(tensor, self.is_cuda, self.cuda_device)

    # Forward pass for training (with batch optimizations
    def sup_loss_on_batch(self, batch, eval):
        self.prof.tick("out")

        action_loss_total = Variable(empty_float_tensor([1], self.is_cuda, self.cuda_device))

        if batch is None:
            print("Skipping None Batch")
            return action_loss_total

        images = self.maybe_cuda(batch["images"])

        instructions = self.maybe_cuda(batch["instr"])
        instr_lengths = batch["instr_len"]
        states = self.maybe_cuda(batch["states"])
        actions = self.maybe_cuda(batch["actions"])

        # Auxiliary labels
        lm_pos_fpv = batch["lm_pos_fpv"]
        lm_pos_map = batch["lm_pos_map"]
        lm_indices = batch["lm_indices"]
        goal_pos_map = batch["goal_loc"]

        TEMPLATES = True
        if TEMPLATES:
            lm_mentioned_tplt = batch["lm_mentioned_tplt"]
            side_mentioned_tplt = batch["side_mentioned_tplt"]
        else:
            lm_mentioned = batch["lm_mentioned"]
            lang_lm_mentioned = batch["lang_lm_mentioned"]

        # stops = self.maybe_cuda(batch["stops"])
        masks = self.maybe_cuda(batch["masks"])
        # This is the first-timestep metadata
        metadata = batch["md"]

        seq_len = images.size(1)
        batch_size = images.size(0)
        count = 0
        correct_goal_count = 0
        goal_count = 0

        # Loop thru batch
        for b in range(batch_size):
            seg_idx = -1

            self.reset()

            self.prof.tick("out")
            b_seq_len = len_until_nones(metadata[b])

            # TODO: Generalize this
            # Slice the data according to the sequence length
            b_metadata = metadata[b][:b_seq_len]
            b_images = images[b][:b_seq_len]
            b_instructions = instructions[b][:b_seq_len]
            b_instr_len = instr_lengths[b][:b_seq_len]
            b_states = states[b][:b_seq_len]
            b_actions = actions[b][:b_seq_len]
            b_lm_pos_fpv = lm_pos_fpv[b][:b_seq_len]
            b_lm_pos_map = lm_pos_map[b][:b_seq_len]
            b_lm_indices = lm_indices[b][:b_seq_len]
            b_goal_pos = goal_pos_map[b][:b_seq_len]
            if not TEMPLATES:
                b_lang_lm_mentioned = lang_lm_mentioned[b][:b_seq_len]
                b_lm_mentioned = lm_mentioned[b][:b_seq_len]

            b_lm_pos_map = [self.cuda_var(s.long()) if s is not None else None for s in b_lm_pos_map]
            b_lm_pos_fpv = [self.cuda_var((s / RESNET_FACTOR).long()) if s is not None else None for s in b_lm_pos_fpv]
            b_lm_indices = [self.cuda_var(s) if s is not None else None for s in b_lm_indices]
            b_goal_pos = self.cuda_var(b_goal_pos)
            if not TEMPLATES:
                b_lang_lm_mentioned = self.cuda_var(b_lang_lm_mentioned)
                b_lm_mentioned = [self.cuda_var(s) if s is not None else None for s in b_lm_mentioned]

            # TODO: Figure out how to keep these properly. Perhaps as a whole batch is best
            # TODO: Introduce a key-value store (encapsulate instead of inherit)
            self.keep_inputs("lm_pos_fpv", b_lm_pos_fpv)
            self.keep_inputs("lm_pos_map", b_lm_pos_map)
            self.keep_inputs("lm_indices", b_lm_indices)
            self.keep_inputs("goal_pos_map", b_goal_pos)
            if not TEMPLATES:
                self.keep_inputs("lang_lm_mentioned", b_lang_lm_mentioned)
                self.keep_inputs("lm_mentioned", b_lm_mentioned)

            # TODO: Abstract all of these if-elses in a modular way once we know which ones are necessary
            if TEMPLATES:
                b_lm_mentioned_tplt = lm_mentioned_tplt[b][:b_seq_len]
                b_side_mentioned_tplt = side_mentioned_tplt[b][:b_seq_len]
                b_side_mentioned_tplt = self.cuda_var(b_side_mentioned_tplt)
                b_lm_mentioned_tplt = self.cuda_var(b_lm_mentioned_tplt)
                self.keep_inputs("lm_mentioned_tplt", b_lm_mentioned_tplt)
                self.keep_inputs("side_mentioned_tplt", b_side_mentioned_tplt)

                b_lm_mentioned = b_lm_mentioned_tplt


            b_obs_mask = [True for _ in range(b_seq_len)]
            b_plan_mask = [True for _ in range(b_seq_len)]
            b_plan_mask_t_cpu = torch.Tensor(b_plan_mask) == True
            b_plan_mask_t = self.maybe_cuda(b_plan_mask_t_cpu)
            b_pos_enc = None

            # ----------------------------------------------------------------------------
            # Optional Auxiliary Inputs
            # ----------------------------------------------------------------------------
            if self.input_required("lm_pos_map_select"):
                b_lm_pos_map_select = [lm_pos for i,lm_pos in enumerate(b_lm_pos_map) if b_plan_mask[i]]
                self.keep_inputs("lm_pos_map_select", b_lm_pos_map_select)
            if self.input_required("lm_indices_select"):
                b_lm_indices_select = [lm_idx for i,lm_idx in enumerate(b_lm_indices) if b_plan_mask[i]]
                self.keep_inputs("lm_indices_select", b_lm_indices_select)
            if self.input_required("lm_mentioned_select"):
                b_lm_mentioned_select = [lm_m for i,lm_m in enumerate(b_lm_mentioned) if b_plan_mask[i]]
                self.keep_inputs("lm_mentioned_select", b_lm_mentioned_select)

            # ----------------------------------------------------------------------------

            self.prof.tick("inputs")

            actions = self(b_images, b_states, b_instructions, b_instr_len,
                           has_obs=b_obs_mask, plan=b_plan_mask, pos_enc=b_pos_enc)

            action_losses, _ = self.action_loss(b_actions, actions, batchreduce=False)

            self.prof.tick("call")

            action_losses = self.action_loss.batch_reduce_loss(action_losses)
            action_loss = self.action_loss.reduce_loss(action_losses)

            action_loss_total = action_loss
            count += b_seq_len

            self.prof.tick("loss")

        action_loss_avg = action_loss_total / (count + 1e-9)

        self.prof.tick("out")

        # Doing this in the end (outside of se
        aux_losses = self.calculate_aux_loss(reduce_average=True)
        aux_loss = self.combine_aux_losses(aux_losses, self.aux_weights)

        prefix = self.model_name + ("/eval" if eval else "/train")

        self.writer.add_dict(prefix, get_current_meters(), self.get_iter())
        self.writer.add_dict(prefix, aux_losses, self.get_iter())
        self.writer.add_scalar(prefix + "/action_loss", action_loss_avg.data.cpu()[0], self.get_iter())
        # TODO: Log value here
        self.writer.add_scalar(prefix + "/goal_accuracy", self.goal_acc_meter.get(), self.get_iter())

        self.prof.tick("auxiliaries")

        total_loss = action_loss_avg + aux_loss

        self.inc_iter()

        self.prof.tick("summaries")
        self.prof.loop()
        self.prof.print_stats(1)

        return total_loss

    def get_dataset(self, data=None, envs=None, dataset_name=None, eval=False):
        # TODO: Maybe use eval here
        #if self.fpv:
        data_sources = []
        # If we're running auxiliary objectives, we need to include the data sources for the auxiliary labels
        #if self.use_aux_class_features or self.use_aux_class_on_map or self.use_aux_grounding_features or self.use_aux_grounding_on_map:
        #if self.use_aux_goal_on_map:
        data_sources.append(aup.PROVIDER_LM_POS_DATA)
        data_sources.append(aup.PROVIDER_GOAL_POS)
        #data_sources.append(aup.PROVIDER_LANDMARKS_MENTIONED)
        data_sources.append(aup.PROVIDER_LANG_TEMPLATE)

        #if self.use_rot_noise or self.use_pos_noise:
        #    data_sources.append(aup.PROVIDER_POSE_NOISE)

        return SegmentDataset(data=data, env_list=envs, dataset_name=dataset_name, aux_provider_names=data_sources, segment_level=True)
Exemplo n.º 4
0
class ModelMisra2017(ModuleWithAuxiliaries):
    def __init__(self, run_name=""):

        super(ModelMisra2017, self).__init__()
        self.model_name = "misra2017"
        self.run_name = run_name
        self.writer = LoggingSummaryWriter(log_dir="runs/" + run_name)

        self.params = get_current_parameters()["Model"]
        self.trajectory_len = get_current_parameters(
        )["Setup"]["trajectory_length"]

        self.prof = SimpleProfiler(torch_sync=PROFILE, print=PROFILE)
        self.iter = nn.Parameter(torch.zeros(1), requires_grad=False)

        # CNN over images - using what is essentially SimpleImage currently
        self.image_module = ImageCnnEmnlp(
            image_emb_size=self.params["image_emb_dim"],
            input_num_channels=3 *
            5,  # 3 channels per image - 5 images in history
            image_height=self.params["img_h"],
            image_width=self.params["img_w"])

        # LSTM to embed text
        self.text_module = TextSimpleModule(
            emb_dim=self.params["word_emb_dim"],
            hidden_dim=self.params["emb_size"],
            vocab_size=self.params["vocab_size"])

        # Action module to embed previous action+block
        self.action_module = ActionSimpleModule(
            num_actions=self.params["num_actions"],
            action_emb_size=self.params["action_emb_dim"])

        # Put it all together
        self.final_module = IncrementalMultimodalEmnlp(
            image_module=self.image_module,
            text_module=self.text_module,
            action_module=self.action_module,
            input_embedding_size=self.params["lstm_emb_dim"] +
            self.params["image_emb_dim"] + self.params["action_emb_dim"],
            output_hidden_size=self.params["h1_hidden_dim"],
            blocks_hidden_size=self.params["blocks_hidden_dim"],
            directions_hidden_size=self.params["action_hidden_dim"],
            max_episode_length=self.trajectory_len)

        self.action_loss = ActionLoss()

        self.env_id = None
        self.prev_instruction = None
        self.seq_step = 0
        self.model_state = None
        self.image_emb_seq = None
        self.state_feature = None

    # TODO: Try to hide these in a superclass or something. They take up a lot of space:
    def cuda(self, device=None):
        ModuleWithAuxiliaries.cuda(self, device)
        self.image_module.cuda(device)
        self.text_module.cuda(device)
        self.final_module.cuda(device)
        self.action_module.cuda(device)
        self.action_loss.cuda(device)
        return self

    def get_iter(self):
        return int(self.iter.data[0])

    def inc_iter(self):
        self.iter += 1

    def init_weights(self):
        self.final_module.init_weights()

    def reset(self):
        # TODO: This is error prone. Create a class StatefulModule, iterate submodules and reset all stateful modules
        super(ModelMisra2017, self).reset()
        self.seq_step = 0
        self.model_state = None
        pass

    def setEnvContext(self, context):
        print("Set env context to: " + str(context))
        self.env_id = context["env_id"]

    def get_action(self, state, instruction):
        """
        Given a DroneState (from PomdpInterface) and instruction, produce a numpy 4D action (x, y, theta, pstop)
        :param state: DroneState object with the raw image from the simulator
        :param instruction: Tokenized instruction given the corpus
        #TODO: Absorb corpus within model
        :return:
        """
        # TODO: Simplify this
        self.eval()
        images_np_pure = state.image
        state_np = state.state

        #print("Act: " + debug_untokenize_instruction(instruction))

        images_np = standardize_image(images_np_pure)
        image_fpv = Variable(none_padded_seq_to_tensor([images_np]))
        state = Variable(none_padded_seq_to_tensor([state_np]))
        # Add the batch dimension

        first_step = True
        if instruction == self.prev_instruction:
            first_step = False
        self.prev_instruction = instruction

        img_in_t = image_fpv
        img_in_t.volatile = True

        instr_len = [len(instruction)] if instruction is not None else None
        instruction = torch.LongTensor(instruction).unsqueeze(0)
        instruction = cuda_var(instruction, self.is_cuda, self.cuda_device)

        state.volatile = True

        if self.is_cuda:
            img_in_t = img_in_t.cuda(self.cuda_device)
            state = state.cuda(self.cuda_device)

        self.seq_step += 1

        action = self(img_in_t, instruction, instr_len)

        output_action = action.squeeze().data.cpu().numpy()
        stop_prob = output_action[3]
        output_stop = 1 if stop_prob > 0.5 else 0
        output_action[3] = output_stop

        #print("action: ", output_action)

        return output_action

    def deterministic_action(self, action_mean, action_std, stop_prob):
        batch_size = action_mean.size(0)
        action = Variable(
            empty_float_tensor((batch_size, 4), self.is_cuda,
                               self.cuda_device))
        action[:, 0:3] = action_mean[:, 0:3]
        action[:, 3] = stop_prob
        return action

    def sample_action(self, action_mean, action_std, stop_prob):
        action = torch.normal(action_mean, action_std)
        stop = torch.bernoulli(stop_prob)
        return action, stop

    # This is called before beginning an execution sequence
    def start_sequence(self):
        self.seq_step = 0
        self.reset()
        print("RESETTED!")
        return

    # TODO: Move this somewhere and standardize
    def cam_poses_from_states(self, states):
        cam_pos = states[:, 9:12]
        cam_rot = states[:, 12:16]
        pose = Pose(cam_pos, cam_rot)
        return pose

    def instructions_to_dipandrew(self, instructions, instr_lengths):
        out = []
        for i in range(len(instructions)):
            instr_i = instructions[i:i + 1, 0:instr_lengths[i]]
            out.append(instr_i)
        return out

    def forward(self, images, instructions, instr_lengths):

        seq_len = len(images)

        instr_dipandrew = self.instructions_to_dipandrew(
            instructions, instr_lengths)

        # Add sequence dimension, since we're treating batches as sequences
        images = images.unsqueeze(0)

        all_actions = []
        for i in range(seq_len):
            time_in = np.asarray([self.seq_step])
            time_in = Variable(
                self.maybe_cuda(torch.from_numpy(time_in).long()))
            action_i, self.model_state = self.final_module(
                images[0:1, i:i + 1], instr_dipandrew[i], time_in,
                self.model_state)

            self.seq_step += 1
            all_actions.append(action_i)

        actions = torch.cat(all_actions, dim=0)
        return actions

    def maybe_cuda(self, tensor):
        if self.is_cuda:
            return tensor.cuda()
        else:
            return tensor

    def cuda_var(self, tensor):
        return cuda_var(tensor, self.is_cuda, self.cuda_device)

    # Forward pass for training (with batch optimizations
    def sup_loss_on_batch(self, batch, eval):
        self.prof.tick("out")

        action_loss_total = Variable(
            empty_float_tensor([1], self.is_cuda, self.cuda_device))

        if batch is None:
            print("Skipping None Batch")
            return action_loss_total

        images = self.maybe_cuda(batch["images"])
        instructions = self.maybe_cuda(batch["instr"])
        instr_lengths = batch["instr_len"]
        actions = self.maybe_cuda(batch["actions"])

        metadata = batch["md"]

        batch_size = images.size(0)
        count = 0

        # Loop thru batch
        for b in range(batch_size):
            self.reset()
            self.prof.tick("out")
            b_seq_len = len_until_nones(metadata[b])

            # TODO: Generalize this
            # Slice the data according to the sequence length
            b_metadata = metadata[b][:b_seq_len]
            b_images = images[b][:b_seq_len]
            b_instructions = instructions[b][:b_seq_len]
            b_instr_len = instr_lengths[b][:b_seq_len]
            b_actions = actions[b][:b_seq_len]

            # ----------------------------------------------------------------------------

            self.prof.tick("inputs")

            actions = self(b_images, b_instructions, b_instr_len)

            action_losses, _ = self.action_loss(b_actions,
                                                actions,
                                                batchreduce=False)

            self.prof.tick("call")
            action_losses = self.action_loss.batch_reduce_loss(action_losses)
            action_loss = self.action_loss.reduce_loss(action_losses)
            action_loss_total = action_loss
            count += b_seq_len

            self.prof.tick("loss")

        action_loss_avg = action_loss_total / (count + 1e-9)

        self.prof.tick("out")

        prefix = self.model_name + ("/eval" if eval else "/train")

        self.writer.add_dict(prefix, get_current_meters(), self.get_iter())
        self.writer.add_scalar(prefix + "/action_loss",
                               action_loss_avg.data.cpu()[0], self.get_iter())

        total_loss = action_loss_avg

        self.inc_iter()

        self.prof.loop()
        self.prof.print_stats(1)

        return total_loss

    def get_dataset(self, data=None, envs=None, dataset_name=None, eval=False):
        # TODO: Maybe use eval here
        #if self.fpv:
        return SegmentDataset(data=data,
                              env_list=envs,
                              dataset_name=dataset_name,
                              aux_provider_names=[],
                              segment_level=True)
Exemplo n.º 5
0
class ModelGSMNBiDomain(nn.Module):
    def __init__(self, run_name="", model_instance_name=""):

        super(ModelGSMNBiDomain, self).__init__()
        self.model_name = "gsmn_bidomain"
        self.run_name = run_name
        self.name = model_instance_name
        if not self.name:
            self.name = ""
        self.writer = LoggingSummaryWriter(
            log_dir=f"runs/{run_name}/{self.name}")

        self.params = get_current_parameters()["Model"]
        self.aux_weights = get_current_parameters()["AuxWeights"]
        self.use_aux = self.params["UseAuxiliaries"]

        self.prof = SimpleProfiler(torch_sync=PROFILE, print=PROFILE)
        self.iter = nn.Parameter(torch.zeros(1), requires_grad=False)

        self.tensor_store = KeyTensorStore()
        self.aux_losses = AuxiliaryLosses()

        self.rviz = None
        if self.params.get("rviz"):
            self.rviz = RvizInterface(
                base_name="/gsmn/",
                map_topics=["semantic_map", "grounding_map", "goal_map"],
                markerarray_topics=["instruction"])

        # Path-pred FPV model definition
        # --------------------------------------------------------------------------------------------------------------

        self.img_to_features_w = FPVToGlobalMap(
            source_map_size=self.params["global_map_size"],
            world_size_px=self.params["world_size_px"],
            world_size_m=self.params["world_size_m"],
            res_channels=self.params["resnet_channels"],
            map_channels=self.params["feature_channels"],
            img_w=self.params["img_w"],
            img_h=self.params["img_h"],
            cam_h_fov=self.params["cam_h_fov"],
            img_dbg=IMG_DBG)

        self.map_accumulator_w = LeakyIntegratorGlobalMap(
            source_map_size=self.params["global_map_size"],
            world_size_px=self.params["world_size_px"],
            world_size_m=self.params["world_size_m"])

        # Pre-process the accumulated map to do language grounding if necessary - in the world reference frame
        if self.use_aux[
                "grounding_map"] and not self.use_aux["grounding_features"]:
            self.map_processor_a_w = LangFilterMapProcessor(
                embed_size=self.params["emb_size"],
                in_channels=self.params["feature_channels"],
                out_channels=self.params["relevance_channels"],
                spatial=False,
                cat_out=True)
        else:
            self.map_processor_a_w = IdentityMapProcessor(
                source_map_size=self.params["global_map_size"],
                world_size_px=self.params["world_size_px"],
                world_size_m=self.params["world_size_m"])

        if self.use_aux["goal_map"]:
            self.map_processor_b_r = LangFilterMapProcessor(
                embed_size=self.params["emb_size"],
                in_channels=self.params["relevance_channels"],
                out_channels=self.params["goal_channels"],
                spatial=self.params["spatial_goal_filter"],
                cat_out=self.params["cat_rel_and_goal"])
        else:
            self.map_processor_b_r = IdentityMapProcessor(
                source_map_size=self.params["local_map_size"],
                world_size_px=self.params["world_size_px"],
                world_size_m=self.params["world_size_m"])

        # Common
        # --------------------------------------------------------------------------------------------------------------

        # Sentence Embedding
        self.sentence_embedding = SentenceEmbeddingSimple(
            self.params["word_emb_size"],
            self.params["emb_size"],
            self.params["emb_layers"],
            dropout=0.0)

        self.map_transform_w_to_r = MapTransformerBase(
            source_map_size=self.params["global_map_size"],
            dest_map_size=self.params["local_map_size"],
            world_size_px=self.params["world_size_px"],
            world_size_m=self.params["world_size_m"])
        self.map_transform_r_to_w = MapTransformerBase(
            source_map_size=self.params["local_map_size"],
            dest_map_size=self.params["global_map_size"],
            world_size_px=self.params["world_size_px"],
            world_size_m=self.params["world_size_m"])

        # Output an action given the global semantic map
        if self.params["map_to_action"] == "downsample2":
            self.map_to_action = EgoMapToActionTriplet(
                map_channels=self.params["map_to_act_channels"],
                map_size=self.params["local_map_size"],
                other_features_size=self.params["emb_size"])

        elif self.params["map_to_action"] == "cropped":
            self.map_to_action = CroppedMapToActionTriplet(
                map_channels=self.params["map_to_act_channels"],
                map_size=self.params["local_map_size"])

        # Auxiliary Objectives
        # --------------------------------------------------------------------------------------------------------------

        # We add all auxiliaries that are necessary. The first argument is the auxiliary name, followed by parameters,
        # followed by variable number of names of inputs. ModuleWithAuxiliaries will automatically collect these inputs
        # that have been saved with keep_auxiliary_input() during execution
        if self.use_aux["class_features"]:
            self.aux_losses.add_auxiliary(
                ClassAuxiliary2D("aux_class", self.params["feature_channels"],
                                 self.params["num_landmarks"],
                                 self.params["dropout"], "fpv_features",
                                 "lm_pos_fpv_features", "lm_indices",
                                 "tensor_store"))
        if self.use_aux["grounding_features"]:
            self.aux_losses.add_auxiliary(
                ClassAuxiliary2D("aux_ground",
                                 self.params["relevance_channels"], 2,
                                 self.params["dropout"], "fpv_features_g",
                                 "lm_pos_fpv_features", "lm_mentioned",
                                 "tensor_store"))
        if self.use_aux["class_map"]:
            self.aux_losses.add_auxiliary(
                ClassAuxiliary2D("aux_class_map",
                                 self.params["feature_channels"],
                                 self.params["num_landmarks"],
                                 self.params["dropout"], "map_S_W",
                                 "lm_pos_map", "lm_indices", "tensor_store"))
        if self.use_aux["grounding_map"]:
            self.aux_losses.add_auxiliary(
                ClassAuxiliary2D("aux_grounding_map",
                                 self.params["relevance_channels"], 2,
                                 self.params["dropout"], "map_R_W",
                                 "lm_pos_map", "lm_mentioned", "tensor_store"))
        if self.use_aux["goal_map"]:
            self.aux_losses.add_auxiliary(
                GoalAuxiliary2D("aux_goal_map", self.params["goal_channels"],
                                self.params["global_map_size"], "map_G_W",
                                "goal_pos_map"))
        # RSS model uses templated data for landmark and side prediction
        if self.use_aux["language"] and self.params["templates"]:
            self.aux_losses.add_auxiliary(
                ClassAuxiliary("aux_lang_lm", self.params["emb_size"],
                               self.params["num_landmarks"], 1,
                               "sentence_embed", "lm_mentioned_tplt"))
            self.aux_losses.add_auxiliary(
                ClassAuxiliary("aux_lang_side", self.params["emb_size"],
                               self.params["num_sides"], 1, "sentence_embed",
                               "side_mentioned_tplt"))
        # CoRL model uses alignment-model groundings
        elif self.use_aux["language"]:
            # one output for each landmark, 2 classes per output. This is for finetuning, so use the embedding that's gonna be fine tuned
            self.aux_losses.add_auxiliary(
                ClassAuxiliary("aux_lang_lm_nl", self.params["emb_size"], 2,
                               self.params["num_landmarks"], "sentence_embed",
                               "lang_lm_mentioned"))
        if self.use_aux["l1_regularization"]:
            self.aux_losses.add_auxiliary(
                FeatureRegularizationAuxiliary2D("aux_regularize_features",
                                                 "l1", "map_S_W"))
            self.aux_losses.add_auxiliary(
                FeatureRegularizationAuxiliary2D("aux_regularize_features",
                                                 "l1", "map_R_W"))

        self.goal_acc_meter = MovingAverageMeter(10)

        self.aux_losses.print_auxiliary_info()

        self.action_loss = ActionLoss()

        self.env_id = None
        self.prev_instruction = None
        self.seq_step = 0

    def cuda(self, device=None):
        CudaModule.cuda(self, device)
        self.aux_losses.cuda(device)
        self.sentence_embedding.cuda(device)
        self.map_accumulator_w.cuda(device)
        self.map_processor_a_w.cuda(device)
        self.map_processor_b_r.cuda(device)
        self.img_to_features_w.cuda(device)
        self.map_to_action.cuda(device)
        self.action_loss.cuda(device)
        self.map_transform_w_to_r.cuda(device)
        self.map_transform_r_to_w.cuda(device)
        return self

    def steal_cross_domain_modules(self, other_self):
        # TODO: Consider whether to share auxiliary losses, and if so, all of them?
        self.aux_losses = other_self.aux_losses
        self.action_loss = other_self.action_loss

        # TODO: Make sure that none of these things are stateful, or that there are resets after every forward pass
        self.sentence_embedding = other_self.sentence_embedding
        self.map_accumulator_w = other_self.map_accumulator_w
        self.map_processor_a_w = other_self.map_processor_a_w
        self.map_processor_b_r = other_self.map_processor_b_r
        self.map_to_action = other_self.map_to_action

        # We'll have a separate one of these for each domain
        #self.img_to_features_w = other_self.img_to_features_w

        # TODO: Check that statefulness is not an issue in sharing modules
        # These have no parameters so no point sharing
        #self.map_transform_w_to_r = other_self.map_transform_w_to_r
        #self.map_transform_r_to_w = other_self.map_transform_r_to_w

    def both_domain_parameters(self, other_self):
        # This function iterates and yields parameters from this module and the other module, but does not yield
        # shared parameters twice.
        # First yield all of the other module's parameters
        for p in other_self.parameters():
            yield p
        # Then yield all the parameters from the this module that are not shared with the other one
        for p in self.img_to_features_w.parameters():
            yield p
        return

    def get_iter(self):
        return int(self.iter.data[0])

    def inc_iter(self):
        self.iter += 1

    def load_img_feature_weights(self):
        if self.params.get("load_feature_net"):
            filename = self.params.get("feature_net_filename")
            weights = load_pytorch_model(None, filename)
            prefix = self.params.get("feature_net_tensor_name")
            if prefix:
                weights = find_state_subdict(weights, prefix)
            # TODO: This breaks OOP conventions
            self.img_to_features_w.img_to_features.load_state_dict(weights)
            print(
                f"Loaded pretrained weights from file {filename} with prefix {prefix}"
            )

    def init_weights(self):
        self.img_to_features_w.init_weights()
        self.load_img_feature_weights()
        self.map_accumulator_w.init_weights()
        self.sentence_embedding.init_weights()
        self.map_to_action.init_weights()
        self.map_processor_a_w.init_weights()
        self.map_processor_b_r.init_weights()

    def reset(self):
        self.tensor_store.reset()
        self.sentence_embedding.reset()
        self.img_to_features_w.reset()
        self.map_accumulator_w.reset()
        self.map_transform_w_to_r.reset()
        self.map_transform_r_to_w.reset()
        self.load_img_feature_weights()
        self.prev_instruction = None

    def set_env_context(self, context):
        print("Set env context to: " + str(context))
        self.env_id = context["env_id"]

    def save_viz(self, images_in, instruction):
        # Save incoming images
        imsave(
            os.path.join(get_viz_dir_for_rollout(),
                         "fpv_" + str(self.seq_step) + ".png"), images_in)
        #self.tensor_store.keep_input("fpv_img", images_in)
        # Save all of these tensors from the tensor store as images
        save_tensors_as_images(self.tensor_store, [
            "images_w", "fpv_img", "fpv_features", "map_F_W", "map_M_W",
            "map_S_W", "map_R_W", "map_R_R", "map_G_R", "map_G_W"
        ], str(self.seq_step))

        # Save action as image
        action = self.tensor_store.get_inputs_batch(
            "action")[-1].data.cpu().squeeze().numpy()
        action_fname = get_viz_dir_for_rollout() + "action_" + str(
            self.seq_step) + ".png"
        Presenter().save_action(action, action_fname, "")

        instruction_fname = get_viz_dir_for_rollout() + "instruction.txt"
        with open(instruction_fname, "w") as fp:
            fp.write(instruction)

    def get_action(self, state, instruction):
        """
        Given a DroneState (from PomdpInterface) and instruction, produce a numpy 4D action (x, y, theta, pstop)
        :param state: DroneState object with the raw image from the simulator
        :param instruction: Tokenized instruction given the corpus
        #TODO: Absorb corpus within model
        :return:
        """
        # TODO: Simplify this
        self.eval()
        images_np_pure = state.image
        state_np = state.state

        #print("Act: " + debug_untokenize_instruction(instruction))

        images_np = standardize_image(images_np_pure)
        image_fpv = Variable(none_padded_seq_to_tensor([images_np]))
        state = Variable(none_padded_seq_to_tensor([state_np]))
        # Add the batch dimension

        first_step = True
        if instruction == self.prev_instruction:
            first_step = False
        self.prev_instruction = instruction
        instruction_str = debug_untokenize_instruction(instruction)

        # TODO: Move this to PomdpInterface (for now it's here because this is already visualizing the maps)
        if first_step:
            if self.rviz is not None:
                self.rviz.publish_instruction_text(
                    "instruction", debug_untokenize_instruction(instruction))

        img_in_t = image_fpv
        img_in_t.volatile = True

        instr_len = [len(instruction)] if instruction is not None else None
        instruction = torch.LongTensor(instruction).unsqueeze(0)
        instruction = cuda_var(instruction, self.is_cuda, self.cuda_device)

        state.volatile = True

        if self.is_cuda:
            if img_in_t is not None:
                img_in_t = img_in_t.cuda(self.cuda_device)
            state = state.cuda(self.cuda_device)

        step_enc = None
        plan_now = None

        self.seq_step += 1

        action = self(img_in_t,
                      state,
                      instruction,
                      instr_len,
                      plan=plan_now,
                      pos_enc=step_enc)

        passive_mode_debug_projections = True
        if passive_mode_debug_projections:
            self.show_landmark_locations(loop=False, states=state)
            self.reset()

        # Run auxiliary objectives for debugging purposes (e.g. to compute classification predictions)
        if self.params.get("run_auxiliaries_at_test_time"):
            _, _ = self.aux_losses.calculate_aux_loss(self.tensor_store,
                                                      reduce_average=True)
            overlaid = self.get_overlaid_classification_results(
                whole_batch=False)

        # Save materials for analysis and presentation
        if self.params["write_figures"]:
            self.save_viz(images_np_pure, instruction_str)

        output_action = action.squeeze().data.cpu().numpy()
        stop_prob = output_action[3]
        output_stop = 1 if stop_prob > self.params["stop_p"] else 0
        output_action[3] = output_stop

        return output_action

    def get_overlaid_classification_results(self, map_not_features=False):
        if map_not_features:
            predictions_name = "aux_class_map_predictions"
        else:
            predictions_name = "aux_class_predictions"
        predictions = self.tensor_store.get_latest_input(predictions_name)
        if predictions is None:
            return None
        predictions = predictions[0].detach()
        # Get the 3 channels corresponding to no landmark, banana and gorilla
        predictions = predictions[[0, 3, 24], :, :]
        images = self.tensor_store.get_latest_input("images")[0].detach()
        overlaid = Presenter().overlaid_image(images,
                                              predictions,
                                              gray_bg=True)
        return overlaid

    def deterministic_action(self, action_mean, action_std, stop_prob):
        batch_size = action_mean.size(0)
        action = Variable(
            empty_float_tensor((batch_size, 4), self.is_cuda,
                               self.cuda_device))
        action[:, 0:3] = action_mean[:, 0:3]
        action[:, 3] = stop_prob
        return action

    # This is called before beginning an execution sequence
    def start_sequence(self):
        self.seq_step = 0
        self.reset()
        print("RESETTED!")
        return

    # TODO: Move this somewhere and standardize
    def cam_poses_from_states(self, states):
        cam_pos = states[:, 9:12]
        cam_rot = states[:, 12:16]

        pos_variance = 0
        rot_variance = 0
        if self.params.get("use_pos_noise"):
            pos_variance = self.params["noisy_pos_variance"]
        if self.params.get("use_rot_noise"):
            rot_variance = self.params["noisy_rot_variance"]

        pose = Pose(cam_pos, cam_rot)
        if self.params.get("use_pos_noise") or self.params.get(
                "use_rot_noise"):
            pose = get_noisy_poses_torch(pose,
                                         pos_variance,
                                         rot_variance,
                                         cuda=self.is_cuda,
                                         cuda_device=self.cuda_device)
        return pose

    def forward(self,
                images,
                states,
                instructions,
                instr_lengths,
                has_obs=None,
                plan=None,
                save_maps_only=False,
                pos_enc=None,
                noisy_poses=None,
                halfway=False):
        """
        :param images: BxCxHxW batch of images (observations)
        :param states: BxK batch of drone states
        :param instructions: BxM LongTensor where M is the maximum length of any instruction
        :param instr_lengths: list of len B of integers, indicating length of each instruction
        :param has_obs: list of booleans of length B indicating whether the given element in the sequence has an observation
        :param yield_semantic_maps: If true, will not compute actions (full model), but return the semantic maps that
            were built along the way in response to the images. This is ugly, but allows code reuse
        :return:
        """
        cam_poses = self.cam_poses_from_states(states)
        g_poses = None  #[None for pose in cam_poses]
        self.prof.tick("out")

        #str_instr = debug_untokenize_instruction(instructions[0].data[:instr_lengths[0]])
        #print("Trn: " + str_instr)

        # Calculate the instruction embedding
        if instructions is not None:
            # TODO: Take batch of instructions and their lengths, return batch of embeddings. Store the last one as internal state
            sent_embeddings = self.sentence_embedding(instructions,
                                                      instr_lengths)
            self.tensor_store.keep_inputs("sentence_embed", sent_embeddings)
        else:
            sent_embeddings = self.sentence_embedding.get()

        self.prof.tick("embed")

        # Extract and project features onto the egocentric frame for each image
        features_w, coverages_w = self.img_to_features_w(images,
                                                         cam_poses,
                                                         sent_embeddings,
                                                         self.tensor_store,
                                                         show="")

        # If we're running the model halway, return now. This is to compute enough features for the wasserstein critic, but no more
        if halfway:
            return None

        # Don't back-prop into resnet if we're freezing these features (TODO: instead set requires grad to false)
        if self.params.get("freeze_feature_net"):
            features_w = features_w.detach()

        self.prof.tick("img_to_map_frame")
        self.tensor_store.keep_inputs("images", images)
        self.tensor_store.keep_inputs("map_F_w", features_w)
        self.tensor_store.keep_inputs("map_M_w", coverages_w)

        if run_metadata.IS_ROLLOUT:
            Presenter().show_image(features_w.data[0, 0:3],
                                   "F",
                                   torch=True,
                                   scale=8,
                                   waitkey=1)

        # Accumulate the egocentric features in a global map
        maps_s_w = self.map_accumulator_w(features_w,
                                          coverages_w,
                                          add_mask=has_obs,
                                          show="acc" if IMG_DBG else "")
        map_poses_w = g_poses
        self.tensor_store.keep_inputs("map_S_W", maps_s_w)
        self.prof.tick("map_accumulate")

        Presenter().show_image(maps_s_w.data[0],
                               f"{self.name}_S_map_W",
                               torch=True,
                               scale=4,
                               waitkey=1)

        # Do grounding of objects in the map chosen to do so
        maps_r_w, map_poses_r_w = self.map_processor_a_w(maps_s_w,
                                                         sent_embeddings,
                                                         map_poses_w,
                                                         show="")
        self.tensor_store.keep_inputs("map_R_W", maps_r_w)
        Presenter().show_image(maps_r_w.data[0],
                               f"{self.name}_R_map_W",
                               torch=True,
                               scale=4,
                               waitkey=1)
        self.prof.tick("map_proc_gnd")

        # Transform to drone's reference frame
        self.map_transform_w_to_r.set_maps(maps_r_w, map_poses_r_w)
        maps_r_r, map_poses_r_r = self.map_transform_w_to_r.get_maps(cam_poses)
        self.tensor_store.keep_inputs("map_R_R", maps_r_r)
        self.prof.tick("transform_w_to_r")

        # Predict goal location
        maps_g_r, map_poses_g_r = self.map_processor_b_r(
            maps_r_r, sent_embeddings, map_poses_r_r)
        self.tensor_store.keep_inputs("map_G_R", maps_g_r)

        # Transform back to map frame
        self.map_transform_r_to_w.set_maps(maps_g_r, map_poses_g_r)
        maps_g_w, _ = self.map_transform_r_to_w.get_maps(None)
        self.tensor_store.keep_inputs("map_G_W", maps_g_w)
        self.prof.tick("map_proc_b")

        # Show and publish to RVIZ
        Presenter().show_image(maps_g_w.data[0],
                               f"{self.name}_G_map_W",
                               torch=True,
                               scale=8,
                               waitkey=1)
        if self.rviz:
            self.rviz.publish_map(
                "goal_map", maps_g_w[0].data.cpu().numpy().transpose(1, 2, 0),
                self.params["world_size_m"])

        # Output the final action given the processed map
        action_pred = self.map_to_action(maps_g_r, sent_embeddings)
        out_action = self.deterministic_action(action_pred[:, 0:3], None,
                                               action_pred[:, 3])
        self.tensor_store.keep_inputs("action", out_action)
        self.prof.tick("map_to_action")

        return out_action

    # TODO: The below two methods seem to do the same thing
    def maybe_cuda(self, tensor):
        if self.is_cuda:
            return tensor.cuda()
        else:
            return tensor

    def cuda_var(self, tensor):
        return cuda_var(tensor, self.is_cuda, self.cuda_device)

    def unbatch(self, batch):
        # TODO: Carefully consider this line. This is necessary to reset state between batches (e.g. delete all tensors in the tensor store)
        self.reset()
        # Get rid of the batch dimension for everything
        images = self.maybe_cuda(batch["images"])[0]
        seq_len = images.shape[0]
        instructions = self.maybe_cuda(batch["instr"])[0][:seq_len]
        instr_lengths = batch["instr_len"][0]
        states = self.maybe_cuda(batch["states"])[0]
        actions = self.maybe_cuda(batch["actions"])[0]

        # Auxiliary labels
        lm_pos_fpv = batch["lm_pos_fpv"][0]
        lm_pos_map = batch["lm_pos_map"][0]
        lm_indices = batch["lm_indices"][0]
        goal_pos_map = batch["goal_loc"][0]

        # TODO: Get rid of this. We will have lm_mentioned booleans and lm_mentioned_idx integers and that's it.
        TEMPLATES = True
        if TEMPLATES:
            lm_mentioned_tplt = batch["lm_mentioned_tplt"][0]
            side_mentioned_tplt = batch["side_mentioned_tplt"][0]
            side_mentioned_tplt = self.cuda_var(side_mentioned_tplt)
            lm_mentioned_tplt = self.cuda_var(lm_mentioned_tplt)
            lang_lm_mentioned = None
        else:
            lm_mentioned_tplt = None
            side_mentioned_tplt = None
            lang_lm_mentioned = batch["lang_lm_mentioned"][0]
        lm_mentioned = batch["lm_mentioned"][0]
        # This is the first-timestep metadata
        metadata = batch["md"][0]

        lm_pos_map = [
            torch.from_numpy(
                transformations.pos_m_to_px(
                    p.numpy(), self.params["global_map_size"],
                    self.params["world_size_m"], self.params["world_size_px"]))
            if p is not None else None for p in lm_pos_map
        ]

        goal_pos_map = torch.from_numpy(
            transformations.pos_m_to_px(goal_pos_map.numpy(),
                                        self.params["global_map_size"],
                                        self.params["world_size_m"],
                                        self.params["world_size_px"]))

        lm_pos_map = [
            self.cuda_var(s.long()) if s is not None else None
            for s in lm_pos_map
        ]
        lm_pos_fpv_features = [
            self.cuda_var(
                (s /
                 self.img_to_features_w.img_to_features.get_downscale_factor()
                 ).long()) if s is not None else None for s in lm_pos_fpv
        ]
        lm_pos_fpv_img = [
            self.cuda_var(s.long()) if s is not None else None
            for s in lm_pos_fpv
        ]
        lm_indices = [
            self.cuda_var(s) if s is not None else None for s in lm_indices
        ]
        goal_pos_map = self.cuda_var(goal_pos_map)
        if not TEMPLATES:
            lang_lm_mentioned = self.cuda_var(lang_lm_mentioned)
        lm_mentioned = [
            self.cuda_var(s) if s is not None else None for s in lm_mentioned
        ]

        obs_mask = [True for _ in range(seq_len)]
        plan_mask = [True for _ in range(seq_len)]
        pos_enc = None

        # TODO: Figure out how to keep these properly. Perhaps as a whole batch is best
        self.tensor_store.keep_inputs("lm_pos_fpv_img", lm_pos_fpv_img)
        self.tensor_store.keep_inputs("lm_pos_fpv_features",
                                      lm_pos_fpv_features)
        self.tensor_store.keep_inputs("lm_pos_map", lm_pos_map)
        self.tensor_store.keep_inputs("lm_indices", lm_indices)
        self.tensor_store.keep_inputs("goal_pos_map", goal_pos_map)
        if not TEMPLATES:
            self.tensor_store.keep_inputs("lang_lm_mentioned",
                                          lang_lm_mentioned)
        else:
            self.tensor_store.keep_inputs("lm_mentioned_tplt",
                                          lm_mentioned_tplt)
            self.tensor_store.keep_inputs("side_mentioned_tplt",
                                          side_mentioned_tplt)
        self.tensor_store.keep_inputs("lm_mentioned", lm_mentioned)

        # ----------------------------------------------------------------------------
        # Optional Auxiliary Inputs
        # ----------------------------------------------------------------------------
        #if self.aux_losses.input_required("lm_pos_map"):
        self.tensor_store.keep_inputs("lm_pos_map", lm_pos_map)
        #if self.aux_losses.input_required("lm_indices"):
        self.tensor_store.keep_inputs("lm_indices", lm_indices)
        #if self.aux_losses.input_required("lm_mentioned"):
        self.tensor_store.keep_inputs("lm_mentioned", lm_mentioned)

        return images, instructions, instr_lengths, states, actions, \
               lm_pos_fpv_img, lm_pos_fpv_features, lm_pos_map, lm_indices, goal_pos_map, \
               lm_mentioned, lm_mentioned_tplt, side_mentioned_tplt, lang_lm_mentioned, \
               metadata, obs_mask, plan_mask, pos_enc

    def show_landmark_locations(self, loop=True, states=None):
        # Show landmark locations in first-person images
        img_all = self.tensor_store.get("images")
        img_w_all = self.tensor_store.get("images_w")
        import rollout.run_metadata as md
        if md.IS_ROLLOUT:
            # TODO: Discard this and move this to PomdpInterface or something
            # (it's got nothing to do with the model)
            # load landmark positions from configs
            from data_io.env import load_env_config
            from learning.datasets.aux_data_providers import get_landmark_locations_airsim
            from learning.models.semantic_map.pinhole_camera_inv import PinholeCameraProjection
            projector = PinholeCameraProjection(
                map_size_px=self.params["global_map_size"],
                world_size_px=self.params["world_size_px"],
                world_size_m=self.params["world_size_m"],
                img_x=self.params["img_w"],
                img_y=self.params["img_h"],
                cam_fov=self.params["cam_h_fov"],
                #TODO: Handle correctly
                domain="sim",
                use_depth=False)
            conf_json = load_env_config(md.ENV_ID)
            landmark_names, landmark_indices, landmark_pos = get_landmark_locations_airsim(
                conf_json)
            cam_poses = self.cam_poses_from_states(states)
            cam_pos = cam_poses.position[0]
            cam_rot = cam_poses.orientation[0]
            lm_pos_map_all = []
            lm_pos_img_all = []
            for i, landmark_in_world in enumerate(landmark_pos):
                lm_pos_img, landmark_in_cam, status = projector.world_point_to_image(
                    cam_pos, cam_rot, landmark_in_world)
                lm_pos_map = torch.from_numpy(
                    transformations.pos_m_to_px(
                        landmark_in_world[np.newaxis, :],
                        self.params["global_map_size"],
                        self.params["world_size_m"],
                        self.params["world_size_px"]))
                lm_pos_map_all += [lm_pos_map[0]]
                if lm_pos_img is not None:
                    lm_pos_img_all += [lm_pos_img]

            lm_pos_img_all = [lm_pos_img_all]
            lm_pos_map_all = [lm_pos_map_all]

        else:
            lm_pos_img_all = self.tensor_store.get("lm_pos_fpv_img")
            lm_pos_map_all = self.tensor_store.get("lm_pos_map")

        print("Plotting landmark points")

        for i in range(len(img_all)):
            p = Presenter()
            overlay_fpv = p.overlay_pts_on_image(img_all[i][0],
                                                 lm_pos_img_all[i])
            overlay_map = p.overlay_pts_on_image(img_w_all[i][0],
                                                 lm_pos_map_all[i])
            p.show_image(overlay_fpv, "landmarks_on_fpv_img", scale=8)
            p.show_image(overlay_map, "landmarks_on_map", scale=20)

            if not loop:
                break

    def calc_tensor_statistics(self, prefix, tensor):
        stats = {}
        stats[f"{prefix}_mean"] = torch.mean(tensor).item()
        stats[f"{prefix}_l2"] = torch.norm(tensor).item()
        stats[f"{prefix}_stddev"] = torch.std(tensor).item()
        return stats

    def get_activation_statistics(self, keys):
        stats = {}
        from utils.dict_tools import dict_merge
        for key in keys:
            t = self.tensor_store.get_inputs_batch(key)
            t_stats = self.calc_tensor_statistics(key, t)
            stats = dict_merge(stats, t_stats)
        return stats

    # Forward pass for training (with batch optimizations
    def sup_loss_on_batch(self, batch, eval, halfway=False):
        self.prof.tick("out")

        action_loss_total = Variable(
            empty_float_tensor([1], self.is_cuda, self.cuda_device))

        if batch is None:
            print("Skipping None Batch")
            return action_loss_total

        images, instructions, instr_lengths, states, action_labels, \
        lm_pos_fpv_img, lm_pos_fpv_features, lm_pos_map, lm_indices, goal_pos_map, \
        lm_mentioned, lm_mentioned_tplt, side_mentioned_tplt, lang_lm_mentioned, \
        metadata, obs_mask, plan_mask, pos_enc = self.unbatch(batch)

        # ----------------------------------------------------------------------------
        self.prof.tick("inputs")

        pred_actions = self(images,
                            states,
                            instructions,
                            instr_lengths,
                            has_obs=obs_mask,
                            plan=plan_mask,
                            pos_enc=pos_enc,
                            halfway=halfway)

        # Debugging landmark locations
        if False:
            self.show_landmark_locations()

        # Don't compute any losses - those will not be used. All we care about are the intermediate activations
        if halfway:
            return None, self.tensor_store

        action_losses, _ = self.action_loss(action_labels,
                                            pred_actions,
                                            batchreduce=False)

        self.prof.tick("call")

        action_losses = self.action_loss.batch_reduce_loss(action_losses)
        action_loss = self.action_loss.reduce_loss(action_losses)

        action_loss_total = action_loss

        self.prof.tick("loss")

        aux_losses, aux_metrics = self.aux_losses.calculate_aux_loss(
            self.tensor_store, reduce_average=True)
        aux_loss = self.aux_losses.combine_losses(aux_losses, self.aux_weights)

        #overlaid = self.get_overlaid_classification_results()
        #Presenter().show_image(overlaid, "classification", scale=2)

        prefix = f"{self.model_name}/{'eval' if eval else 'train'}"
        act_prefix = f"{self.model_name}_activations/{'eval' if eval else 'train'}"

        # Mean, stddev, norm of maps
        act_stats = self.get_activation_statistics(
            ["map_S_W", "map_R_W", "map_G_W"])
        self.writer.add_dict(act_prefix, act_stats, self.get_iter())

        self.writer.add_dict(prefix, get_current_meters(), self.get_iter())
        self.writer.add_dict(prefix, aux_losses, self.get_iter())
        self.writer.add_dict(prefix, aux_metrics, self.get_iter())
        self.writer.add_scalar(prefix + "/action_loss",
                               action_loss_total.data.cpu().item(),
                               self.get_iter())
        # TODO: Log value here
        self.writer.add_scalar(prefix + "/goal_accuracy",
                               self.goal_acc_meter.get(), self.get_iter())

        self.prof.tick("auxiliaries")

        total_loss = action_loss_total + aux_loss

        self.inc_iter()

        self.prof.tick("summaries")
        self.prof.loop()
        self.prof.print_stats(1)

        return total_loss, self.tensor_store

    def get_dataset(self,
                    data=None,
                    envs=None,
                    dataset_names=None,
                    dataset_prefix=None,
                    eval=False):
        # TODO: Maybe use eval here
        #if self.fpv:
        data_sources = []
        # If we're running auxiliary objectives, we need to include the data sources for the auxiliary labels
        #if self.use_aux_class_features or self.use_aux_class_on_map or self.use_aux_grounding_features or self.use_aux_grounding_on_map:
        #if self.use_aux_goal_on_map:
        data_sources.append(aup.PROVIDER_LM_POS_DATA)
        data_sources.append(aup.PROVIDER_GOAL_POS)
        #data_sources.append(aup.PROVIDER_LANDMARKS_MENTIONED)
        data_sources.append(aup.PROVIDER_LANG_TEMPLATE)

        #if self.use_rot_noise or self.use_pos_noise:
        #    data_sources.append(aup.PROVIDER_POSE_NOISE)

        return SegmentDataset(data=data,
                              env_list=envs,
                              dataset_names=dataset_names,
                              dataset_prefix=dataset_prefix,
                              aux_provider_names=data_sources,
                              segment_level=True)
Exemplo n.º 6
0
class TrainerRL:
    def __init__(self, params, save_rollouts_to_dataset="", device=None):
        self.iterations_per_epoch = params.get("iterations_per_epoch", 1)
        self.test_iterations_per_epoch = params.get(
            "test_iterations_per_epoch", 1)
        self.num_workers = params.get("num_workers")
        self.num_rollouts_per_iter = params.get("num_rollouts_per_iter")
        self.model_name = params.get("model") or params.get("rl_model")
        self.init_model_file = params.get("model_file")
        self.num_steps = params.get("trajectory_len")
        self.device = device

        self.summary_every_n = params.get("plot_every_n")

        self.roller = SimpleParallelPolicyRoller(
            num_workers=self.num_workers,
            device=self.device,
            policy_name=self.model_name,
            policy_file=self.init_model_file,
            dataset_save_name=save_rollouts_to_dataset)

        self.rollout_sampler = RolloutSampler(self.roller)

        # This should load it's own weights from file based on
        self.full_model, _ = load_model(self.model_name)
        self.full_model = self.full_model.to(self.device)
        self.actor_critic = self.full_model.stage2_action_generation
        # Train in eval mode to disable dropout
        #self.actor_critic.eval()
        self.full_model.stage1_visitation_prediction.eval()
        self.writer = LoggingSummaryWriter(
            log_dir=f"{get_logging_dir()}/runs/{params['run_name']}/ppo")

        self.global_step = 0
        self.stage1_updates = 0

        clip_param = params.get("clip")
        num_mini_batch = params.get("num_mini_batch")
        value_loss_coef = params.get("value_loss_coef")
        lr = params.get("lr")
        eps = params.get("eps")
        max_grad_norm = params.get("max_grad_norm")
        use_clipped_value_loss = params.get("use_clipped_value_loss")

        self.entropy_coef = params.get("entropy_coef")
        self.entropy_schedule_epochs = params.get("entropy_schedule_epochs",
                                                  [])
        self.entropy_schedule_multipliers = params.get(
            "entropy_schedule_multipliers", [])

        self.minibatch_size = params.get("minibatch_size")

        self.use_gae = params.get("use_gae")
        self.gamma = params.get("gamma")
        self.gae_lambda = params.get("gae_lambda")
        self.intrinsic_reward_only = params.get("intrinsic_reward_only")

        self.prof = SimpleProfiler(torch_sync=PROFILE, print=PROFILE)

        print(
            f"PPO trainable parameters: {get_n_trainable_params(self.actor_critic)}"
        )
        print(
            f"PPO actor-critic all parameters: {get_n_params(self.actor_critic)}"
        )

        self.ppo = PPO(self.actor_critic,
                       clip_param=clip_param,
                       ppo_epoch=1,
                       num_mini_batch=num_mini_batch,
                       value_loss_coef=value_loss_coef,
                       entropy_coef=self.entropy_coef,
                       lr=lr,
                       eps=eps,
                       max_grad_norm=max_grad_norm,
                       use_clipped_value_loss=use_clipped_value_loss)

    def set_start_epoch(self, epoch):
        prints_per_epoch = int(self.iterations_per_epoch /
                               self.summary_every_n)
        self.global_step = epoch * prints_per_epoch

    def save_rollouts(self, rollouts, dataset_name):
        for rollout in rollouts:
            # This saves just a single segment per environment, as opposed to all segments that the oracle saves. Problem?
            if len(rollout) > 0:
                env_id = rollout[0]["env_id"]
                save_dataset(dataset_name, rollout, env_id=env_id, lock=True)

    def reload_stage1(self, module_state_dict):
        print("Reloading stage 1 model in RL trainer")
        self.full_model.stage1_visitation_prediction.load_state_dict(
            module_state_dict)
        print("Reloading stage 1 model in rollout sampler")
        self.rollout_sampler.update_stage1_on_workers(
            self.full_model.stage1_visitation_prediction)
        print("Done reloading stage1")
        self.stage1_updates += 1

    def train_epoch(self, epoch_num, eval=False, envs="train"):

        rewards = []
        returns = []
        value_losses = []
        action_losses = []
        dist_entropies = []
        value_preds = []
        vels = []
        stopprobs = []

        step_rollout_metrics = {}

        # Update entropy coefficient by applying scaling
        if len(self.entropy_schedule_epochs) > 0:
            scaled_entropy_coeff = self.entropy_coef
            for e_multiplier, e_epoch in zip(self.entropy_schedule_multipliers,
                                             self.entropy_schedule_epochs):
                if epoch_num > e_epoch:
                    scaled_entropy_coeff = e_multiplier * self.entropy_coef
                else:
                    break
            self.ppo.set_entropy_coef(scaled_entropy_coeff)
        else:
            scaled_entropy_coeff = self.entropy_coef

        self.prof.tick("out")

        # TODO: Make the 100 a parameter
        iterations = self.test_iterations_per_epoch if eval else self.iterations_per_epoch

        for i in range(iterations):
            policy_state = self.full_model.get_policy_state()
            device = policy_state[next((iter(policy_state)))].device
            print("TrainerRL: Sampling N Rollouts")
            rollouts = self.rollout_sampler.sample_n_rollouts(
                self.num_rollouts_per_iter,
                policy_state,
                sample=not eval,
                envs=envs)
            #if save_rollouts_to_dataset is not None:
            #    self.save_rollouts(rollouts, save_rollouts_to_dataset)

            self.prof.tick("sample_rollouts")
            print("TrainerRL: Calculating Rollout Metrics")
            i_rollout_metrics = calc_rollout_metrics(rollouts)
            step_rollout_metrics = dictlist_append(step_rollout_metrics,
                                                   i_rollout_metrics)

            assert len(rollouts) > 0

            # Convert our rollouts to the format used by Ilya Kostrikov
            device = next(self.full_model.parameters()).device
            rollout_storage = RolloutStorage.from_rollouts(
                rollouts,
                device=device,
                intrinsic_reward_only=self.intrinsic_reward_only)
            next_value = None

            rollout_storage.compute_returns(next_value, self.use_gae,
                                            self.gamma, self.gae_lambda, False)

            self.prof.tick("compute_storage")

            reward = rollout_storage.rewards.mean().detach().cpu().item()
            avg_return = (((rollout_storage.returns[1:] *
                            rollout_storage.masks[:-1]).sum() +
                           rollout_storage.returns[0]) /
                          (rollout_storage.masks[:-1].sum() + 1)).cpu().item()
            avg_value = rollout_storage.value_preds.mean().detach().cpu().item(
            )
            avg_vel = rollout_storage.actions[:, 0,
                                              0:3].detach().cpu().numpy().mean(
                                                  axis=0, keepdims=False)
            avg_stopprob = rollout_storage.actions[:, 0, 3].mean().detach(
            ).cpu().item()

            print("TrainerRL: PPO Update")
            if not eval:
                value_loss, action_loss, dist_entropy, avg_ratio = self.ppo.update(
                    rollout_storage, self.global_step, self.minibatch_size)
                print(
                    f"Iter: {i}/{iterations}, Value loss: {value_loss}, Action loss: {action_loss}, Entropy: {dist_entropy}, Reward: {reward}"
                )
            else:
                value_loss = 0
                action_loss = 0
                dist_entropy = 0
                avg_ratio = 0

            self.prof.tick("ppo_update")
            print("TrainerRL: PPO Update Done")

            returns.append(avg_return)
            rewards.append(reward)
            value_losses.append(value_loss)
            action_losses.append(action_loss)
            dist_entropies.append(dist_entropy)
            value_preds.append(avg_value)
            vels.append(avg_vel)
            stopprobs.append(avg_stopprob)

            if i % self.summary_every_n == self.summary_every_n - 1:
                avg_reward = np.mean(
                    np.asarray(rewards[-self.summary_every_n:]))
                avg_return = np.mean(
                    np.asarray(returns[-self.summary_every_n:]))
                avg_vel = np.mean(np.asarray(vels[-self.summary_every_n:]),
                                  axis=0,
                                  keepdims=False)

                metrics = {
                    "value_loss":
                    np.mean(np.asarray(value_losses[-self.summary_every_n:])),
                    "action_loss":
                    np.mean(np.asarray(action_losses[-self.summary_every_n:])),
                    "dist_entropy":
                    np.mean(np.asarray(
                        dist_entropies[-self.summary_every_n:])),
                    "avg_value":
                    np.mean(np.asarray(value_preds[-self.summary_every_n:])),
                    "avg_vel_x":
                    avg_vel[0],
                    "avg_yaw_rate":
                    avg_vel[2],
                    "avg_stopprob":
                    np.mean(np.asarray(stopprobs[-self.summary_every_n:])),
                    "ratio":
                    avg_ratio
                }

                # Reduce average
                step_rollout_metrics = dict_map(step_rollout_metrics,
                                                lambda m: np.asarray(m).mean())

                mode = "eval" if eval else "train"

                self.writer.add_scalar(f"ppo_{mode}/reward", avg_reward,
                                       self.global_step)
                self.writer.add_scalar(f"ppo_{mode}/return", avg_return,
                                       self.global_step)
                self.writer.add_scalar(f"ppo_{mode}/stage1_updates",
                                       self.stage1_updates, self.global_step)
                self.writer.add_dict(f"ppo_{mode}/", metrics, self.global_step)
                self.writer.add_dict(f"ppo_{mode}/", step_rollout_metrics,
                                     self.global_step)
                self.writer.add_scalar(f"ppo_{mode}/scaled_entropy_coeff",
                                       scaled_entropy_coeff, self.global_step)
                step_rollout_metrics = {}

                self.global_step += 1

            self.prof.tick("logging")
            print("TrainerRL: Finished Step")

        # TODO: Remove code duplication (this was easier for now)
        avg_reward = np.mean(np.asarray(rewards))
        avg_vel = np.mean(np.asarray(vels), axis=0, keepdims=False)
        metrics = {
            "value_loss": np.mean(np.asarray(value_losses)),
            "action_loss": np.mean(np.asarray(action_losses)),
            "dist_entropy": np.mean(np.asarray(dist_entropies)),
            "avg_value": np.mean(np.asarray(value_preds)),
            "avg_vel_x": avg_vel[0],
            "avg_yaw_rate": avg_vel[2],
            "avg_stopprob": np.mean(np.asarray(stopprobs))
        }
        #pprint(metrics)

        self.prof.tick("logging")
        self.prof.loop()
        self.prof.print_stats(1)

        return avg_reward, metrics