示例#1
0
    def forward(self, input, objects, objects_length):
        object_features = input
        context_features = self.context_feature_extract(input)
        relation_features = self.relation_feature_extract(input)

        outputs = list()
        objects_index = 0
        for i in range(input.size(0)):
            box = objects[objects_index:objects_index + objects_length[i].item()]
            objects_index += objects_length[i].item()

            with torch.no_grad():
                batch_ind = i + torch.zeros(box.size(0), 1, dtype=box.dtype, device=box.device)

                # generate a "full-image" bounding box
                image_h, image_w = input.size(2) * self.downsample_rate, input.size(3) * self.downsample_rate
                image_box = torch.cat([
                    torch.zeros(box.size(0), 1, dtype=box.dtype, device=box.device),
                    torch.zeros(box.size(0), 1, dtype=box.dtype, device=box.device),
                    image_w + torch.zeros(box.size(0), 1, dtype=box.dtype, device=box.device),
                    image_h + torch.zeros(box.size(0), 1, dtype=box.dtype, device=box.device)
                ], dim=-1)

                # meshgrid to obtain the subject and object bounding boxes
                sub_id, obj_id = jactorch.meshgrid(torch.arange(box.size(0), dtype=torch.int64, device=box.device), dim=0)
                sub_id, obj_id = sub_id.contiguous().view(-1), obj_id.contiguous().view(-1)
                sub_box, obj_box = jactorch.meshgrid(box, dim=0)
                sub_box = sub_box.contiguous().view(box.size(0) ** 2, 4)
                obj_box = obj_box.contiguous().view(box.size(0) ** 2, 4)

                # union box
                union_box = functional.generate_union_box(sub_box, obj_box)
                rel_batch_ind = i + torch.zeros(union_box.size(0), 1, dtype=box.dtype, device=box.device)

                # intersection maps
                box_context_imap = functional.generate_intersection_map(box, image_box, self.pool_size)
                sub_union_imap = functional.generate_intersection_map(sub_box, union_box, self.pool_size)
                obj_union_imap = functional.generate_intersection_map(obj_box, union_box, self.pool_size)

            this_context_features = self.context_roi_pool(context_features, torch.cat([batch_ind, image_box], dim=-1))
            x, y = this_context_features.chunk(2, dim=1)
            this_object_features = self.object_feature_fuse(torch.cat([
                self.object_roi_pool(object_features, torch.cat([batch_ind, box], dim=-1)),
                x, y * box_context_imap
            ], dim=1))

            this_relation_features = self.relation_roi_pool(relation_features, torch.cat([rel_batch_ind, union_box], dim=-1))
            x, y, z = this_relation_features.chunk(3, dim=1)
            this_relation_features = self.relation_feature_fuse(torch.cat([
                this_object_features[sub_id], this_object_features[obj_id],
                x, y * sub_union_imap, z * obj_union_imap
            ], dim=1))

            outputs.append([
                None,
                self._norm(self.object_feature_fc(this_object_features.view(box.size(0), -1))),
                self._norm(self.relation_feature_fc(this_relation_features.view(box.size(0) * box.size(0), -1)).view(box.size(0), box.size(0), -1))
            ])

        return outputs
示例#2
0
def gen_voronoi(centers, height, width):
    range_y = torch.arange(height, device=centers.device)
    range_x = torch.arange(width, device=centers.device)
    y, x = jactorch.meshgrid(range_y, range_x, dim=0)
    y, x = y.reshape(-1), x.reshape(-1)
    coords = torch.stack([y, x], dim=1).float()
    coords, centers = jactorch.meshgrid(coords, centers, dim=0)
    dis = (coords[:, :, 0] - centers[:, :, 1]) ** 2 + (coords[:, :, 1] - centers[:, :, 0]) ** 2
    assignment = dis.argmin(1)
    return dis.view((height, width, -1)), assignment.view((height, width))
    def cross_similarity(self, query, identifier):
        mapping = self.get_attribute(identifier)
        query = mapping(query)
        query = query / query.norm(2, dim=-1, keepdim=True)
        q1, q2 = jactorch.meshgrid(query, dim=-2)

        return self.similarity2(q1, q2, identifier, _normalized=True)
