예제 #1
0
def parse_data(data_const, args):
    # just focus on HOI samplers, remove those action with on objects
    action_class_num = len(vcoco_metadata.action_classes) - len(vcoco_metadata.action_no_obj)
    # no_action_index = vcoco_metadata.action_index['none']
    no_role_index = vcoco_metadata.role_index['none']
    # Load COCO annotations for V-COCO images
    coco = vu.load_coco()
    for subset in ["vcoco_train", "vcoco_test", "vcoco_val"]:
        # create file object to save the parsed data
        if not args.vis_result:
            print('{} data will be saved into {}/vcoco_data.hdf5 file'.format(subset.split("_")[1], subset))
            hdf5_file = os.path.join(data_const.proc_dir, subset, 'vcoco_data.hdf5')
            save_data = h5py.File(hdf5_file, 'w')
            # evaluate detection
            eval_det_file = os.path.join(data_const.proc_dir, subset, 'eval_det_result.json')
            gt_record = {n:0 for n in vcoco_metadata.action_class_with_object}
            det_record = gt_record.copy()

        # load selected data
        selected_det_data = h5py.File(os.path.join(data_const.proc_dir, subset, "selected_coco_cls_dets.hdf5"), 'r')

        # Load the VCOCO annotations for vcoco_train image set
        vcoco_all = vu.load_vcoco(subset)
        for x in vcoco_all:
            x = vu.attach_gt_boxes(x, coco)
            # record groundtruths
            if x['action_name'] in vcoco_metadata.action_class_with_object:
                if len(x['role_name']) == 2:
                    gt_record[x['action_name']] = sum(x['label'][:,0])
                else:
                    for i in range(x['label'].shape[0]):
                        if x['label'][i,0] == 1:
                            role_bbox = x['role_bbox'][i, :] * 1.
                            role_bbox = role_bbox.reshape((-1, 4))
                            for i_role in range(1, len(x['role_name'])):
                                if x['role_name'][i_role]=='instr' and (not np.isnan(role_bbox[i_role, :][0])):
                                    gt_record[x['action_name']+'_with'] +=1
                                    continue
                                if x['role_name'][i_role]=='obj' and (not np.isnan(role_bbox[i_role, :][0])):
                                    gt_record[x['action_name']] +=1                               
        # print(gt_record)
        image_ids = vcoco_all[0]['image_id'][:,0].astype(int).tolist()
        # all_results = list()
        unique_image_ids = list()
        for i_image, image_id in enumerate(image_ids):
            img_name = coco.loadImgs(ids=image_id)[0]['coco_url'].split('.org')[1][1:]
            # get image size
            img_gt = Image.open(os.path.join(data_const.original_image_dir, img_name)).convert('RGB')
            img_size = img_gt.size
            # load corresponding selected data for image_id 
            det_boxes = selected_det_data[str(image_id)]['boxes_scores_rpn_ids'][:,:4]
            det_scores = selected_det_data[str(image_id)]['boxes_scores_rpn_ids'][:,4]
            det_classes = selected_det_data[str(image_id)]['boxes_scores_rpn_ids'][:,-1].astype(int)
            det_features = selected_det_data[str(image_id)]['features']
            # calculate the number of nodes
            human_num = len(np.where(det_classes==1)[0])
            node_num = len(det_classes)
            obj_num = node_num - human_num
            labeled_edge_num = human_num * (node_num-1) 
            # labeled_edge_num = human_num * obj_num      # test: just consider h-o
            if image_id not in unique_image_ids:
                unique_image_ids.append(image_id)
                # construct empty edge labels
                edge_labels = np.zeros((labeled_edge_num, action_class_num))
                edge_roles = np.zeros((labeled_edge_num, 3))
                # edge_labels[:, no_action_index]=1    
                edge_roles[:, no_role_index] = 1
            else:
                if not args.vis_result:
                    edge_labels = save_data[str(image_id)]['edge_labels']
                    edge_roles = save_data[str(image_id)]['edge_roles']
                else:
                    continue
            # import ipdb; ipdb.set_trace()
            # Ground truth labels
            for x in vcoco_all:
                if x['label'][i_image,0] == 1:
                    if x['action_name'] in vcoco_metadata.action_no_obj:
                        continue
                    # role_bbox contain (agent,object/instr)
                    # if i_image == 16:
                    #     import ipdb; ipdb.set_trace()
                    role_bbox = x['role_bbox'][i_image, :] * 1.
                    role_bbox = role_bbox.reshape((-1, 4))
                    # match human box
                    bbox = role_bbox[0, :]
                    human_index = get_node_index(bbox, det_boxes, range(human_num))
                    if human_index == -1:
                        warnings.warn('human detection missing')
                        # print(img_name)
                        continue
                    assert human_index < human_num
                    # match object box
                    for i_role in range(1, len(x['role_name'])):
                        action_name = x['action_name']
                        if x['role_name'][i_role]=='instr' and (x['action_name'] == 'cut' or x['action_name'] == 'eat' or x['action_name'] == 'hit'):
                            action_index = vcoco_metadata.action_with_obj_index[x['action_name']+'_with']
                            action_name +='_with'
                            # import ipdb; ipdb.set_trace()
                            # print('testing')
                        else:
                            action_index = vcoco_metadata.action_with_obj_index[x['action_name']]
                        bbox = role_bbox[i_role, :]
                        if np.isnan(bbox[0]):
                            continue
                        if args.vis_result:
                            img_gt = vis_img_vcoco(img_gt, [role_bbox[0,:], role_bbox[i_role,:]], 1, raw_action=action_index, data_gt=True)
                        obj_index = get_node_index(bbox, det_boxes, range(node_num))    # !Note: Take the human into account
                        # obj_index = get_node_index(bbox, det_boxes, range(human_num, node_num))  # test
                        if obj_index == -1:
                            warnings.warn('object detection missing')
                            # print(img_name)
                            continue
                        if obj_index == human_index:
                            warnings.warn('human detection is the same to object detection')
                            # print(img_name)
                            continue
                        # match labels
                        # if human_index == 0:
                        #     edge_index = obj_index - 1
                        if human_index > obj_index:
                            edge_index = human_index * (node_num-1) + obj_index
                        else:
                            edge_index = human_index * (node_num-1) + obj_index - 1
                            # edge_index = human_index * obj_num + obj_index - human_num  #test
                        det_record[action_name] +=1
                        edge_labels[edge_index, action_index] = 1
                        # edge_labels[edge_index, no_action_index] = 0
                        edge_roles[edge_index, vcoco_metadata.role_index[x['role_name'][i_role]]] = 1
                        edge_roles[edge_index, no_role_index] = 0
                        
            # visualizing result instead of saving result
            if args.vis_result:
                # ipdb.set_trace()
                image_res = Image.open(os.path.join(data_const.original_image_dir, img_name)).convert('RGB')
                result = vis_img_vcoco(image_res, det_boxes, det_classes, det_scores, edge_labels, score_thresh=0.4)
                plt.figure(figsize=(100,100))
                plt.suptitle(img_name)
                plt.subplot(1,2,1)
                plt.imshow(np.array(img_gt))
                plt.title('all_ground_truth'+str(i_image))
                plt.subplot(1,2,2)
                plt.imshow(np.array(result))
                plt.title('selected_ground_truth')
                # plt.axis('off')
                plt.ion()
                plt.pause(1)
                plt.close()
            # save process data
            else:
                if str(image_id) not in save_data.keys():
                    # import ipdb; ipdb.set_trace()
                    save_data.create_group(str(image_id))
                    save_data[str(image_id)].create_dataset('img_name', data=np.fromstring(img_name, dtype=np.uint8).astype('float64'))
                    save_data[str(image_id)].create_dataset('img_size', data=img_size)
                    save_data[str(image_id)].create_dataset('boxes', data=det_boxes)
                    save_data[str(image_id)].create_dataset('classes', data=det_classes)
                    save_data[str(image_id)].create_dataset('scores', data=det_scores)
                    save_data[str(image_id)].create_dataset('feature', data=det_features)
                    save_data[str(image_id)].create_dataset('node_num', data=node_num)
                    save_data[str(image_id)].create_dataset('edge_labels', data=edge_labels)
                    save_data[str(image_id)].create_dataset('edge_roles', data=edge_roles)
                else:
                    save_data[str(image_id)]['edge_labels'][:] = edge_labels
                    save_data[str(image_id)]['edge_roles'][:] = edge_roles  
        if not args.vis_result:   
            save_data.close()      
            print("Finished parsing data!")   
        # eval object detection
        eval_single = {n:det_record[n]/gt_record[n] for n in vcoco_metadata.action_class_with_object}
        eval_all = sum(det_record.values()) / sum(gt_record.values())
        eval_det_result = {
            'gt': gt_record,
            'det': det_record,
            'eval_single': eval_single,
            'eval_all': eval_all
        }
        io.dump_json_object(eval_det_result, eval_det_file)
