示例#1
0
    def eval_loss(self, batch, node_type):
        (first_history_index, x_t, y_t, x_st_t, y_st_t, neighbors_data_st,
         neighbors_edge_value, robot_traj_st_t, map, _, _) = batch

        x = x_t.to(self.device)
        y = y_t.to(self.device)
        x_st_t = x_st_t.to(self.device)
        y_st_t = y_st_t.to(self.device)
        if robot_traj_st_t is not None:
            robot_traj_st_t = robot_traj_st_t.to(self.device)
        if type(map) == torch.Tensor:
            map = map.to(self.device)

        # Run forward pass
        model = self.node_models_dict[node_type]
        nll = model.eval_loss(
            inputs=x,
            inputs_st=x_st_t,
            first_history_indices=first_history_index,
            labels=y,
            labels_st=y_st_t,
            neighbors=restore(neighbors_data_st),
            neighbors_edge_value=restore(neighbors_edge_value),
            robot=robot_traj_st_t,
            map=map,
            prediction_horizon=self.ph)

        return nll.cpu().detach().numpy()
示例#2
0
    def train_loss(self, batch, node_type):
        (first_history_index, x_t, y_t, x_st_t, y_st_t, neighbors_data_st,
         neighbors_edge_value, robot_traj_st_t, map, x_next_t,
         neighbors_prev_t, neighbors_next_t, safe_t) = batch

        x = x_t.to(self.device)
        y = y_t.to(self.device)
        x_st_t = x_st_t.to(self.device)
        y_st_t = y_st_t.to(self.device)
        if robot_traj_st_t is not None:
            robot_traj_st_t = robot_traj_st_t.to(self.device)
        if type(map) == torch.Tensor:
            map = map.to(self.device)

        x_next = x_next_t.to(self.device)
        neighbors_prev = restore(neighbors_prev_t).to(self.device)
        neighbors_next = restore(neighbors_next_t).to(self.device)

        # Run forward pass
        model = self.node_models_dict[node_type]
        loss, loss_task, loss_nce = model.train_loss(
            inputs=x,
            inputs_st=x_st_t,
            first_history_indices=first_history_index,
            labels=y,
            labels_st=y_st_t,
            neighbors=restore(neighbors_data_st),
            neighbors_edge_value=restore(neighbors_edge_value),
            robot=robot_traj_st_t,
            map=map,
            prediction_horizon=self.ph,
            primary_next=x_next,
            neighbors_prev=neighbors_prev,
            neighbors_next=neighbors_next,
            safety_gt=safe_t)

        return loss, loss_task, loss_nce