def __init__(self, data, length_pred, skip_pred, device):

        assert length_pred >= 1  # TODO: multiple

        num_human = data[0].shape[1]
        state_dim = data[0].shape[2]

        self.transform = MultiAgentTransform(num_human)

        obsv = []
        target = []
        index = []

        for i, episode in enumerate(data):

            # remove starting and ending frame due to unpredictability
            speed = episode[:, :, -2:].norm(dim=2)
            valid = episode[(speed > 1e-4).all(axis=1)]

            length_valid = valid.shape[0]

            human_state = self.transform.transform_frame(valid)[:length_valid -
                                                                length_pred *
                                                                skip_pred]

            if length_valid > length_pred * skip_pred:
                upcome = []
                for k in range(length_pred):
                    propagate = episode[(k + 1) * skip_pred:length_valid -
                                        (length_pred - k - 1) *
                                        skip_pred, :, :2]
                    upcome.append(propagate)
                upcome = torch.cat(upcome, axis=2)
                obsv.append(
                    human_state.view(
                        (length_valid - length_pred * skip_pred) * num_human,
                        -1))
                target.append(
                    upcome.view(
                        (length_valid - length_pred * skip_pred) * num_human,
                        -1))
                index.append(
                    torch.arange(5).repeat(length_valid -
                                           length_pred * skip_pred) +
                    num_human * i)

        self.obsv = torch.cat(obsv).to(device)
        self.target = torch.cat(target).to(device)
        self.index = torch.cat(index).to(device)
    def __init__(self,
                 num_human,
                 embedding_dim=64,
                 hidden_dim=64,
                 local_dim=32):
        super().__init__()
        self.num_human = num_human
        self.transform = MultiAgentTransform(num_human)

        self.robot_encoder = nn.Sequential(nn.Linear(4, local_dim),
                                           nn.ReLU(inplace=True),
                                           nn.Linear(local_dim, local_dim),
                                           nn.ReLU(inplace=True))

        self.human_encoder = nn.Sequential(
            nn.Linear(4 * self.num_human, hidden_dim), nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True))

        self.human_head = nn.Sequential(nn.Linear(hidden_dim, local_dim),
                                        nn.ReLU(inplace=True))

        self.joint_embedding = nn.Sequential(
            nn.Linear(local_dim * 2, embedding_dim), nn.ReLU(inplace=True))

        self.pairwise = nn.Sequential(nn.Linear(embedding_dim, hidden_dim),
                                      nn.ReLU(inplace=True),
                                      nn.Linear(hidden_dim, hidden_dim))

        self.attention = nn.Sequential(nn.Linear(embedding_dim, hidden_dim),
                                       nn.ReLU(inplace=True),
                                       nn.Linear(hidden_dim, 1))

        self.task_encoder = nn.Sequential(nn.Linear(4, hidden_dim),
                                          nn.ReLU(inplace=True),
                                          nn.Linear(hidden_dim, hidden_dim),
                                          nn.ReLU(inplace=True))

        self.joint_encoder = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim), nn.ReLU(inplace=True))

        self.planner = nn.Linear(hidden_dim, 2)
    def __init__(self,
                 num_human,
                 embedding_dim=64,
                 hidden_dim=64,
                 local_dim=32,
                 forecast_hidden_dim=32,
                 forecast_emb_dim=16,
                 max_obs=5):
        super().__init__()
        self.num_human = num_human
        self.transform = MultiAgentTransform(num_human)
        self.max_obs = max_obs

        self.robot_encoder = nn.Sequential(nn.Linear(4, local_dim),
                                           nn.ReLU(inplace=True),
                                           nn.Linear(local_dim, local_dim),
                                           nn.ReLU(inplace=True))

        self.human_encoder = nn.Sequential(
            nn.Linear(4 * self.num_human, hidden_dim), nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True))

        self.human_head = nn.Sequential(nn.Linear(hidden_dim, local_dim),
                                        nn.ReLU(inplace=True))

        self.joint_embedding = nn.Sequential(
            nn.Linear(local_dim + local_dim, embedding_dim),
            nn.ReLU(inplace=True))

        self.pairwise = nn.Sequential(nn.Linear(embedding_dim, hidden_dim),
                                      nn.ReLU(inplace=True),
                                      nn.Linear(hidden_dim, hidden_dim))

        self.attention = nn.Sequential(nn.Linear(embedding_dim, hidden_dim),
                                       nn.ReLU(inplace=True),
                                       nn.Linear(hidden_dim, 1))

        self.task_encoder = nn.Sequential(nn.Linear(4, hidden_dim),
                                          nn.ReLU(inplace=True),
                                          nn.Linear(hidden_dim, hidden_dim),
                                          nn.ReLU(inplace=True))

        self.joint_encoder = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim), nn.ReLU(inplace=True))

        self.trajectory_fext = TrajFeatureExtractor(self.transform,
                                                    self.human_encoder,
                                                    self.human_head,
                                                    n_heads=4,
                                                    max_obs=self.max_obs)
        self.planner = nn.Linear(hidden_dim, 2)
    def __init__(self, data, max_pred, device):

        assert max_pred >= 1

        num_human = data[0].shape[1]
        state_dim = data[0].shape[2]

        self.transform = MultiAgentTransform(num_human)

        obsv = []
        target = []

        for episode in data:

            # remove starting and ending frame due to unpredictability
            speed = episode[:, :, -2:].norm(dim=2)
            valid = episode[(speed > 1e-4).all(axis=1)]

            length_valid = valid.shape[0]

            human_state = self.transform.transform_frame(valid)[:length_valid -
                                                                max_pred]

            frames = torch.empty(
                (length_valid - max_pred, num_human, max_pred, state_dim))
            for t in range(max_pred):
                frames[:, :, t, :] = valid[t + 1:length_valid - max_pred + t +
                                           1, :, ]

            obsv.append(
                human_state.view((length_valid - max_pred) * num_human,
                                 num_human * state_dim))
            target.append(
                frames.view((length_valid - max_pred) * num_human, max_pred,
                            state_dim))

        self.obsv = torch.cat(obsv).to(device)
        self.target = torch.cat(target).to(device)
