예제 #1
0
    def __getitem__(self, idx):
        #pdb.set_trace()
        n_his = self.args.n_his
        frame_offset = self.args.frame_offset
        idx_video, idx_frame = self.valid_idx[idx][0], self.valid_idx[idx][1]

        objs = []
        attrs = []
        hws = []
        for i in range(idx_frame - n_his * frame_offset,
                       idx_frame + frame_offset + 1, frame_offset):

            frame = self.metadata[idx_video]['frames'][i]
            #frame_filename = frame['frame_filename']
            frame_filename = os.path.join(
                'video_' + str(idx_video).zfill(5),
                str(frame['frame_index'] + 1) + '.png')
            #pdb.set_trace()
            objects = frame['objects']
            n_objects = len(objects)
            sub_id = idx_video // 1000

            full_img_dir = os.path.join(
                self.data_dir,
                'image_' + str(sub_id * 1000).zfill(5) + '-' + str(
                    (sub_id + 1) * 1000).zfill(5))
            img = self.loader(os.path.join(full_img_dir, frame_filename))
            img = np.array(img)[:, :, ::-1].copy()
            img = cv2.resize(img, (self.W, self.H),
                             interpolation=cv2.INTER_AREA).astype(
                                 np.float) / 255.
            ### prepare object inputs
            object_inputs = []
            for j in range(n_objects):
                material = objects[j]['material']
                shape = objects[j]['shape']

                if i == idx_frame - n_his * frame_offset:
                    attrs.append(
                        encode_attr(material, shape, self.bbox_size,
                                    self.args.attr_dim))

                mask_raw = decode(objects[j]['mask'])
                mask = cv2.resize(mask_raw, (self.W, self.H),
                                  interpolation=cv2.INTER_NEAREST)
                # cv2.imshow("mask", mask * 255)
                # cv2.waitKey(0)
                #bbox, pos, box_hw = convert_mask_to_bbox_hw(mask_raw, self.H, self.W, self.bbox_size, objects[j]['mask'])
                bbox, pos = convert_mask_to_bbox(mask_raw, self.H, self.W,
                                                 self.bbox_size)
                pos_mean = torch.FloatTensor(
                    np.array([self.H / 2., self.W / 2.]))
                pos_mean = pos_mean.unsqueeze(1).unsqueeze(1)
                pos_std = pos_mean
                pos = normalize(pos, pos_mean, pos_std)
                mask_crop = normalize(crop(mask, bbox, self.H, self.W), 0.5,
                                      1).unsqueeze(0)
                img_crop = normalize(crop(img, bbox, self.H, self.W), 0.5,
                                     0.5).permute(2, 0, 1)

                if self.args.box_only_flag:
                    bbx_xyxy, ret, crop_box, crop_box_v2 = decode_mask_to_box(
                        objects[j]['mask'], [self.bbox_size, self.bbox_size],
                        self.H, self.W)
                    ret_mean = torch.FloatTensor(
                        np.array([1 / 2., 1 / 2., 1 / 2., 1 / 2.]))
                    ret_mean = ret_mean.unsqueeze(1).unsqueeze(1)
                    ret_std = ret_mean
                    ret = normalize(ret, ret_mean, ret_std)
                    pos = ret[:2]
                    hw = ret[2:]

                elif self.args.add_hw_state_flag:
                    bbx_xyxy, ret, crop_box, crop_box_v2 = decode_mask_to_box(
                        objects[j]['mask'], [self.bbox_size, self.bbox_size],
                        self.H, self.W)
                    ret_mean = torch.FloatTensor(
                        np.array([1 / 2., 1 / 2., 1 / 2., 1 / 2.]))
                    ret_mean = ret_mean.unsqueeze(1).unsqueeze(1)
                    ret_std = ret_mean
                    ret = normalize(ret, ret_mean, ret_std)
                    hw = ret[2:]

                elif self.args.add_xyhw_state_flag:
                    bbx_xyxy, ret, crop_box, crop_box_v2 = decode_mask_to_box(
                        objects[j]['mask'], [self.bbox_size, self.bbox_size],
                        self.H, self.W)
                    ret_mean = torch.FloatTensor(
                        np.array([1 / 2., 1 / 2., 1 / 2., 1 / 2.]))
                    ret_mean = ret_mean.unsqueeze(1).unsqueeze(1)
                    ret_std = ret_mean
                    ret = normalize(ret, ret_mean, ret_std)
                    pos = ret[:2]
                    hw = ret[2:]

                identifier = get_identifier(objects[j])

                if self.args.box_only_flag:
                    s = torch.cat([pos, hw], 0).unsqueeze(0), identifier
                elif self.args.add_hw_state_flag or self.args.add_xyhw_state_flag:
                    s = torch.cat([mask_crop, pos, img_crop, hw],
                                  0).unsqueeze(0), identifier
                elif self.args.rm_mask_state_flag:
                    s = torch.cat([mask_crop * 0, pos, img_crop],
                                  0).unsqueeze(0), identifier
                else:
                    s = torch.cat([mask_crop, pos, img_crop],
                                  0).unsqueeze(0), identifier
                object_inputs.append(s)

            objs.append(object_inputs)

        attr = torch.cat(attrs, 0).view(n_objects, self.args.attr_dim,
                                        self.bbox_size, self.bbox_size)

        feats = []
        for x in range(n_objects):
            feats.append(objs[0][x][0])

        for i in range(1, len(objs)):
            for x in range(n_objects):
                for y in range(n_objects):
                    id_x = objs[0][x][1]
                    id_y = objs[i][y][1]
                    if check_same_identifier(id_x, id_y):
                        feats[x] = torch.cat([feats[x], objs[i][y][0]], 1)

        try:
            feats = torch.cat(feats, 0)
        except:
            print(idx_video, idx_frame)
        # print("feats shape", feats.size())

        ### prepare relation attributes
        n_relations = n_objects * n_objects
        Ra = torch.FloatTensor(
            np.ones((n_relations, self.args.relation_dim *
                     (self.args.n_his + 2), self.bbox_size, self.bbox_size)) *
            -0.5)

        # change to relative position
        relation_dim = self.args.relation_dim
        state_dim = self.args.state_dim
        if self.args.box_only_flag:
            for i in range(n_objects):
                for j in range(n_objects):
                    idx = i * n_objects + j
                    Ra[idx, 1::relation_dim] = feats[i, 0::state_dim] - feats[
                        j, 0::state_dim]  # x
                    Ra[idx, 2::relation_dim] = feats[i, 1::state_dim] - feats[
                        j, 1::state_dim]  # y
        else:
            for i in range(n_objects):
                for j in range(n_objects):
                    idx = i * n_objects + j
                    Ra[idx, 1::relation_dim] = feats[i, 1::state_dim] - feats[
                        j, 1::state_dim]  # x
                    Ra[idx, 2::relation_dim] = feats[i, 2::state_dim] - feats[
                        j, 2::state_dim]  # y

        # add collision attr
        gt = self.metadata[idx_video]['ground_truth']
        gt_ids = gt['objects']
        gt_collisions = gt['collisions']

        label_rel = torch.FloatTensor(
            np.ones((n_objects * n_objects, 1)) * -0.5)

        if self.args.edge_superv:
            for i in range(idx_frame - n_his * frame_offset,
                           idx_frame + frame_offset + 1, frame_offset):

                for j in range(len(gt_collisions)):
                    frame_id = gt_collisions[j]['frame']
                    if 0 <= frame_id - i < self.args.frame_offset:
                        id_0 = gt_collisions[j]['object'][0]
                        id_1 = gt_collisions[j]['object'][1]
                        for k in range(len(gt_ids)):
                            if id_0 == gt_ids[k]['id']:
                                id_x = get_identifier(gt_ids[k])
                            if id_1 == gt_ids[k]['id']:
                                id_y = get_identifier(gt_ids[k])

                        # id_0 = get_identifier(gt_ids[gt_collisions[j]['object'][0]])
                        # id_1 = get_identifier(gt_ids[gt_collisions[j]['object'][1]])

                        for k in range(n_objects):
                            if check_same_identifier(objs[0][k][1], id_x):
                                x = k
                            if check_same_identifier(objs[0][k][1], id_y):
                                y = k

                        idx_rel_xy = x * n_objects + y
                        idx_rel_yx = y * n_objects + x

                        # print(x, y, n_objects)

                        idx = i - (idx_frame - n_his * frame_offset)
                        idx /= frame_offset
                        Ra[idx_rel_xy, int(idx) * relation_dim] = 0.5
                        Ra[idx_rel_yx, int(idx) * relation_dim] = 0.5

                        if i == idx_frame + frame_offset:
                            label_rel[idx_rel_xy] = 1
                            label_rel[idx_rel_yx] = 1
        '''
        print(feats[0, -state_dim])
        print(feats[0, -state_dim+1])
        print(feats[0, -state_dim+2])
        print(feats[0, -state_dim+3])
        print(feats[0, -state_dim+4])
        '''
        '''
        ### change absolute pos to relative pos
        feats[:, state_dim+1::state_dim] = \
                feats[:, state_dim+1::state_dim] - feats[:, 1:-state_dim:state_dim]   # x
        feats[:, state_dim+2::state_dim] = \
                feats[:, state_dim+2::state_dim] - feats[:, 2:-state_dim:state_dim]   # y
        feats[:, 1] = 0
        feats[:, 2] = 0
        '''

        x = feats[:, :-state_dim]
        label_obj = feats[:, -state_dim:]
        if self.args.box_only_flag:
            label_obj[:, 1] -= feats[:, -2 * state_dim + 1]
            label_obj[:, 2] -= feats[:, -2 * state_dim + 2]
            label_obj[:, 0] -= feats[:, -2 * state_dim + 0]
            label_obj[:, 3] -= feats[:, -2 * state_dim + 3]
        else:
            label_obj[:, 1] -= feats[:, -2 * state_dim + 1]
            label_obj[:, 2] -= feats[:, -2 * state_dim + 2]
        rel = prepare_relations(n_objects)
        rel.append(Ra[:, :-relation_dim])
        '''
        print(rel[-1][0, 0])
        print(rel[-1][0, 1])
        print(rel[-1][0, 2])
        print(rel[-1][2, 3])
        print(rel[-1][2, 4])
        print(rel[-1][2, 5])
        '''

        # print("attr shape", attr.size())
        # print("x shape", x.size())
        # print("label_obj shape", label_obj.size())
        # print("label_rel shape", label_rel.size())
        '''
        for i in range(n_objects):
            print(objs[0][i][1])
            print(label_obj[i, 1])

        time.sleep(10)
        '''

        return attr, x, rel, label_obj, label_rel
