예제 #1
0
    def eval_on_batch(self, attr, traj, config):
        if self.training:
            entire_out, (local_out, local_length) = self(attr, traj, config)
        else:
            entire_out = self(attr, traj, config)

        pred_dict, entire_loss = self.entire_estimate.eval_on_batch(
            entire_out, attr['time'], config['time_mean'], config['time_std'])

        if self.training:
            # get the mean/std of each local path
            mean, std = (self.kernel_size - 1) * config['time_gap_mean'], (
                self.kernel_size - 1) * config['time_gap_std']

            # get ground truth of each local path
            local_label = utils.get_local_seq(traj['time_gap'],
                                              self.kernel_size, mean, std)
            local_loss = self.local_estimate.eval_on_batch(
                local_out, local_length, local_label, mean, std)

            return pred_dict, (
                1 - self.alpha
            ) * entire_loss + self.alpha * local_loss  ### According to eqn 8 of paper
        else:
            return pred_dict, entire_loss
예제 #2
0
    def forward(self, traj, config):
        lngs = torch.unsqueeze(traj['lngs'], dim=2)
        lats = torch.unsqueeze(traj['lats'], dim=2)
        states = self.state_em(traj['states'].long())  ### Optional, remove

        locs = torch.cat((lngs, lats, states), dim=2)  ### Remove states

        # map the coords into 16-dim vector
        locs = torch.tanh(self.process_coords(
            locs))  ### size [batch, max len of trajectory in batch, 16]
        locs = locs.permute(0, 2, 1)

        conv_locs = F.elu(
            self.conv(locs)
        )  ### size [batch, num_filter, max len of batch traj - kernel_size + 1]
        conv_locs = conv_locs.permute(
            0, 2, 1)  ### size [batch, max len - kernel_size + 1, num_filter]

        # calculate the dist for local paths
        local_dist = utils.get_local_seq(traj['dist_gap'], self.kernel_size,
                                         config['dist_gap_mean'],
                                         config['dist_gap_std'])
        local_dist = torch.unsqueeze(
            local_dist,
            dim=2)  ### ### size [batch, max len - kernel_size + 1, 1]

        conv_locs = torch.cat(
            (conv_locs, local_dist),
            dim=2)  ### size [batch, max len - kernel_size + 1, num_filter + 1]

        return conv_locs
예제 #3
0
    def forward(self, traj, config):
        lngs = torch.unsqueeze(traj['lngs'], dim = 2)
        lats = torch.unsqueeze(traj['lats'], dim = 2)
        roads = traj['roads'].long()
        roads = self.road_em(roads)

        locs = torch.cat((lngs, lats, roads), dim = 2)

        # map the coords into 16-dim vector
        locs = F.tanh(self.process_coords(locs))
        locs = locs.permute(0, 2, 1)

        conv_locs = F.elu(self.conv(locs)).permute(0, 2, 1)

        # calculate the dist for local paths
        local_dist = utils.get_local_seq(traj['dist_gap'], self.kernel_size, config['dist_gap_mean'], config['dist_gap_std'])
        local_dist = torch.unsqueeze(local_dist, dim = 2)

        conv_locs = torch.cat((conv_locs, local_dist), dim = 2)

        return conv_locs
예제 #4
0
    def eval_on_batch(self, attr, traj, config):
        if self.training:
            entire_out, (local_out, local_length) = self(attr, traj)
        else:
            entire_out = self(attr, traj)
        pred_dict, entire_loss = self.entire_estimate.eval_on_batch(
            entire_out, attr['time'], config['time_mean'],
            config['time_std'])  #entire_loss is scalar

        if self.training:
            mean, std = (self.kernel_size - 1) * config['time_gap_mean'], (
                self.kernel_size - 1) * config['time_gap_std']

            local_label = utils.get_local_seq(traj['time_gap'],
                                              self.kernel_size, mean, std)
            local_loss = self.loacl_estimate.eval_on_batch(
                local_out, local_length, local_label, mean, std)

            return pred_dict, (
                1 - self.alpha) * entire_loss + self.alpha * local_loss
        else:
            return pred_dict, entire_loss
예제 #5
0
    def forward(self, traj):
        lngs = torch.unsqueeze(Variable(traj['lngs']),
                               dim=2)  #[bs,seq_length,1]
        lats = torch.unsqueeze(Variable(traj['lats']),
                               dim=2)  #[bs,seq_length,1]
        states = self.state_em(Variable(
            traj['states']).long())  #[bs,seq_length,2]
        locs = torch.cat((lngs, lats, states), dim=2)  #[bs,seq_length,4]

        locs = F.tanh(self.process_cooeds(locs))  #[bs,seq_length,16]
        locs = locs.permute(0, 2, 1)  #[bs,16,seq_length]

        conv_locs = F.elu(self.conv(locs)).permute(
            0, 2, 1)  #[bs,conv(H),num_filters]
        local_dist = utils.get_local_seq(Variable(traj['dist_gap']),
                                         self.kernel_size, 'dist_gap_mean',
                                         'dist_gap_std')  #[bs,conv(H)]
        local_dist = torch.unsqueeze(local_dist, dim=2)  #[bs,conv(H),1]

        conv_locs = torch.cat((conv_locs, local_dist),
                              dim=2)  #[bs,conv(H),num_filters+1]

        return conv_locs