예제 #2
0
def epoch_train(model, dataloader, dataset, criterion, optimizer, scheduler,
                device, data_const):
    print('epoch training...')

    # set visualization and create folder to save checkpoints
    writer = SummaryWriter(log_dir=args.log_dir + '/' + args.exp_ver + '/' +
                           'epoch_train')
    io.mkdir_if_not_exists(os.path.join(args.save_dir, args.exp_ver,
                                        'epoch_train'),
                           recursive=True)

    for epoch in range(args.start_epoch, args.epoch):
        # each epoch has a training and validation step
        epoch_loss = 0
        for phase in ['train', 'val']:
            start_time = time.time()
            running_loss = 0
            # all_edge = 0
            idx = 0

            VcocoDataset.data_sample_count = 0
            for data in tqdm(dataloader[phase]):
                train_data = data
                img_name = train_data['img_name']
                det_boxes = train_data['det_boxes']
                roi_labels = train_data['roi_labels']
                roi_scores = train_data['roi_scores']
                node_num = train_data['node_num']
                edge_labels = train_data['edge_labels']
                edge_num = train_data['edge_num']
                features = train_data['features']
                spatial_feat = train_data['spatial_feat']
                word2vec = train_data['word2vec']
                features, spatial_feat, word2vec, edge_labels = features.to(
                    device), spatial_feat.to(device), word2vec.to(
                        device), edge_labels.to(device)
                if idx == 10: break
                if phase == 'train':
                    model.train()
                    model.zero_grad()
                    outputs = model(node_num, features, spatial_feat, word2vec,
                                    roi_labels)
                    # import ipdb; ipdb.set_trace()
                    loss = criterion(outputs, edge_labels.float())
                    loss.backward()
                    optimizer.step()

                else:
                    model.eval()
                    # turn off the gradients for validation, save memory and computations
                    with torch.no_grad():
                        outputs = model(node_num,
                                        features,
                                        spatial_feat,
                                        word2vec,
                                        roi_labels,
                                        validation=True)
                        loss = criterion(outputs, edge_labels.float())
                    # print result every 1000 iteration during validation
                    if idx == 0 or idx % round(
                            1000 / args.batch_size) == round(
                                1000 / args.batch_size) - 1:
                        # ipdb.set_trace()
                        image = Image.open(
                            os.path.join(
                                data_const.original_image_dir,
                                img_name[0][:].astype(
                                    np.uint8).tostring().decode(
                                        'ascii'))).convert('RGB')
                        image_temp = image.copy()
                        raw_outputs = nn.Sigmoid()(outputs[0:int(edge_num[0])])
                        raw_outputs = raw_outputs.cpu().detach().numpy()
                        # class_img = vis_img(image, det_boxes, roi_labels, roi_scores)
                        class_img = vis_img_vcoco(
                            image,
                            det_boxes[0],
                            roi_labels[0],
                            roi_scores[0],
                            edge_labels[0:int(edge_num[0])].cpu().numpy(),
                            score_thresh=0.7)
                        action_img = vis_img_vcoco(image_temp,
                                                   det_boxes[0],
                                                   roi_labels[0],
                                                   roi_scores[0],
                                                   raw_outputs,
                                                   score_thresh=0.5)
                        writer.add_image(
                            'gt_detection',
                            np.array(class_img).transpose(2, 0, 1))
                        writer.add_image(
                            'action_detection',
                            np.array(action_img).transpose(2, 0, 1))
                        writer.add_text(
                            'img_name', img_name[0][:].astype(
                                np.uint8).tostring().decode('ascii'), epoch)

                idx += 1
                # accumulate loss of each batch
                running_loss += loss.item() * edge_labels.shape[0]
                # all_edge += edge_labels.shape[0]
            # calculate the loss and accuracy of each epoch
            epoch_loss = running_loss / len(dataset[phase])
            # epoch_loss = running_loss / all_edge
            # import ipdb; ipdb.set_trace()
            # log trainval datas, and visualize them in the same graph
            if phase == 'train':
                train_loss = epoch_loss
                VcocoDataset.displaycount()
            else:
                writer.add_scalars('trainval_loss_epoch', {
                    'train': train_loss,
                    'val': epoch_loss
                }, epoch)
            # print data
            if (epoch % args.print_every) == 0:
                end_time = time.time()
                print("[{}] Epoch: {}/{} Loss: {} Execution time: {}".format(\
                        phase, epoch+1, args.epoch, epoch_loss, (end_time-start_time)))

        # scheduler.step()
        # save model
        if epoch_loss < 0.0405 or epoch % args.save_every == (
                args.save_every - 1) and epoch >= (500 - 1):
            checkpoint = {
                'lr': args.lr,
                'b_s': args.batch_size,
                'bias': args.bias,
                'bn': args.bn,
                'dropout': args.drop_prob,
                'layers': args.layers,
                'feat_type': args.feat_type,
                'multi_head': args.multi_attn,
                'diff_edge': args.diff_edge,
                'state_dict': model.state_dict()
            }
            save_name = "checkpoint_" + str(epoch + 1) + '_epoch.pth'
            torch.save(
                checkpoint,
                os.path.join(args.save_dir, args.exp_ver, 'epoch_train',
                             save_name))

    writer.close()
    print('Finishing training!')