예제 #2
0
    def __getitem__(self, idx):
        #pdb.set_trace()
        n_his = self.args.n_his
        frame_offset = self.args.frame_offset
        idx_video, idx_frame = self.valid_idx[idx][0], self.valid_idx[idx][1]

        objs = []
        attrs = []
        for i in range(idx_frame - n_his * frame_offset,
                       idx_frame + frame_offset + 1, frame_offset):

            frame = self.metadata[idx_video]['proposals']['frames'][i]
            #frame_filename = frame['frame_filename']
            frame_filename = os.path.join(
                'video_' + str(idx_video).zfill(5),
                str(frame['frame_index'] + 1) + '.png')

            objects = frame['objects']
            n_objects = len(objects)

            vid = int(idx_video / 1000)
            ann_full_dir = os.path.join(
                self.data_dir, 'image_%02d000-%02d000' % (vid, vid + 1))
            img = self.loader(os.path.join(ann_full_dir, frame_filename))
            img = np.array(img)[:, :, ::-1].copy()
            img = cv2.resize(img, (self.W, self.H),
                             interpolation=cv2.INTER_AREA).astype(
                                 np.float) / 255.

            ### prepare object inputs
            object_inputs = []
            for j in range(n_objects):
                material = objects[j]['material']
                shape = objects[j]['shape']

                if i == idx_frame - n_his * frame_offset:
                    attrs.append(
                        encode_attr(material, shape, self.bbox_size,
                                    self.args.attr_dim))

                bbox_xyxy, xyhw_exp, crop_box, crop_box_v2 = decode_mask_to_box(objects[j]['mask'],\
                        [self.bbox_size, self.bbox_size], self.H, self.W)
                #img_crop = normalize(crop(img, crop_box, self.H, self.W), 0.5, 0.5).permute(2, 0, 1)
                img_crop = normalize(crop(img, crop_box_v2, self.H, self.W),
                                     0.5, 0.5).permute(2, 0, 1)
                tube_id = utilsTube.get_tube_id_from_bbox(
                    bbox_xyxy, frame['frame_index'],
                    self.metadata[idx_video]['tubes'])
                if tube_id == -1:
                    pdb.set_trace()
                if self.args.box_only_flag:
                    xyhw_norm = (xyhw_exp - 0.5) / 0.5
                    s = torch.cat([xyhw_norm], 0).unsqueeze(0), tube_id
                elif self.args.new_mode == 1:
                    xyhw_norm = (xyhw_exp - 0.5) / 0.5
                    s = torch.cat([xyhw_norm, img_crop],
                                  0).unsqueeze(0), tube_id
                else:
                    s = torch.cat([xyhw_exp, img_crop],
                                  0).unsqueeze(0), tube_id
                object_inputs.append(s)

            objs.append(object_inputs)

        attr = torch.cat(attrs, 0).view(n_objects, self.args.attr_dim,
                                        self.bbox_size, self.bbox_size)

        feats = []
        for x in range(n_objects):
            feats.append(objs[0][x][0])

        for i in range(1, len(objs)):
            for x in range(n_objects):
                for y in range(n_objects):
                    id_x = objs[0][x][1]
                    id_y = objs[i][y][1]
                    if id_x == id_y:
                        feats[x] = torch.cat([feats[x], objs[i][y][0]], 1)

        try:
            feats = torch.cat(feats, 0)
        except:
            print(idx_video, idx_frame)

        #pdb.set_trace()
        ### prepare relation attributes
        n_relations = n_objects * n_objects
        Ra = torch.FloatTensor(
            np.ones((n_relations, self.args.relation_dim *
                     (self.args.n_his + 2), self.bbox_size, self.bbox_size)) *
            -0.5)

        # change to relative position
        relation_dim = self.args.relation_dim
        state_dim = self.args.state_dim
        if self.args.box_only_flag or self.args.new_mode == 1:
            for i in range(n_objects):
                for j in range(n_objects):
                    idx = i * n_objects + j
                    Ra[idx, 1::relation_dim] = feats[i, 0::state_dim] - feats[
                        j, 0::state_dim]  # x
                    Ra[idx, 2::relation_dim] = feats[i, 1::state_dim] - feats[
                        j, 1::state_dim]  # y
        else:
            for i in range(n_objects):
                for j in range(n_objects):
                    idx = i * n_objects + j
                    Ra[idx, 1::relation_dim] = feats[i, 0::state_dim] - feats[
                        j, 0::state_dim]  # x
                    Ra[idx, 2::relation_dim] = feats[i, 1::state_dim] - feats[
                        j, 1::state_dim]  # y
                    Ra[idx, 3::relation_dim] = feats[i, 2::state_dim] - feats[
                        j, 2::state_dim]  # h
                    Ra[idx, 4::relation_dim] = feats[i, 3::state_dim] - feats[
                        j, 3::state_dim]  # w
        label_rel = torch.FloatTensor(
            np.ones((n_objects * n_objects, 1)) * -0.5)
        '''
        ### change absolute pos to relative pos
        feats[:, state_dim+1::state_dim] = \
                feats[:, state_dim+1::state_dim] - feats[:, 1:-state_dim:state_dim]   # x
        feats[:, state_dim+2::state_dim] = \
                feats[:, state_dim+2::state_dim] - feats[:, 2:-state_dim:state_dim]   # y
        feats[:, 1] = 0
        feats[:, 2] = 0
        '''
        #pdb.set_trace()
        x = feats[:, :-state_dim]
        label_obj = feats[:, -state_dim:]
        label_obj[:, 0] -= feats[:, -2 * state_dim + 0]
        label_obj[:, 1] -= feats[:, -2 * state_dim + 1]
        label_obj[:, 2] -= feats[:, -2 * state_dim + 2]
        label_obj[:, 3] -= feats[:, -2 * state_dim + 3]
        rel = prepare_relations(n_objects)
        rel.append(Ra[:, :-relation_dim])
        '''
        print(rel[-1][0, 0])
        print(rel[-1][0, 1])
        print(rel[-1][0, 2])
        print(rel[-1][2, 3])
        print(rel[-1][2, 4])
        print(rel[-1][2, 5])
        '''

        # print("attr shape", attr.size())
        # print("x shape", x.size())
        # print("label_obj shape", label_obj.size())
        # print("label_rel shape", label_rel.size())
        '''
        for i in range(n_objects):
            print(objs[0][i][1])
            print(label_obj[i, 1])

        time.sleep(10)
        '''

        return attr, x, rel, label_obj, label_rel