示例#4
0
def generate_roi_pool_bins(box, bin_size, c2l=COOR_TO_LEN_CORR):
    # TODO(Jiayuan Mao @ 07/20): workaround: line space is not implemented for cuda.
    linspace = torch.linspace(0, 1, bin_size + 1,
                              dtype=box.dtype).to(device=box.device)
    for i in range(box.dim() - 1):
        linspace.unsqueeze_(0)
    x_space = linspace * (__last(box, 2) - __last(box, 0) +
                          c2l).unsqueeze(-1) + __last(box, 0).unsqueeze(-1)
    y_space = linspace * (__last(box, 3) - __last(box, 1) +
                          c2l).unsqueeze(-1) + __last(box, 1).unsqueeze(-1)
    x1, x2 = x_space[:, :-1], x_space[:, 1:] - c2l
    y1, y2 = y_space[:, :-1], y_space[:, 1:] - c2l
    y1, x1 = jactorch.meshgrid(y1, x1, dim=-1)
    y2, x2 = jactorch.meshgrid(y2, x2, dim=-1)

    # shape: nr_boxes, bin_size^2, 4
    bins = torch.stack([x1, y1, x2, y2], dim=-1).view(box.size(0), -1, 4)
    return bins.float()
示例#5
0
    def forward(self, input, image, masks=None):
        object_features = input  #[batch_size,feature_dim,h_f,w_f]
        context_features = self.context_feature_extract(
            input)  #[batch_size,feature_dim,h_f,w_f]
        relation_features = self.relation_feature_extract(
            input)  #[batch_size,feature_dim//2*3,h_f,w_f]

        masks_monet = self.monet_mask_extract(
            self.image_resize(image))  # [batch_size,slot_num,h_m,w_m]
        if masks is None:
            masks = masks_monet
        else:
            masks = self.true_mask_resize(
                masks.view(input.shape[0] * self.slot_num, 1, self.h_raw_raw,
                           self.w_raw_raw))
            masks = masks.view(input.shape[0], self.slot_num, self.h_m,
                               self.w_m)
            self.monet_mask_extract.m = masks

        if self.loss_type == 'separate':
            masks = masks.detach()
        masks = self.mask_resize(masks.view(-1, 1, self.h_m, self.w_m)).view(
            input.shape[0], -1, self.h_f, self.w_f)

        sub_id, obj_id = jactorch.meshgrid(torch.arange(self.slot_num,
                                                        dtype=torch.long,
                                                        device=input.device),
                                           dim=0)
        sub_id, obj_id = sub_id.contiguous().view(
            -1), obj_id.contiguous().view(-1)

        masked_object_features = object_features.unsqueeze(
            1) * masks.unsqueeze(2)  #[batch_size,slot_num,feature_dim,h_f,w_f]
        masked_context_features = context_features.unsqueeze(
            1) * masks.unsqueeze(2)
        masked_relation_features = relation_features.unsqueeze(1) * (
            masks[:, sub_id] + masks[:, obj_id]).unsqueeze(2)

        x_context, y_context = masked_context_features.chunk(2, dim=2)
        combined_object_features = torch.cat([
            masked_object_features, x_context, y_context * masks.unsqueeze(2)
        ],
                                             dim=2)
        combined_object_features = combined_object_features.view(
            -1, self.feature_dim * 2, self.h_f, self.w_f)
        combined_object_features = self.object_feature_fuse(
            combined_object_features)
        combined_object_features = combined_object_features.view(
            input.shape[0], self.slot_num, self.output_dims[1], self.h_f,
            self.w_f)

        x_relation, y_relation, z_relation = masked_relation_features.chunk(
            3, dim=2)
        combined_relation_features = torch.cat([
            combined_object_features[:, sub_id],
            combined_object_features[:, obj_id], x_relation,
            y_relation * masks[:, sub_id].unsqueeze(2),
            z_relation * masks[:, obj_id].unsqueeze(2)
        ],
                                               dim=2)
        combined_relation_features = combined_relation_features.view(
            -1, self.feature_dim // 2 * 3 + self.output_dims[1] * 2, self.h_f,
            self.w_f)
        combined_relation_features = self.relation_feature_fuse(
            combined_relation_features)

        combined_object_features = combined_object_features.view(
            masks.shape[0] * masks.shape[1], -1)
        combined_object_features = self._norm(
            self.object_feature_fc(combined_object_features))
        combined_object_features = combined_object_features.view(
            masks.shape[0], masks.shape[1], -1)

        combined_relation_features = combined_relation_features.view(
            masks.shape[0] * masks.shape[1]**2, -1)
        combined_relation_features = self._norm(
            self.object_feature_fc(combined_relation_features))
        combined_relation_features = combined_relation_features.view(
            masks.shape[0], masks.shape[1], masks.shape[1], -1)

        outputs = []
        for i in range(input.shape[0]):
            outputs.append([
                None, combined_object_features[i],
                combined_relation_features[i]
            ])
        return outputs