예제 #3
0
def main(args):
    # Load checkpoint and set up model
    try:
        # use GPU if available else revert to CPU
        device = torch.device(
            'cuda:0' if torch.cuda.is_available() and args.gpu else 'cpu')
        print("Testing on", device)

        # set up model and initialize it with uploaded checkpoint
        if args.dataset == 'hico':
            # load checkpoint
            checkpoint = torch.load(args.main_pretrained_hico,
                                    map_location=device)
            print('vsgats Checkpoint loaded!')
            pg_checkpoint = torch.load(args.pretrained_hico,
                                       map_location=device)
            data_const = HicoConstants(feat_type=checkpoint['feat_type'])
            vs_gats = vsgat_hico(feat_type=checkpoint['feat_type'],
                                 bias=checkpoint['bias'],
                                 bn=checkpoint['bn'],
                                 dropout=checkpoint['dropout'],
                                 multi_attn=checkpoint['multi_head'],
                                 layer=checkpoint['layers'],
                                 diff_edge=checkpoint['diff_edge'])  #2 )
        if args.dataset == 'vcoco':
            # load checkpoint
            checkpoint = torch.load(args.main_pretrained_vcoco,
                                    map_location=device)
            print('vsgats Checkpoint loaded!')
            pg_checkpoint = torch.load(args.pretrained_vcoco,
                                       map_location=device)
            data_const = VcocoConstants()
            vs_gats = vsgat_vcoco(feat_type=checkpoint['feat_type'],
                                  bias=checkpoint['bias'],
                                  bn=checkpoint['bn'],
                                  dropout=checkpoint['dropout'],
                                  multi_attn=checkpoint['multi_head'],
                                  layer=checkpoint['layers'],
                                  diff_edge=checkpoint['diff_edge'])  #2 )
        vs_gats.load_state_dict(checkpoint['state_dict'])
        vs_gats.to(device)
        vs_gats.eval()

        print(pg_checkpoint['o_c_l'], pg_checkpoint['b_l'],
              pg_checkpoint['attn'], pg_checkpoint['lr'],
              pg_checkpoint['dropout'])
        # pgception = PGception(action_num=24, classifier_mod='cat', o_c_l=[64,64,128,128], last_h_c=256, bias=pg_checkpoint['bias'], drop=pg_checkpoint['dropout'], bn=pg_checkpoint['bn'])
        pgception = PGception(action_num=pg_checkpoint['a_n'],
                              layers=1,
                              classifier_mod=pg_checkpoint['classifier_mod'],
                              o_c_l=pg_checkpoint['o_c_l'],
                              last_h_c=pg_checkpoint['last_h_c'],
                              bias=pg_checkpoint['bias'],
                              drop=pg_checkpoint['dropout'],
                              bn=pg_checkpoint['bn'],
                              agg_first=pg_checkpoint['agg_first'],
                              attn=pg_checkpoint['attn'],
                              b_l=pg_checkpoint['b_l'])
        # pgception = PGception(action_num=pg_checkpoint['a_n'], drop=pg_checkpoint['dropout'])
        pgception.load_state_dict(pg_checkpoint['state_dict'])
        pgception.to(device)
        pgception.eval()
        print('Constructed model successfully!')
    except Exception as e:
        print('Failed to load checkpoint or construct model!', e)
        sys.exit(1)

    # prepare for data
    if args.dataset == 'hico':
        original_imgs_dir = os.path.join(data_const.infer_dir,
                                         'original_imgs/hico')
        # original_imgs_dir = './datasets/hico/images/test2015'
        save_path = os.path.join(data_const.infer_dir, 'processed_imgs/hico')
        test_dataset = HicoDataset(data_const=data_const, subset='test')
        dataloader = sorted(os.listdir(original_imgs_dir))
        # dataloader = ['HICO_test2015_00000128.jpg']
    else:
        original_imgs_dir = os.path.join(data_const.infer_dir,
                                         'original_imgs/vcoco')
        # original_imgs_dir = './datasets/vcoco/coco/images/val2014'
        save_path = os.path.join(data_const.infer_dir, 'processed_imgs/vcoco')
        test_dataset = VcocoDataset(data_const=data_const,
                                    subset='vcoco_test',
                                    pg_only=False)
        # dataloader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False, collate_fn=vcoco_collate_fn)
        dataloader = sorted(os.listdir(original_imgs_dir))
        dataloader = ['COCO_val2014_000000150361.jpg']

    if not os.path.exists(original_imgs_dir):
        os.makedirs(original_imgs_dir)
    if not os.path.exists(save_path):
        os.mkdir(save_path)
        print('result images will be kept here{}'.format(save_path))

    # ipdb.set_trace()
    for data in tqdm(dataloader):
        # load corresponding data
        # print("Testing on image named {}".format(img))
        if args.dataset == 'hico':
            img = data
            global_id = data.split('.')[0]
            test_data = test_dataset.sample_date(global_id)
            test_data = collate_fn([test_data])
            det_boxes = test_data['det_boxes'][0]
            roi_scores = test_data['roi_scores'][0]
            roi_labels = test_data['roi_labels'][0]
            keypoints = test_data['keypoints'][0]
            edge_labels = test_data['edge_labels']
            node_num = test_data['node_num']
            features = test_data['features']
            spatial_feat = test_data['spatial_feat']
            word2vec = test_data['word2vec']
            pose_normalized = test_data["pose_to_human"]
            pose_to_obj_offset = test_data["pose_to_obj_offset"]
        else:
            # global_id = data['global_id'][0]
            img = data
            global_id = str(int((data.split('.')[0].split('_')[-1])))
            test_data = test_dataset.sample_date(global_id)
            test_data = vcoco_collate_fn([test_data])
            # img = data['img_name'][0][:].astype(np.uint8).tostring().decode('ascii').split("/")[-1]
            # test_data = data
            det_boxes = test_data['det_boxes'][0]
            roi_scores = test_data['roi_scores'][0]
            roi_labels = test_data['roi_labels'][0]
            edge_labels = test_data['edge_labels']
            node_num = test_data['node_num']
            features = test_data['features']
            spatial_feat = test_data['spatial_feat']
            word2vec = test_data['word2vec']
            pose_normalized = test_data["pose_to_human"]
            pose_to_obj_offset = test_data["pose_to_obj_offset"]

        # inference
        pose_to_obj_offset, pose_normalized, features, spatial_feat, word2vec = pose_to_obj_offset.to(
            device), pose_normalized.to(device), features.to(
                device), spatial_feat.to(device), word2vec.to(device)
        outputs, attn, attn_lang = vs_gats(
            node_num, features, spatial_feat, word2vec,
            [roi_labels])  # !NOTE: it is important to set [roi_labels]
        pg_outputs = pgception(pose_normalized, pose_to_obj_offset)
        # action_score = nn.Sigmoid()(outputs+pg_outputs)
        # action_score = action_score.cpu().detach().numpy()
        det_outputs = nn.Sigmoid()(outputs + pg_outputs)
        det_outputs = det_outputs.cpu().detach().numpy()

        # show result
        # import ipdb; ipdb.set_trace()
        if args.dataset == 'hico':
            image = Image.open(
                os.path.join('datasets/hico/images/test2015',
                             img)).convert('RGB')
            image_temp = image.copy()
            gt_img = vis_img(image,
                             det_boxes,
                             roi_labels,
                             roi_scores,
                             edge_labels.cpu().numpy(),
                             score_thresh=0.5)
            det_img = vis_img(image_temp,
                              det_boxes,
                              roi_labels,
                              roi_scores,
                              det_outputs,
                              score_thresh=0.5)
        if args.dataset == 'vcoco':
            image = Image.open(
                os.path.join(data_const.original_image_dir, 'val2014',
                             img)).convert('RGB')
            image_temp = image.copy()
            gt_img = vis_img_vcoco(image,
                                   det_boxes,
                                   roi_labels,
                                   roi_scores,
                                   edge_labels.cpu().numpy(),
                                   score_thresh=0.1)
            det_img = vis_img_vcoco(image_temp,
                                    det_boxes,
                                    roi_labels,
                                    roi_scores,
                                    det_outputs,
                                    score_thresh=0.5)

        # det_img.save('/home/birl/ml_dl_projects/bigjun/hoi/VS_GATs/inference_imgs/original_imgs'+'/'+img)
        det_img.save(save_path + '/' + img.split("/")[-1])