예제 #3
0
def forward_step(frames, model, objs_gt=None):

    n_frames = len(frames)

    if n_frames < args.n_his + 1:
        return [], [], []

    ##### filter frames to predict
    # st_time = time.time()
    ids_predict = []
    objs_first = frames[0][0]
    for i in range(len(objs_first)):
        id = objs_first[i][2]

        id_to_predict = True
        for j in range(1, n_frames):
            objs = frames[j][0]

            contain_id = False
            for k in range(len(objs)):
                if check_same_identifier(objs[k][2], id):
                    contain_id = True
                    break
            if not contain_id:
                id_to_predict = False
                break

        if id_to_predict:
            ids_predict.append(id)

    n_objects = len(ids_predict)

    if n_objects == 0:
        return [], [], []

    # print("Time - filter frame", time.time() - st_time)

    ##### prepare inputs
    # st_time = time.time()
    feats_rec = []
    attrs = []
    for i in range(n_frames):
        objs = frames[i][0]

        feat_cur_frame = []
        for j in range(len(ids_predict)):
            for k in range(len(objs)):
                attr, feat, id = objs[k]
                if check_same_identifier(ids_predict[j], id):
                    feat_cur_frame.append(feat.clone())
                    if i == 0:
                        attrs.append(attr.unsqueeze(0).clone())
                    break

        feats_rec.append(torch.cat(feat_cur_frame, 0))

    attrs = torch.cat(attrs, 0)
    feats = torch.cat(feats_rec.copy(), 1)

    n_relations = n_objects * n_objects
    Ra = torch.FloatTensor(
        np.ones((n_relations, args.relation_dim *
                 (args.n_his + 1), args.bbox_size, args.bbox_size)) * -0.5)

    relation_dim = args.relation_dim
    state_dim = args.state_dim
    for i in range(n_objects):
        for j in range(n_objects):
            idx = i * n_objects + j
            Ra[idx, 1::relation_dim] = feats[i, 1::state_dim] - feats[
                j, 1::state_dim]  # x
            Ra[idx, 2::relation_dim] = feats[i, 2::state_dim] - feats[
                j, 2::state_dim]  # y

    if args.edge_superv:
        for i in range(n_frames):
            rels = frames[i][1]
            for j in range(len(rels)):
                id_0, id_1 = rels[j][0], rels[j][1]
                x, y = -1, -1

                for k in range(n_objects):
                    if check_same_identifier(id_0, ids_predict[k]):
                        x = k
                    if check_same_identifier(id_1, ids_predict[k]):
                        y = k

                # if x == -1 or y == -1:
                # continue

                idx_rel_xy = x * n_objects + y
                idx_rel_yx = y * n_objects + x
                Ra[idx_rel_xy, i * relation_dim] = 0.5
                Ra[idx_rel_yx, i * relation_dim] = 0.5
    '''
    # change absolute pos to relative pos
    feats[:, state_dim+1::state_dim] = \
            feats[:, state_dim+1::state_dim] - feats[:, 1:-state_dim:state_dim]   # x
    feats[:, state_dim+2::state_dim] = \
            feats[:, state_dim+2::state_dim] - feats[:, 2:-state_dim:state_dim]   # y
    feats[:, 1] = 0
    feats[:, 2] = 0
    '''
    rel = prepare_relations(n_objects)
    rel.append(Ra)

    # print("Time - prepare inputs", time.time() - st_time)

    ##### predict
    # st_time = time.time()
    node_r_idx, node_s_idx, Ra = rel[3], rel[4], rel[5]
    Rr_idx, Rs_idx, value = rel[0], rel[1], rel[2]

    Rr = torch.sparse.FloatTensor(
        Rr_idx, value, torch.Size([node_r_idx.shape[0],
                                   value.size(0)]))
    Rs = torch.sparse.FloatTensor(
        Rs_idx, value, torch.Size([node_s_idx.shape[0],
                                   value.size(0)]))

    data = [attrs, feats, Rr, Rs, Ra]

    with torch.set_grad_enabled(False):
        for d in range(len(data)):
            if use_gpu:
                data[d] = Variable(data[d].cuda())
            else:
                data[d] = Variable(data[d])

        attr, feats, Rr, Rs, Ra = data
        # print('attr size', attr.size())
        # print('feats size', feats.size())
        # print('Rr size', Rr.size())
        # print('Rs size', Rs.size())
        # print('Ra size', Ra.size())

        # st_time = time.time()
        pred_obj, pred_rel, pred_feat = model(attr,
                                              feats,
                                              Rr,
                                              Rs,
                                              Ra,
                                              node_r_idx,
                                              node_s_idx,
                                              args.pstep,
                                              ret_feat=True)
        # print(time.time() - st_time)

    # print("Time - predict", time.time() - st_time)

    #### transform format
    # st_time = time.time()
    objs_pred = []
    rels_pred = []
    feats_pred = []

    pred_obj = pred_obj.data.cpu()
    pred_rel = pred_rel.data.cpu()
    pred_feat = pred_feat.data.cpu()

    assert pred_obj.shape[0] == pred_feat.shape[0]
    assert pred_feat.shape[1] == args.nf_effect
    '''
    mask = pred_obj[:, 0]
    position = pred_obj[:, 1:3]
    image = pred_obj[:, 3:]
    collision = pred_rel

    print('mask\n', mask)
    print('x\n', position[0])
    print('y\n', position[1])
    print('img\n', image[0])

    time.sleep(10)
    '''

    # print('pred_obj shape', pred_obj.shape)
    # print('pred_rel shape', pred_rel.shape)

    if objs_gt is not None:
        for i in range(n_objects):
            # cnt_id_in_gt = 0
            # id_gt = -1
            for j in range(len(objs_gt)):
                if check_same_identifier(ids_predict[i], objs_gt[j][2]):
                    objs_pred.append(objs_gt[j])
                    feats_pred.append(pred_feat[i])
                    break
                    # id_gt = j
                    # cnt_id_in_gt += 1
            '''
            if cnt_id_in_gt == 0 or cnt_id_in_gt > 1:
                feat = pred_obj[i:i+1]
                feat[0, 1] += feats_rec[-1][i, 1]   # x
                feat[0, 2] += feats_rec[-1][i, 2]   # y
                feat[0, 0, feat[0, 0] >= 0] = 0.5   # mask
                feat[0, 0, feat[0, 0] < 0] = -0.5

                obj = [attrs[i], feat, ids_predict[i]]
                objs_pred.append(obj)
            else:
                objs_pred.append(objs_gt[id_gt])
            '''

    else:
        for i in range(n_objects):
            feat = pred_obj[i:i + 1]

            # print(ids_predict[i])
            # print(feat[0, 1])
            # print(feats_rec[-1][i, 1])

            feat[0, 1] += feats_rec[-1][i, 1]  # x

            # print(feat[0, 1])
            # time.sleep(1)
            feat[0, 2] += feats_rec[-1][i, 2]  # y
            feat[0, 0, feat[0, 0] >= 0] = 0.5  # mask
            feat[0, 0, feat[0, 0] < 0] = -0.5

            obj = [attrs[i], feat, ids_predict[i]]
            objs_pred.append(obj)
            feats_pred.append(pred_feat[i])
    '''
    print(objs_pred[0][1][0, 0])
    print(objs_pred[0][1][0, 1])
    print(objs_pred[0][1][0, 2])
    print(objs_pred[0][1][0, 3])
    print(objs_pred[0][1][0, 4])
    '''

    for i in range(n_relations):
        x = i // n_objects
        y = i % n_objects
        if x >= y:
            continue

        idx_0, idx_1 = i, y * n_objects + x
        if pred_rel[idx_0] + pred_rel[idx_1] > 0 and args.edge_superv:
            rels_pred.append([ids_predict[x], ids_predict[y]])

    # print("Time - transform format", time.time() - st_time)

    return objs_pred, rels_pred, feats_pred