示例#6
0
    def forward(self, input, feed_dict, mode=0, tar_obj_id=-1):
        """
        extracting region, collision and box sequence features for models
        0 for normal, 1 for future and 2 for counterfact
        """
        object_features = input
        context_features = self.context_feature_extract(input)
        relation_features = self.relation_feature_extract(input)

        outputs = list()

        #pdb.set_trace()
        def parse_boxes_for_frm(feed_dict, frm_idx, mode=0, tar_obj_id=-1):
            if mode == 0:
                boxes_list = []
                tube_id_list = []
                frm_id = feed_dict['tube_info']['frm_list'][frm_idx]
                for tube_id, tube_info in feed_dict['tube_info'].items():
                    if not isinstance(tube_id, int):
                        continue
                    assert len(tube_info['frm_name']) == len(
                        tube_info['boxes'])
                    if frm_id not in tube_info['frm_name']:
                        continue
                    box_idx = tube_info['frm_name'].index(frm_id)
                    box = tube_info['boxes'][box_idx].squeeze(0)
                    boxes_list.append(
                        torch.tensor(box, device=feed_dict['img'].device))
                    tube_id_list.append(tube_id)
                boxes_tensor = torch.stack(boxes_list, 0).cuda()
                return boxes_tensor, tube_id_list
            elif mode == 1 or mode == 3:
                boxes_list = []
                tube_id_list = []
                frm_id = feed_dict['predictions']['frm_list'][frm_idx]
                for tube_id, tube_info in feed_dict['predictions'].items():
                    if not isinstance(tube_id, int):
                        continue
                    assert len(tube_info['frm_name']) == len(
                        tube_info['boxes'])
                    if frm_id not in tube_info['frm_name']:
                        continue
                    box_idx = tube_info['frm_name'].index(frm_id)
                    box = tube_info['boxes'][box_idx]
                    boxes_list.append(
                        torch.tensor(box,
                                     device=feed_dict['img_future'].device))
                    tube_id_list.append(tube_id)
                boxes_tensor = torch.stack(boxes_list, 0).cuda()
                return boxes_tensor, tube_id_list
            elif mode == 2:
                boxes_list = []
                tube_id_list = []
                frm_id = feed_dict['counterfacts'][tar_obj_id]['frm_list'][
                    frm_idx]
                for tube_id, tube_info in feed_dict['counterfacts'][
                        tar_obj_id].items():
                    if not isinstance(tube_id, int):
                        continue
                    assert len(tube_info['frm_name']) == len(
                        tube_info['boxes'])
                    if frm_id not in tube_info['frm_name']:
                        continue
                    box_idx = tube_info['frm_name'].index(frm_id)
                    box = tube_info['boxes'][box_idx]
                    boxes_list.append(
                        torch.tensor(box,
                                     device=feed_dict['img_counterfacts']
                                     [tar_obj_id].device))
                    tube_id_list.append(tube_id)
                boxes_tensor = torch.stack(boxes_list, 0).cuda()
                return boxes_tensor, tube_id_list

        for i in range(input.size(0)):
            boxes, tube_id_list = parse_boxes_for_frm(feed_dict, i, mode,
                                                      tar_obj_id)
            with torch.no_grad():
                batch_ind = i + torch.zeros(
                    boxes.size(0), 1, dtype=boxes.dtype, device=boxes.device)

                # generate a "full-image" bounding box
                image_h, image_w = input.size(
                    2) * self.downsample_rate, input.size(
                        3) * self.downsample_rate
                image_box = torch.cat([
                    torch.zeros(boxes.size(0),
                                1,
                                dtype=boxes.dtype,
                                device=boxes.device),
                    torch.zeros(boxes.size(0),
                                1,
                                dtype=boxes.dtype,
                                device=boxes.device),
                    image_w + torch.zeros(boxes.size(0),
                                          1,
                                          dtype=boxes.dtype,
                                          device=boxes.device),
                    image_h + torch.zeros(boxes.size(0),
                                          1,
                                          dtype=boxes.dtype,
                                          device=boxes.device)
                ],
                                      dim=-1)

                # meshgrid to obtain the subject and object bounding boxes
                sub_id, obj_id = jactorch.meshgrid(torch.arange(
                    boxes.size(0), dtype=torch.int64, device=boxes.device),
                                                   dim=0)
                sub_id, obj_id = sub_id.contiguous().view(
                    -1), obj_id.contiguous().view(-1)
                sub_box, obj_box = jactorch.meshgrid(boxes, dim=0)
                sub_box = sub_box.contiguous().view(boxes.size(0)**2, 4)
                obj_box = obj_box.contiguous().view(boxes.size(0)**2, 4)

                # union box
                union_box = functional.generate_union_box(sub_box, obj_box)
                rel_batch_ind = i + torch.zeros(union_box.size(0),
                                                1,
                                                dtype=boxes.dtype,
                                                device=boxes.device)

                # intersection maps
                box_context_imap = functional.generate_intersection_map(
                    boxes, image_box, self.pool_size)
                sub_union_imap = functional.generate_intersection_map(
                    sub_box, union_box, self.pool_size)
                obj_union_imap = functional.generate_intersection_map(
                    obj_box, union_box, self.pool_size)

            this_context_features = self.context_roi_pool(
                context_features, torch.cat([batch_ind, image_box], dim=-1))
            x, y = this_context_features.chunk(2, dim=1)
            this_object_features = self.object_feature_fuse(
                torch.cat([
                    self.object_roi_pool(object_features,
                                         torch.cat([batch_ind, boxes],
                                                   dim=-1)), x,
                    y * box_context_imap
                ],
                          dim=1))

            this_relation_features = self.relation_roi_pool(
                relation_features, torch.cat([rel_batch_ind, union_box],
                                             dim=-1))
            x, y, z = this_relation_features.chunk(3, dim=1)
            this_relation_features = self.relation_feature_fuse(
                torch.cat([
                    this_object_features[sub_id], this_object_features[obj_id],
                    x, y * sub_union_imap, z * obj_union_imap
                ],
                          dim=1))

            if DEBUG:
                outputs.append(
                    [None, this_object_features, this_relation_features])
            else:
                outputs.append([
                    None,
                    self._norm(
                        self.object_feature_fc(
                            this_object_features.view(boxes.size(0), -1))),
                    self._norm(
                        self.relation_feature_fc(
                            this_relation_features.view(
                                boxes.size(0) * boxes.size(0),
                                -1)).view(boxes.size(0), boxes.size(0), -1)),
                    tube_id_list
                ])

        outputs_new = self.merge_tube_obj_ftr(outputs, feed_dict, mode,
                                              tar_obj_id)
        return outputs_new
    def forward(self, input, objects, objects_length):
        object_features = input

        if self.object_supervision and self.concatenative_pair_representation:
            context_features = self.context_feature_extract(input)
            outputs = list()
            objects_index = 0
            for i in range(input.size(0)):
                box = objects[objects_index:objects_index +
                              objects_length[i].item()]
                #box is a list of object boundaries for the image

                objects_index += objects_length[i].item()

                with torch.no_grad():
                    batch_ind = i + torch.zeros(
                        box.size(0), 1, dtype=box.dtype, device=box.device)

                    # generate a "full-image" bounding box
                    image_h, image_w = input.size(
                        2) * self.downsample_rate, input.size(
                            3) * self.downsample_rate

                    image_box = torch.cat([
                        torch.zeros(
                            box.size(0), 1, dtype=box.dtype,
                            device=box.device),
                        torch.zeros(
                            box.size(0), 1, dtype=box.dtype,
                            device=box.device),
                        image_w + torch.zeros(
                            box.size(0), 1, dtype=box.dtype,
                            device=box.device),
                        image_h + torch.zeros(
                            box.size(0), 1, dtype=box.dtype, device=box.device)
                    ],
                                          dim=-1)

                    # intersection maps
                    box_context_imap = functional.generate_intersection_map(
                        box, image_box, self.pool_size)

                this_context_features = self.context_roi_pool(
                    context_features, torch.cat([batch_ind, image_box],
                                                dim=-1))
                x, y = this_context_features.chunk(2, dim=1)
                this_object_features = self.object_feature_fuse(
                    torch.cat([
                        self.object_roi_pool(
                            object_features, torch.cat([batch_ind, box],
                                                       dim=-1)), x,
                        y * box_context_imap
                    ],
                              dim=1))

                object_representations = self._norm(
                    self.object_feature_fc(
                        this_object_features.view(box.size(0), -1)))

                object_pair_representations = self.objects_to_pair_representations(
                    object_representations)

                if DEBUG:
                    outputs.append([None, this_object_features, None])
                else:
                    outputs.append([
                        None, object_representations,
                        object_pair_representations
                    ])

        elif self.object_supervision and not self.concatenative_pair_representation:
            object_features = input
            context_features = self.context_feature_extract(input)
            relation_features = self.relation_feature_extract(input)

            outputs = list()
            objects_index = 0
            for i in range(input.size(0)):
                box = objects[objects_index:objects_index +
                              objects_length[i].item()]
                objects_index += objects_length[i].item()

                with torch.no_grad():
                    batch_ind = i + torch.zeros(
                        box.size(0), 1, dtype=box.dtype, device=box.device)

                    # generate a "full-image" bounding box
                    image_h, image_w = input.size(
                        2) * self.downsample_rate, input.size(
                            3) * self.downsample_rate
                    image_box = torch.cat([
                        torch.zeros(
                            box.size(0), 1, dtype=box.dtype,
                            device=box.device),
                        torch.zeros(
                            box.size(0), 1, dtype=box.dtype,
                            device=box.device),
                        image_w + torch.zeros(
                            box.size(0), 1, dtype=box.dtype,
                            device=box.device),
                        image_h + torch.zeros(
                            box.size(0), 1, dtype=box.dtype, device=box.device)
                    ],
                                          dim=-1)

                    # meshgrid to obtain the subject and object bounding boxes
                    sub_id, obj_id = jactorch.meshgrid(torch.arange(
                        box.size(0), dtype=torch.int64, device=box.device),
                                                       dim=0)
                    sub_id, obj_id = sub_id.contiguous().view(
                        -1), obj_id.contiguous().view(-1)
                    sub_box, obj_box = jactorch.meshgrid(box, dim=0)
                    sub_box = sub_box.contiguous().view(box.size(0)**2, 4)
                    obj_box = obj_box.contiguous().view(box.size(0)**2, 4)

                    # union box
                    union_box = functional.generate_union_box(sub_box, obj_box)
                    rel_batch_ind = i + torch.zeros(union_box.size(0),
                                                    1,
                                                    dtype=box.dtype,
                                                    device=box.device)

                    # intersection maps
                    box_context_imap = functional.generate_intersection_map(
                        box, image_box, self.pool_size)
                    sub_union_imap = functional.generate_intersection_map(
                        sub_box, union_box, self.pool_size)
                    obj_union_imap = functional.generate_intersection_map(
                        obj_box, union_box, self.pool_size)

                this_context_features = self.context_roi_pool(
                    context_features, torch.cat([batch_ind, image_box],
                                                dim=-1))
                x, y = this_context_features.chunk(2, dim=1)
                this_object_features = self.object_feature_fuse(
                    torch.cat([
                        self.object_roi_pool(
                            object_features, torch.cat([batch_ind, box],
                                                       dim=-1)), x,
                        y * box_context_imap
                    ],
                              dim=1))

                this_relation_features = self.relation_roi_pool(
                    relation_features,
                    torch.cat([rel_batch_ind, union_box], dim=-1))
                x, y, z = this_relation_features.chunk(3, dim=1)
                this_relation_features = self.relation_feature_fuse(
                    torch.cat([
                        this_object_features[sub_id],
                        this_object_features[obj_id], x, y * sub_union_imap,
                        z * obj_union_imap
                    ],
                              dim=1))

                print(
                    self._norm(
                        self.relation_feature_fc(
                            this_relation_features.view(
                                box.size(0) * box.size(0),
                                -1)).view(box.size(0), box.size(0),
                                          -1)).size())
                if DEBUG:
                    outputs.append(
                        [None, this_object_features, this_relation_features])
                else:
                    outputs.append([
                        None,
                        self._norm(
                            self.object_feature_fc(
                                this_object_features.view(box.size(0), -1))),
                        self._norm(
                            self.relation_feature_fc(
                                this_relation_features.view(
                                    box.size(0) * box.size(0),
                                    -1)).view(box.size(0), box.size(0), -1))
                    ])

        elif not self.object_supervision and self.concatenative_pair_representation:
            outputs = list()
            #object_features has shape batch_size x 256 x 16 x 24
            obj_coord_map = coord_map(
                (object_features.size(2), object_features.size(3)),
                self.query.device)

            for i in range(input.size(0)):
                single_scene_object_features = torch.squeeze(
                    object_features[i, :], dim=0)  #dim=256 x 16 x 24
                scene_object_coords = torch.unsqueeze(torch.cat(
                    (single_scene_object_features, obj_coord_map), dim=0),
                                                      dim=0)

                fused_object_coords = torch.squeeze(
                    self.object_coord_fuse(scene_object_coords),
                    dim=0)  #dim=256 x Z x Y

                num_objects = objects_length[i].item()
                relevant_queries = self.query[
                    0:num_objects, :]  #num_objects x feature_dim

                attention_map = self.temperature * torch.einsum(
                    "ij,jkl -> ikl", relevant_queries,
                    fused_object_coords)  #dim=num_objects x Z x Y
                attention_map = nn.Softmax(1)(attention_map.view(
                    num_objects, -1)).view_as(attention_map)
                object_values = torch.einsum(
                    "ijk,ljk -> il", attention_map,
                    fused_object_coords)  #dim=num_objects x 256

                object_representations = self._norm(
                    self.object_features_layer(object_values))

                object_pair_representations = self.objects_to_pair_representations(
                    object_representations)

                outputs.append([
                    None, object_representations, object_pair_representations
                ])

        return outputs