class ExtendedNetwork(nn.Module):
    def __init__(self,
                 num_human,
                 embedding_dim=64,
                 hidden_dim=64,
                 local_dim=32):
        super().__init__()
        self.num_human = num_human
        self.transform = MultiAgentTransform(num_human)

        self.robot_encoder = nn.Sequential(nn.Linear(4, local_dim),
                                           nn.ReLU(inplace=True),
                                           nn.Linear(local_dim, local_dim),
                                           nn.ReLU(inplace=True))

        self.human_encoder = nn.Sequential(
            nn.Linear(4 * self.num_human, hidden_dim), nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True))

        self.human_head = nn.Sequential(nn.Linear(hidden_dim, local_dim),
                                        nn.ReLU(inplace=True))

        self.joint_embedding = nn.Sequential(
            nn.Linear(local_dim * 2, embedding_dim), nn.ReLU(inplace=True))

        self.pairwise = nn.Sequential(nn.Linear(embedding_dim, hidden_dim),
                                      nn.ReLU(inplace=True),
                                      nn.Linear(hidden_dim, hidden_dim))

        self.attention = nn.Sequential(nn.Linear(embedding_dim, hidden_dim),
                                       nn.ReLU(inplace=True),
                                       nn.Linear(hidden_dim, 1))

        self.task_encoder = nn.Sequential(nn.Linear(4, hidden_dim),
                                          nn.ReLU(inplace=True),
                                          nn.Linear(hidden_dim, hidden_dim),
                                          nn.ReLU(inplace=True))

        self.joint_encoder = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim), nn.ReLU(inplace=True))

        self.planner = nn.Linear(hidden_dim, 2)

    def forward(self, robot_state, crowd_obsv, aux_task=''):

        if len(robot_state.shape) < 2:
            robot_state = robot_state.unsqueeze(0)
            crowd_obsv = crowd_obsv.unsqueeze(0)

        # preprocessing
        emb_robot = self.robot_encoder(robot_state[:, :4])

        human_state = self.transform.transform_frame(crowd_obsv)
        feat_human = self.human_encoder(human_state)
        emb_human = self.human_head(feat_human)

        emb_concat = torch.cat(
            [emb_robot.unsqueeze(1).repeat(1, self.num_human, 1), emb_human],
            axis=2)

        # embedding
        emb_pairwise = self.joint_embedding(emb_concat)

        # pairwise
        feat_pairwise = self.pairwise(emb_pairwise)

        # attention
        logit_pairwise = self.attention(emb_pairwise)
        score_pairwise = nn.functional.softmax(logit_pairwise, dim=1)

        # crowd
        feat_crowd = torch.sum(feat_pairwise * score_pairwise, dim=1)

        # planning
        reparam_robot_state = torch.cat(
            [robot_state[:, -2:] - robot_state[:, :2], robot_state[:, 2:4]],
            axis=1)
        feat_task = self.task_encoder(reparam_robot_state)

        feat_joint = self.joint_encoder(
            torch.cat([feat_task, feat_crowd], axis=1))
        action = self.planner(feat_joint)
        if aux_task == 'contrastive':
            return action, feat_joint
        return action, emb_human