def forward_step(frames, model, objs_gt=None, args=None):

    n_frames = len(frames)

    if n_frames < args.n_his + 1:
        return [], [], []

    ##### filter frames to predict
    # st_time = time.time()
    ids_predict = []
    objs_first = frames[0][0]
    for i in range(len(objs_first)):
        id = objs_first[i][2]

        id_to_predict = True
        for j in range(1, n_frames):
            objs = frames[j][0]

            contain_id = False
            for k in range(len(objs)):
                if objs[k][2] == id:
                    contain_id = True
                    break
            if not contain_id:
                id_to_predict = False
                break

        if id_to_predict:
            ids_predict.append(id)

    n_objects = len(ids_predict)

    if n_objects == 0:
        return [], [], []

    # print("Time - filter frame", time.time() - st_time)

    ##### prepare inputs
    # st_time = time.time()
    feats_rec = []
    attrs = []
    for i in range(n_frames):
        objs = frames[i][0]

        feat_cur_frame = []
        for j in range(len(ids_predict)):
            for k in range(len(objs)):
                attr, feat, id = objs[k]
                if ids_predict[j] == id:
                    feat_cur_frame.append(feat.clone())
                    if i == 0:
                        attrs.append(attr.unsqueeze(0).clone())
                    break

        feats_rec.append(torch.cat(feat_cur_frame, 0))

    attrs = torch.cat(attrs, 0)
    feats = torch.cat(feats_rec.copy(), 1)
    n_relations = n_objects * n_objects

    #pdb.set_trace()

    if args.separate_mode == 1:
        # update spatial features
        relation_dim_spatial = args.relation_dim_spatial
        state_dim_spatial = args.state_dim_spatial
        x_step = args.n_his + 1
        feats_view = feats.view(n_objects, x_step, args.state_dim,
                                args.bbox_size, args.bbox_size)
        feats_spatial = feats_view[:, :, :state_dim_spatial].contiguous().view(
            n_objects, x_step * state_dim_spatial, args.bbox_size,
            args.bbox_size)
        feats_spatial = (feats_spatial - 0.5) / 0.5
        #if objs_gt is None:
        #    pdb.set_trace()
        Ra_spatial = torch.FloatTensor(
            np.ones((n_relations, args.relation_dim_spatial *
                     (args.n_his + 1), args.bbox_size, args.bbox_size)) * -0.5)
        for i in range(n_objects):
            for j in range(n_objects):
                idx = i * n_objects + j
                Ra_spatial[idx, 1::relation_dim_spatial] = feats_spatial[
                    i, 0::state_dim_spatial] - feats_spatial[
                        j, 0::state_dim_spatial]  # x
                Ra_spatial[idx, 2::relation_dim_spatial] = feats_spatial[
                    i, 1::state_dim_spatial] - feats_spatial[
                        j, 1::state_dim_spatial]  # y
        rel = prepare_relations(n_objects)
        rel.append(Ra_spatial)

        # st_time = time.time()
        node_r_idx, node_s_idx, Ra_spatial = rel[3], rel[4], rel[5]
        Rr_idx, Rs_idx, value = rel[0], rel[1], rel[2]

        Rr = torch.sparse.FloatTensor(
            Rr_idx, value, torch.Size([node_r_idx.shape[0],
                                       value.size(0)]))
        Rs = torch.sparse.FloatTensor(
            Rs_idx, value, torch.Size([node_s_idx.shape[0],
                                       value.size(0)]))

        data = [attrs, feats_spatial, Rr, Rs, Ra_spatial]

        with torch.set_grad_enabled(False):
            for d in range(len(data)):
                if use_gpu:
                    data[d] = Variable(data[d].cuda())
                else:
                    data[d] = Variable(data[d])
            attr, feats_spatial, Rr, Rs, Ra_spatial = data
            pred_obj_spa, pred_rel_spa, pred_feat_spa = model._spatial_model(
                attr,
                feats_spatial,
                Rr,
                Rs,
                Ra_spatial,
                node_r_idx,
                node_s_idx,
                args.pstep,
                ret_feat=True)
            feat_spa_list = []
            for i in range(n_objects):
                feat_spa = pred_obj_spa[i]
                feat_spa[0] += feats_spatial[i,
                                             state_dim_spatial * args.n_his +
                                             0]  # x
                feat_spa[1] += feats_spatial[i,
                                             state_dim_spatial * args.n_his +
                                             1]  # y
                feat_spa[2] += feats_spatial[i,
                                             state_dim_spatial * args.n_his +
                                             2]  # h
                feat_spa[3] += feats_spatial[i,
                                             state_dim_spatial * args.n_his +
                                             3]  # w
                feat_spa = 0.5 * feat_spa + 0.5
                feat_spa_list.append(feat_spa)

    Ra = torch.FloatTensor(
        np.ones((n_relations, args.relation_dim *
                 (args.n_his + 1), args.bbox_size, args.bbox_size)) * -0.5)

    relation_dim = args.relation_dim
    state_dim = args.state_dim
    for i in range(n_objects):
        for j in range(n_objects):
            idx = i * n_objects + j
            if args.box_only_flag or args.new_mode == 1:
                Ra[idx, 1::relation_dim] = feats[i, 0::state_dim] - feats[
                    j, 0::state_dim]  # x
                Ra[idx, 2::relation_dim] = feats[i, 1::state_dim] - feats[
                    j, 1::state_dim]  # y
            else:
                Ra[idx, 1::relation_dim] = feats[i, 0::state_dim] - feats[
                    j, 0::state_dim]  # x
                Ra[idx, 2::relation_dim] = feats[i, 1::state_dim] - feats[
                    j, 1::state_dim]  # y
                Ra[idx, 3::relation_dim] = feats[i, 2::state_dim] - feats[
                    j, 2::state_dim]  # h
                Ra[idx, 4::relation_dim] = feats[i, 3::state_dim] - feats[
                    j, 3::state_dim]  # w

    if args.edge_superv:
        for i in range(n_frames):
            rels = frames[i][1]
            for j in range(len(rels)):
                id_0, id_1 = rels[j][0], rels[j][1]
                x, y = -1, -1

                for k in range(n_objects):
                    if check_same_identifier(id_0, ids_predict[k]):
                        x = k
                    if check_same_identifier(id_1, ids_predict[k]):
                        y = k

                idx_rel_xy = x * n_objects + y
                idx_rel_yx = y * n_objects + x
                Ra[idx_rel_xy, i * relation_dim] = 0.5
                Ra[idx_rel_yx, i * relation_dim] = 0.5

    rel = prepare_relations(n_objects)
    rel.append(Ra)

    # st_time = time.time()
    node_r_idx, node_s_idx, Ra = rel[3], rel[4], rel[5]
    Rr_idx, Rs_idx, value = rel[0], rel[1], rel[2]

    Rr = torch.sparse.FloatTensor(
        Rr_idx, value, torch.Size([node_r_idx.shape[0],
                                   value.size(0)]))
    Rs = torch.sparse.FloatTensor(
        Rs_idx, value, torch.Size([node_s_idx.shape[0],
                                   value.size(0)]))

    data = [attrs, feats, Rr, Rs, Ra]

    with torch.set_grad_enabled(False):
        for d in range(len(data)):
            if use_gpu:
                data[d] = Variable(data[d].cuda())
            else:
                data[d] = Variable(data[d])

        attr, feats, Rr, Rs, Ra = data

        # st_time = time.time()
        pred_obj, pred_rel, pred_feat = model(attr,
                                              feats,
                                              Rr,
                                              Rs,
                                              Ra,
                                              node_r_idx,
                                              node_s_idx,
                                              args.pstep,
                                              ret_feat=True)
        # print(time.time() - st_time)

    #### transform format
    # st_time = time.time()
    #pdb.set_trace()
    objs_pred = []
    rels_pred = []
    feats_pred = []

    pred_obj = pred_obj.data.cpu()
    pred_rel = pred_rel.data.cpu()
    pred_feat = pred_feat.data.cpu()

    assert pred_obj.shape[0] == pred_feat.shape[0]
    assert pred_feat.shape[1] == args.nf_effect
    '''
    mask = pred_obj[:, 0]
    position = pred_obj[:, 1:3]
    image = pred_obj[:, 3:]
    collision = pred_rel

    print('mask\n', mask)
    print('x\n', position[0])
    print('y\n', position[1])
    print('img\n', image[0])

    time.sleep(10)
    '''

    if objs_gt is not None:
        for i in range(n_objects):
            # cnt_id_in_gt = 0
            # id_gt = -1
            for j in range(len(objs_gt)):
                if ids_predict[i] == objs_gt[j][2]:
                    objs_pred.append(objs_gt[j])
                    feats_pred.append(pred_feat[i])
                    break
                    # id_gt = j
                    # cnt_id_in_gt += 1
            '''
            if cnt_id_in_gt == 0 or cnt_id_in_gt > 1:
                feat = pred_obj[i:i+1]
                feat[0, 1] += feats_rec[-1][i, 1]   # x
                feat[0, 2] += feats_rec[-1][i, 2]   # y
                feat[0, 0, feat[0, 0] >= 0] = 0.5   # mask
                feat[0, 0, feat[0, 0] < 0] = -0.5

                obj = [attrs[i], feat, ids_predict[i]]
                objs_pred.append(obj)
            else:
                objs_pred.append(objs_gt[id_gt])
            '''

    else:
        for i in range(n_objects):
            feat = pred_obj[i:i + 1]

            feat[0, 0] += feats_rec[-1][i, 0]  # x
            feat[0, 1] += feats_rec[-1][i, 1]  # y
            feat[0, 2] += feats_rec[-1][i, 2]  # h
            feat[0, 3] += feats_rec[-1][i, 3]  # w

            if args.separate_mode == 1:
                feat[:, :4] = feat_spa_list[i]
            # masking out object
            if not args.box_only_flag and args.maskout_pixel_inference_flag:
                feat = utilsTube.maskout_pixels_outside_box(
                    feat, args.H, args.W, args.bbox_size)

            obj = [attrs[i], feat, ids_predict[i]]
            objs_pred.append(obj)
            feats_pred.append(pred_feat[i])

    for i in range(n_relations):
        x = i // n_objects
        y = i % n_objects
        if x >= y:
            continue

        idx_0, idx_1 = i, y * n_objects + x
        if pred_rel[idx_0] + pred_rel[idx_1] > 0 and args.edge_superv:
            rels_pred.append([ids_predict[x], ids_predict[y]])

    # print("Time - transform format", time.time() - st_time)

    return objs_pred, rels_pred, feats_pred