Beispiel #1
0
def main():
    parser = argparse.ArgumentParser(
        description="PyTorch DRG Detection Inference")
    parser.add_argument(
        "--dataset_name",
        default="vcoco_test",
        help="dataset name, default: vcoco_test",
    )
    parser.add_argument(
        "--detection_file",
        help="The path to the final detection pkl file for test",
        default="../output/VCOCO/detection_merged_human_object_app.pkl",
    )

    args = parser.parse_args()

    dataset_name = args.dataset_name
    data = DatasetCatalog.get(dataset_name)
    data_args = data["args"]
    action_dic = json.load(open(data_args['action_index']))
    action_dic_inv = {y: x for x, y in action_dic.items()}
    vcocoeval = VCOCOeval(data_args['vcoco_test_file'], data_args['ann_file'],
                          data_args['vcoco_test_ids_file'])

    vcocoeval._do_eval(args.detection_file, ovr_thresh=0.5)
Beispiel #2
0
def main():
    parser = argparse.ArgumentParser(description="PyTorch DRG Detection Inference")
    parser.add_argument(
        "--dataset_name",
        default="vcoco_test",
        help="dataset name, default: vcoco_test",
    )
    parser.add_argument(
        "--app_detection",
        help="The path to the app detection pkl for test",
        default=None,
    )
    parser.add_argument(
        "--sp_human_detection",
        help="The path to the sp human detection pkl  for test",
        default=None,
    )
    parser.add_argument(
        "--sp_object_detection",
        help="The path to the sp object detection pkl for test",
        default=None,
    )

    args = parser.parse_args()

    ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
    DATA_DIR = os.path.abspath(os.path.join(ROOT_DIR, 'Data'))

    detect_app_dict = pickle.load(open(args.app_detection, "rb"), encoding='latin1')
    detect_human_centric_dict = pickle.load(open(args.sp_human_detection, "rb"), encoding='latin1')
    detect_object_centric_dict = pickle.load(open(args.sp_object_detection, "rb"), encoding='latin1')

    output_folder = os.path.join(ROOT_DIR, 'output')
    mkdir(output_folder)
    output_file = os.path.join(output_folder, 'detection_vcoco_test_human_object_app.pkl')

    dataset_name = args.dataset_name
    data = DatasetCatalog.get(dataset_name)
    data_args = data["args"]
    action_dic = json.load(open(data_args['action_index']))
    action_dic_inv = {y: x for x, y in action_dic.items()}
    vcoco_test_ids = open(data_args['vcoco_test_ids_file'], 'r')
    test_image_id_list = [int(line.rstrip()) for line in vcoco_test_ids]
    vcocoeval = VCOCOeval(data_args['vcoco_test_file'], data_args['ann_file'], data_args['vcoco_test_ids_file'])

    run_test(
        detect_object_centric_dict,
        detect_human_centric_dict,
        detect_app_dict,
        test_image_id_list=test_image_id_list,
        dataset_name=dataset_name,
        action_dic_inv=action_dic_inv,
        output_file=output_file
    )

    vcocoeval._do_eval(output_file, ovr_thresh=0.5)
Beispiel #3
0
def main():
    #     apply_prior   prior_mask
    # 0        -             -
    # 1        Y             -
    # 2        -             Y
    # 3        Y             Y
    parser = argparse.ArgumentParser(
        description="PyTorch Object Detection Inference")
    parser.add_argument(
        "--config-file",
        default="",
        metavar="FILE",
        help="path to config file",
    )
    parser.add_argument("--local_rank", type=int, default=0)
    parser.add_argument(
        "--ckpt",
        help=
        "The path to the checkpoint for test, default is the latest checkpoint.",
        default=None,
    )
    parser.add_argument('--num_iteration',
                        dest='num_iteration',
                        help='Specify which weight to load',
                        default=-1,
                        type=int)
    parser.add_argument('--object_thres',
                        dest='object_thres',
                        help='Object threshold',
                        default=0.4,
                        type=float)  # used to be 0.4 or 0.05
    parser.add_argument('--human_thres',
                        dest='human_thres',
                        help='Human threshold',
                        default=0.6,
                        type=float)
    parser.add_argument('--prior_flag',
                        dest='prior_flag',
                        help='whether use prior_flag',
                        default=1,
                        type=int)
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )

    args = parser.parse_args()

    num_gpus = int(
        os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
    distributed = num_gpus > 1 and torch.cuda.is_available()

    if distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend="nccl",
                                             init_method="env://")
        synchronize()

    print('prior flag: {}'.format(args.prior_flag))

    ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
    # DATA_DIR = os.path.abspath(os.path.join(ROOT_DIR, 'Data'))
    args.config_file = os.path.join(ROOT_DIR, args.config_file)

    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()

    save_dir = ""
    logger = setup_logger("DRG", save_dir, get_rank())
    logger.info("Using {} GPUs".format(num_gpus))
    logger.info(cfg)

    logger.info("Collecting env info (might take some time)")
    logger.info("\n" + collect_env_info())

    model = build_detection_model(cfg)
    # model.to(cfg.MODEL.DEVICE)
    device = torch.device(
        "cuda") if torch.cuda.is_available() else torch.device("cpu")
    model.to(device)

    # Initialize mixed-precision if necessary
    use_mixed_precision = cfg.DTYPE == 'float16'
    amp_handle = amp.init(enabled=use_mixed_precision, verbose=cfg.AMP_VERBOSE)

    output_dir = cfg.OUTPUT_DIR
    checkpointer = DetectronCheckpointer(cfg, model, save_dir=output_dir)

    if args.num_iteration != -1:
        args.ckpt = os.path.join(cfg.OUTPUT_DIR,
                                 'model_%07d.pth' % args.num_iteration)
    ckpt = cfg.MODEL.WEIGHT if args.ckpt is None else args.ckpt
    logger.info("Testing checkpoint {}".format(ckpt))
    _ = checkpointer.load(ckpt, use_latest=args.ckpt is None)

    # iou_types = ("bbox",)
    output_folders = [None] * len(cfg.DATASETS.TEST)
    dataset_names = cfg.DATASETS.TEST
    if cfg.OUTPUT_DIR:
        for idx, dataset_name in enumerate(dataset_names):
            if args.num_iteration != -1:
                output_folder = os.path.join(cfg.OUTPUT_DIR, "inference_ho",
                                             dataset_name,
                                             "model_%07d" % args.num_iteration)
            else:
                output_folder = os.path.join(cfg.OUTPUT_DIR, "inference_ho",
                                             dataset_name)
            mkdir(output_folder)
            output_folders[idx] = output_folder

    opt = {}
    opt['word_dim'] = 300
    opt['use_thres_dic'] = 1
    for output_folder, dataset_name in zip(output_folders, dataset_names):
        data = DatasetCatalog.get(dataset_name)
        data_args = data["args"]
        test_detection = pickle.load(open(data_args['test_detection_file'],
                                          "rb"),
                                     encoding='latin1')
        word_embeddings = pickle.load(open(data_args['word_embedding_file'],
                                           "rb"),
                                      encoding='latin1')
        opt['thres_dic'] = pickle.load(open(data_args['threshold_dic'], "rb"),
                                       encoding='latin1')
        output_file = os.path.join(output_folder, 'detection_times.pkl')
        output_file_human = os.path.join(output_folder, 'detection_human.pkl')
        output_file_object = os.path.join(output_folder,
                                          'detection_object.pkl')
        # hico_folder = os.path.join(output_folder, 'HICO')
        output_map_folder = os.path.join(output_folder, 'map')

        logger.info("Output will be saved in {}".format(output_file))
        logger.info("Start evaluation on {} dataset.".format(dataset_name))

        run_test(model,
                 dataset_name=dataset_name,
                 test_detection=test_detection,
                 word_embeddings=word_embeddings,
                 output_file=output_file,
                 output_file_human=output_file_human,
                 output_file_object=output_file_object,
                 object_thres=args.object_thres,
                 human_thres=args.human_thres,
                 device=device,
                 cfg=cfg,
                 opt=opt)

        # Generate_HICO_detection(output_file, hico_folder)
        compute_hico_map(output_map_folder, output_file, 'test')
Beispiel #4
0
def main():
    parser = argparse.ArgumentParser(
        description="PyTorch Object Detection Inference")
    parser.add_argument(
        "--config-file",
        default="configs/e2e_faster_rcnn_R_50_FPN_1x.yaml",
        metavar="FILE",
        help="path to config file",
    )
    parser.add_argument("--local_rank", type=int, default=0)
    parser.add_argument(
        "--ckpt",
        help=
        "The path to the checkpoint for test, default is the latest checkpoint.",
        default=None,
    )
    parser.add_argument('--num_iteration',
                        dest='num_iteration',
                        help='Specify which weight to load',
                        default=-1,
                        type=int)
    parser.add_argument('--object_thres',
                        dest='object_thres',
                        help='Object threshold',
                        default=0.1,
                        type=float)  # used to be 0.4 or 0.05
    parser.add_argument('--human_thres',
                        dest='human_thres',
                        help='Human threshold',
                        default=0.8,
                        type=float)
    parser.add_argument('--prior_flag',
                        dest='prior_flag',
                        help='whether use prior_flag',
                        default=1,
                        type=int)
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )

    args = parser.parse_args()

    num_gpus = int(
        os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
    distributed = num_gpus > 1 and torch.cuda.is_available()

    if distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend="nccl",
                                             init_method="env://")
        synchronize()

    ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
    # DATA_DIR = os.path.abspath(os.path.join(ROOT_DIR, 'Data'))
    args.config_file = os.path.join(ROOT_DIR, args.config_file)

    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()

    save_dir = ""
    logger = setup_logger("DRG", save_dir, get_rank())
    logger.info("Using {} GPUs".format(num_gpus))
    logger.info(cfg)

    logger.info("Collecting env info (might take some time)")
    logger.info("\n" + collect_env_info())

    model = build_detection_model(cfg)
    # model.to(cfg.MODEL.DEVICE)
    device = torch.device(
        "cuda") if torch.cuda.is_available() else torch.device("cpu")
    model.to(device)

    # Initialize mixed-precision if necessary
    use_mixed_precision = cfg.DTYPE == 'float16'
    amp_handle = amp.init(enabled=use_mixed_precision, verbose=cfg.AMP_VERBOSE)

    output_dir = cfg.OUTPUT_DIR
    checkpointer = DetectronCheckpointer(cfg, model, save_dir=output_dir)

    if args.num_iteration != -1:
        args.ckpt = os.path.join(cfg.OUTPUT_DIR,
                                 'model_%07d.pth' % args.num_iteration)
    ckpt = cfg.MODEL.WEIGHT if args.ckpt is None else args.ckpt
    logger.info("Testing checkpoint {}".format(ckpt))
    _ = checkpointer.load(ckpt, use_latest=args.ckpt is None)

    # iou_types = ("bbox",)
    output_folders = [None] * len(cfg.DATASETS.TEST)
    dataset_names = cfg.DATASETS.TEST
    if cfg.OUTPUT_DIR:
        for idx, dataset_name in enumerate(dataset_names):
            if args.num_iteration != -1:
                output_folder = os.path.join(cfg.OUTPUT_DIR, "inference_ho",
                                             dataset_name,
                                             "model_%07d" % args.num_iteration)
            else:
                output_folder = os.path.join(cfg.OUTPUT_DIR, "inference_ho",
                                             dataset_name)
            mkdir(output_folder)
            output_folders[idx] = output_folder

    opt = {}
    opt['word_dim'] = 300
    for output_folder, dataset_name in zip(output_folders, dataset_names):
        data = DatasetCatalog.get(dataset_name)
        data_args = data["args"]
        im_dir = data_args['im_dir']
        test_detection = pickle.load(open(data_args['test_detection_file'],
                                          "rb"),
                                     encoding='latin1')
        prior_mask = pickle.load(open(data_args['prior_mask'], "rb"),
                                 encoding='latin1')
        action_dic = json.load(open(data_args['action_index']))
        action_dic_inv = {y: x for x, y in action_dic.items()}
        vcoco_test_ids = open(data_args['vcoco_test_ids_file'], 'r')
        test_image_id_list = [int(line.rstrip()) for line in vcoco_test_ids]
        vcocoeval = VCOCOeval(data_args['vcoco_test_file'],
                              data_args['ann_file'],
                              data_args['vcoco_test_ids_file'])
        word_embeddings = pickle.load(open(data_args['word_embedding_file'],
                                           "rb"),
                                      encoding='latin1')
        output_file = os.path.join(output_folder, 'detection.pkl')
        output_dict_file = os.path.join(
            output_folder, 'detection_app_{}_new.pkl'.format(dataset_name))

        logger.info("Output will be saved in {}".format(output_file))
        logger.info("Start evaluation on {} dataset({} images).".format(
            dataset_name, len(test_image_id_list)))

        run_test(model,
                 dataset_name=dataset_name,
                 im_dir=im_dir,
                 test_detection=test_detection,
                 word_embeddings=word_embeddings,
                 test_image_id_list=test_image_id_list,
                 prior_mask=prior_mask,
                 action_dic_inv=action_dic_inv,
                 output_file=output_file,
                 output_dict_file=output_dict_file,
                 object_thres=args.object_thres,
                 human_thres=args.human_thres,
                 prior_flag=args.prior_flag,
                 device=device,
                 cfg=cfg)

        synchronize()

        vcocoeval._do_eval(output_file, ovr_thresh=0.5)
    def inference(self,
                  colors_pred,
                  add_class_names=None,
                  save_path=None,
                  save_independently=None,
                  show_ground_truth=True):
        """
        Do Inference, either show the boxes or the masks
        """

        # load the config
        paths_catalog = import_file("maskrcnn_benchmark.config.paths_catalog",
                                    cfg.PATHS_CATALOG, True)
        DatasetCatalog = paths_catalog.DatasetCatalog
        test_datasets = DatasetCatalog.get(cfg.DATASETS.TEST[0])
        img_dir = test_datasets['args']['root']
        anno_file = test_datasets['args']['ann_file']
        data = json.load(open(anno_file))
        coco = COCO(anno_file)
        predis = []
        filenames = []

        # iterate through data
        for i, image in enumerate(data['images']):

            pil_img = Image.open(img_dir + '/' + image['file_name'])
            filenames.append(image['file_name'])
            img = np.array(pil_img)[:, :, [0, 1, 2]]

            # get ground truth boxes or masks
            anno = [
                obj for obj in data['annotations']
                if obj['image_id'] == image['id']
            ]
            classes = [
                obj['category_id'] for obj in data['annotations']
                if obj['image_id'] == image['id']
            ]
            json_category_id_to_contiguous_id = {
                v: i + 1
                for i, v in enumerate(coco.getCatIds())
            }
            classes = [json_category_id_to_contiguous_id[c] for c in classes]
            classes = torch.tensor(classes)
            boxes = [obj['bbox'] for obj in anno]
            boxes = torch.as_tensor(boxes).reshape(-1, 4)
            target = BoxList(boxes, pil_img.size, mode='xywh').convert('xyxy')
            target.add_field('labels', classes)
            masks = [obj["segmentation"] for obj in anno]
            masks = SegmentationMask(masks, img.size)
            target.add_field("masks", masks)
            target = target.clip_to_image(remove_empty=True)

            # these are the ground truth polygons
            polygons = []
            color_rgb = [[255, 101, 80], [255, 55, 55], [255, 255, 61],
                         [255, 128, 0]]
            colors = {
                i: [s / 255 for s in color]
                for i, color in enumerate(color_rgb)
            }
            color = [colors[i.item()] for i in classes]

            # ground truth boxes
            boxes = []

            polys = vars(target)['extra_fields']['masks']
            for polygon in polys:
                try:
                    tenso = vars(polygon)['polygons'][0]
                except KeyError:
                    continue

                poly1 = tenso.numpy()
                poly = poly1.reshape((int(len(poly1) / 2), 2))
                polygons.append(Polygon(poly))

            xywh_tar = target.convert("xywh")
            for box in vars(xywh_tar)['bbox'].numpy():

                rect = Rectangle((box[0], box[1]), box[2], box[3])
                boxes.append(rect)

            # compute predictions
            predictions = self.compute_prediction(img)
            predis.append(predictions)
            top_predictions = self.select_top_predictions(predictions)

            polygons_predicted, colors_prediction = self.overlay_mask(
                img, top_predictions, colors_pred, inference=True)
            #print(colors_prediction)

            fig = plt.figure()
            ax = fig.add_subplot(1, 1, 1)

            ax.imshow(Image.fromarray(img))
            ax.axis('off')

            # this is for ground thruth
            if show_ground_truth == True:
                p = PatchCollection(polygons,
                                    facecolor='none',
                                    linewidths=0,
                                    alpha=0.4)
                ax.add_collection(p)
                p = PatchCollection(polygons,
                                    facecolor='none',
                                    edgecolors=color,
                                    linewidths=2)
                ax.add_collection(p)

            # this is for prediction
            ppd = PatchCollection(polygons_predicted,
                                  facecolor='none',
                                  linewidths=0,
                                  alpha=0.4)
            ax.add_collection(ppd)
            ppd = PatchCollection(polygons_predicted,
                                  facecolor='none',
                                  edgecolors=colors_prediction,
                                  linewidths=2)
            ax.add_collection(ppd)

            plt.savefig(save_path + image['file_name'],
                        dpi=200,
                        bbox_inches='tight',
                        pad_inches=0)

            plt.show()

        dic = {}
        for i in range(len(filenames)):
            dic[filenames[i]] = predis[i]
        return dic
Beispiel #6
0
def train(cfg, local_rank, distributed, logger):
    if is_main_process():
        wandb.init(project='scene-graph',
                   entity='sgg-speaker-listener',
                   config=cfg.LISTENER)
    debug_print(logger, 'prepare training')

    model = build_detection_model(cfg)
    listener = build_listener(cfg)

    speaker_listener = SpeakerListener(model,
                                       listener,
                                       cfg,
                                       is_joint=cfg.LISTENER.JOINT)
    if is_main_process():
        wandb.watch(listener)

    debug_print(logger, 'end model construction')

    # modules that should be always set in eval mode
    # their eval() method should be called after model.train() is called
    eval_modules = (
        model.rpn,
        model.backbone,
        model.roi_heads.box,
    )

    fix_eval_modules(eval_modules)

    # NOTE, we slow down the LR of the layers start with the names in slow_heads
    if cfg.MODEL.ROI_RELATION_HEAD.PREDICTOR == "IMPPredictor":
        slow_heads = [
            "roi_heads.relation.box_feature_extractor",
            "roi_heads.relation.union_feature_extractor.feature_extractor",
        ]
    else:
        slow_heads = []

    # load pretrain layers to new layers
    load_mapping = {
        "roi_heads.relation.box_feature_extractor":
        "roi_heads.box.feature_extractor",
        "roi_heads.relation.union_feature_extractor.feature_extractor":
        "roi_heads.box.feature_extractor"
    }

    if cfg.MODEL.ATTRIBUTE_ON:
        load_mapping[
            "roi_heads.relation.att_feature_extractor"] = "roi_heads.attribute.feature_extractor"
        load_mapping[
            "roi_heads.relation.union_feature_extractor.att_feature_extractor"] = "roi_heads.attribute.feature_extractor"

    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)
    listener.to(device)

    num_gpus = int(
        os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
    num_batch = cfg.SOLVER.IMS_PER_BATCH

    optimizer = make_optimizer(cfg,
                               model,
                               logger,
                               slow_heads=slow_heads,
                               slow_ratio=10.0,
                               rl_factor=float(num_batch))
    listener_optimizer = make_listener_optimizer(cfg, listener)
    scheduler = make_lr_scheduler(cfg, optimizer, logger)
    listener_scheduler = None
    debug_print(logger, 'end optimizer and schedule')

    if cfg.LISTENER.JOINT:
        speaker_listener_optimizer = make_speaker_listener_optimizer(
            cfg, speaker_listener.speaker, speaker_listener.listener)

    # Initialize mixed-precision training
    use_mixed_precision = cfg.DTYPE == "float16"
    amp_opt_level = 'O1' if use_mixed_precision else 'O0'

    if cfg.LISTENER.JOINT:
        speaker_listener, speaker_listener_optimizer = amp.initialize(
            speaker_listener, speaker_listener_optimizer, opt_level='O0')
    else:
        speaker_listener, listener_optimizer = amp.initialize(
            speaker_listener, listener_optimizer, opt_level='O0')

    #listener, listener_optimizer = amp.initialize(listener, listener_optimizer, opt_level='O0')
    #[model, listener], [optimizer, listener_optimizer] = amp.initialize([model, listener], [optimizer, listener_optimizer], opt_level='O1', loss_scale=1)
    #model = amp.initialize(model, opt_level='O1')

    if distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[local_rank],
            output_device=local_rank,
            # this should be removed if we update BatchNorm stats
            broadcast_buffers=False,
            find_unused_parameters=True,
        )

        listener = torch.nn.parallel.DistributedDataParallel(
            listener,
            device_ids=[local_rank],
            output_device=local_rank,
            # this should be removed if we update BatchNorm stats
            broadcast_buffers=False,
            find_unused_parameters=True,
        )

    debug_print(logger, 'end distributed')
    arguments = {}
    arguments["iteration"] = 0

    output_dir = cfg.OUTPUT_DIR
    listener_dir = cfg.LISTENER_DIR
    save_to_disk = get_rank() == 0

    speaker_checkpointer = DetectronCheckpointer(cfg,
                                                 model,
                                                 optimizer,
                                                 scheduler,
                                                 output_dir,
                                                 save_to_disk,
                                                 custom_scheduler=True)

    listener_checkpointer = Checkpointer(listener,
                                         optimizer=listener_optimizer,
                                         save_dir=listener_dir,
                                         save_to_disk=save_to_disk,
                                         custom_scheduler=False)

    speaker_listener.add_listener_checkpointer(listener_checkpointer)
    speaker_listener.add_speaker_checkpointer(speaker_checkpointer)

    speaker_listener.load_listener()
    speaker_listener.load_speaker(load_mapping=load_mapping)
    debug_print(logger, 'end load checkpointer')
    train_data_loader = make_data_loader(cfg,
                                         mode='train',
                                         is_distributed=distributed,
                                         start_iter=arguments["iteration"],
                                         ret_images=True)
    val_data_loaders = make_data_loader(cfg,
                                        mode='val',
                                        is_distributed=distributed,
                                        ret_images=True)

    debug_print(logger, 'end dataloader')
    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD

    if cfg.SOLVER.PRE_VAL:
        logger.info("Validate before training")
        #output =  run_val(cfg, model, listener, val_data_loaders, distributed, logger)
        #print('OUTPUT: ', output)
        #(sg_loss, img_loss, sg_acc, img_acc) = output

    logger.info("Start training")
    meters = MetricLogger(delimiter="  ")
    max_iter = len(train_data_loader)
    start_iter = arguments["iteration"]
    start_training_time = time.time()
    end = time.time()

    print_first_grad = True

    listener_loss_func = torch.nn.MarginRankingLoss(margin=1, reduction='none')
    mistake_saver = None
    if is_main_process():
        ds_catalog = DatasetCatalog()
        dict_file_path = os.path.join(
            ds_catalog.DATA_DIR,
            ds_catalog.DATASETS['VG_stanford_filtered_with_attribute']
            ['dict_file'])
        ind_to_classes, ind_to_predicates = load_vg_info(dict_file_path)
        ind_to_classes = {k: v for k, v in enumerate(ind_to_classes)}
        ind_to_predicates = {k: v for k, v in enumerate(ind_to_predicates)}
        print('ind to classes:', ind_to_classes, '/n ind to predicates:',
              ind_to_predicates)
        mistake_saver = MistakeSaver(
            '/Scene-Graph-Benchmark.pytorch/filenames_masked', ind_to_classes,
            ind_to_predicates)

    #is_printed = False
    while True:
        try:
            listener_iteration = 0
            for iteration, (images, targets,
                            image_ids) in enumerate(train_data_loader,
                                                    start_iter):

                if cfg.LISTENER.JOINT:
                    speaker_listener_optimizer.zero_grad()
                else:
                    listener_optimizer.zero_grad()

                #print(f'ITERATION NUMBER: {iteration}')
                if any(len(target) < 1 for target in targets):
                    logger.error(
                        f"Iteration={iteration + 1} || Image Ids used for training {_} || targets Length={[len(target) for target in targets]}"
                    )
                if len(images) <= 1:
                    continue

                data_time = time.time() - end
                iteration = iteration + 1
                listener_iteration += 1
                arguments["iteration"] = iteration
                model.train()
                fix_eval_modules(eval_modules)
                images_list = deepcopy(images)
                images_list = to_image_list(
                    images_list, cfg.DATALOADER.SIZE_DIVISIBILITY).to(device)

                for i in range(len(images)):
                    images[i] = images[i].unsqueeze(0)
                    images[i] = F.interpolate(images[i],
                                              size=(224, 224),
                                              mode='bilinear',
                                              align_corners=False)
                    images[i] = images[i].squeeze()

                images = torch.stack(images).to(device)
                #images.requires_grad_()

                targets = [target.to(device) for target in targets]

                speaker_loss_dict = {}
                if not cfg.LISTENER.JOINT:
                    score_matrix = speaker_listener(images_list, targets,
                                                    images)
                else:
                    score_matrix, _, speaker_loss_dict = speaker_listener(
                        images_list, targets, images)

                speaker_summed_losses = sum(
                    loss for loss in speaker_loss_dict.values())

                # reduce losses over all GPUs for logging purposes
                if not not cfg.LISTENER.JOINT:
                    speaker_loss_dict_reduced = reduce_loss_dict(
                        speaker_loss_dict)
                    speaker_losses_reduced = sum(
                        loss for loss in speaker_loss_dict_reduced.values())
                    speaker_losses_reduced /= num_gpus

                    if is_main_process():
                        wandb.log(
                            {"Train Speaker Loss": speaker_losses_reduced},
                            listener_iteration)

                listener_loss = 0
                gap_reward = 0
                avg_acc = 0
                num_correct = 0

                score_matrix = score_matrix.to(device)
                # fill loss matrix
                loss_matrix = torch.zeros((2, images.size(0), images.size(0)),
                                          device=device)
                # sg centered scores
                for true_index in range(loss_matrix.size(1)):
                    row_score = score_matrix[true_index]
                    (true_scores, predicted_scores,
                     binary) = format_scores(row_score, true_index, device)
                    loss_vec = listener_loss_func(true_scores,
                                                  predicted_scores, binary)
                    loss_matrix[0][true_index] = loss_vec
                # image centered scores
                transposted_score_matrix = score_matrix.t()
                for true_index in range(loss_matrix.size(1)):
                    row_score = transposted_score_matrix[true_index]
                    (true_scores, predicted_scores,
                     binary) = format_scores(row_score, true_index, device)
                    loss_vec = listener_loss_func(true_scores,
                                                  predicted_scores, binary)
                    loss_matrix[1][true_index] = loss_vec

                print('iteration:', listener_iteration)
                sg_acc = 0
                img_acc = 0
                # calculate accuracy
                for i in range(loss_matrix.size(1)):
                    temp_sg_acc = 0
                    temp_img_acc = 0
                    for j in range(loss_matrix.size(2)):
                        if loss_matrix[0][i][i] > loss_matrix[0][i][j]:
                            temp_sg_acc += 1
                        else:
                            if cfg.LISTENER.HTML:
                                if is_main_process(
                                ) and listener_iteration >= 600 and listener_iteration % 25 == 0 and i != j:
                                    detached_sg_i = (sgs[i][0].detach(),
                                                     sgs[i][1],
                                                     sgs[i][2].detach())
                                    detached_sg_j = (sgs[j][0].detach(),
                                                     sgs[j][1],
                                                     sgs[j][2].detach())
                                    mistake_saver.add_mistake(
                                        (image_ids[i], image_ids[j]),
                                        (detached_sg_i, detached_sg_j),
                                        listener_iteration, 'SG')
                        if loss_matrix[1][i][i] > loss_matrix[1][j][i]:
                            temp_img_acc += 1
                        else:
                            if cfg.LISTENER.HTML:
                                if is_main_process(
                                ) and listener_iteration >= 600 and listener_iteration % 25 == 0 and i != j:
                                    detached_sg_i = (sgs[i][0].detach(),
                                                     sgs[i][1],
                                                     sgs[i][2].detach())
                                    detached_sg_j = (sgs[j][0].detach(),
                                                     sgs[j][1],
                                                     sgs[j][2].detach())
                                    mistake_saver.add_mistake(
                                        (image_ids[i], image_ids[j]),
                                        (detached_sg_i, detached_sg_j),
                                        listener_iteration, 'IMG')

                    temp_sg_acc = temp_sg_acc * 100 / (loss_matrix.size(1) - 1)
                    temp_img_acc = temp_img_acc * 100 / (loss_matrix.size(1) -
                                                         1)
                    sg_acc += temp_sg_acc
                    img_acc += temp_img_acc
                if cfg.LISTENER.HTML:
                    if is_main_process(
                    ) and listener_iteration % 100 == 0 and listener_iteration >= 600:
                        mistake_saver.toHtml('/www')

                sg_acc /= loss_matrix.size(1)
                img_acc /= loss_matrix.size(1)

                avg_sg_acc = torch.tensor([sg_acc]).to(device)
                avg_img_acc = torch.tensor([img_acc]).to(device)
                # reduce acc over all gpus
                avg_acc = {'sg_acc': avg_sg_acc, 'img_acc': avg_img_acc}
                avg_acc_reduced = reduce_loss_dict(avg_acc)

                sg_acc = sum(acc for acc in avg_acc_reduced['sg_acc'])
                img_acc = sum(acc for acc in avg_acc_reduced['img_acc'])

                # log acc to wadb
                if is_main_process():
                    wandb.log({
                        "Train SG Accuracy": sg_acc.item(),
                        "Train IMG Accuracy": img_acc.item()
                    })

                sg_loss = 0
                img_loss = 0

                for i in range(loss_matrix.size(0)):
                    for j in range(loss_matrix.size(1)):
                        loss_matrix[i][j][j] = 0.

                for i in range(loss_matrix.size(1)):
                    sg_loss += torch.max(loss_matrix[0][i])
                    img_loss += torch.max(loss_matrix[1][:][i])

                sg_loss = sg_loss / loss_matrix.size(1)
                img_loss = img_loss / loss_matrix.size(1)
                sg_loss = sg_loss.to(device)
                img_loss = img_loss.to(device)

                loss_dict = {'sg_loss': sg_loss, 'img_loss': img_loss}

                losses = sum(loss for loss in loss_dict.values())

                # reduce losses over all GPUs for logging purposes
                loss_dict_reduced = reduce_loss_dict(loss_dict)
                sg_loss_reduced = loss_dict_reduced['sg_loss']
                img_loss_reduced = loss_dict_reduced['img_loss']
                if is_main_process():
                    wandb.log({"Train SG Loss": sg_loss_reduced})
                    wandb.log({"Train IMG Loss": img_loss_reduced})

                losses_reduced = sum(loss
                                     for loss in loss_dict_reduced.values())
                meters.update(loss=losses_reduced, **loss_dict_reduced)

                losses = losses + speaker_summed_losses * cfg.LISTENER.LOSS_COEF
                # Note: If mixed precision is not used, this ends up doing nothing
                # Otherwise apply loss scaling for mixed-precision recipe
                #losses.backward()
                if not cfg.LISTENER.JOINT:
                    with amp.scale_loss(losses,
                                        listener_optimizer) as scaled_losses:
                        scaled_losses.backward()
                else:
                    with amp.scale_loss(
                            losses,
                            speaker_listener_optimizer) as scaled_losses:
                        scaled_losses.backward()

                verbose = (iteration % cfg.SOLVER.PRINT_GRAD_FREQ
                           ) == 0 or print_first_grad  # print grad or not
                print_first_grad = False
                #clip_grad_value([(n, p) for n, p in listener.named_parameters() if p.requires_grad], cfg.LISTENER.CLIP_VALUE, logger=logger, verbose=True, clip=True)
                if not cfg.LISTENER.JOINT:
                    listener_optimizer.step()
                else:
                    speaker_listener_optimizer.step()

                batch_time = time.time() - end
                end = time.time()
                meters.update(time=batch_time, data=data_time)

                eta_seconds = meters.time.global_avg * (max_iter - iteration)
                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

                if cfg.LISTENER.JOINT:
                    if iteration % 200 == 0 or iteration == max_iter:
                        logger.info(
                            meters.delimiter.join([
                                "eta: {eta}",
                                "iter: {iter}",
                                "{meters}",
                                "lr: {lr:.6f}",
                                "max mem: {memory:.0f}",
                            ]).format(
                                eta=eta_string,
                                iter=iteration,
                                meters=str(meters),
                                lr=speaker_listener_optimizer.param_groups[-1]
                                ["lr"],
                                memory=torch.cuda.max_memory_allocated() /
                                1024.0 / 1024.0,
                            ))
                else:
                    if iteration % 200 == 0 or iteration == max_iter:
                        logger.info(
                            meters.delimiter.join([
                                "eta: {eta}",
                                "iter: {iter}",
                                "{meters}",
                                "lr: {lr:.6f}",
                                "max mem: {memory:.0f}",
                            ]).format(
                                eta=eta_string,
                                iter=iteration,
                                meters=str(meters),
                                lr=listener_optimizer.param_groups[-1]["lr"],
                                memory=torch.cuda.max_memory_allocated() /
                                1024.0 / 1024.0,
                            ))

                if iteration % checkpoint_period == 0:
                    """
                    print('Model before save')
                    print('****************************')
                    print(listener.gnn.conv1.node_model.node_mlp_1[0].weight)
                    print('****************************')
                    """
                    if not cfg.LISTENER.JOINT:
                        listener_checkpointer.save(
                            "model_{:07d}".format(listener_iteration),
                            amp=amp.state_dict())
                    else:
                        speaker_checkpointer.save(
                            "model_speaker{:07d}".format(iteration))
                        listener_checkpointer.save(
                            "model_listenr{:07d}".format(listener_iteration),
                            amp=amp.state_dict())
                if iteration == max_iter:
                    if not cfg.LISTENER.JOINT:
                        listener_checkpointer.save(
                            "model_{:07d}".format(listener_iteration),
                            amp=amp.state_dict())
                    else:
                        speaker_checkpointer.save(
                            "model_{:07d}".format(iteration))
                        listener_checkpointer.save(
                            "model_{:07d}".format(listener_iteration),
                            amp=amp.state_dict())

                val_result = None  # used for scheduler updating
                if cfg.SOLVER.TO_VAL and iteration % cfg.SOLVER.VAL_PERIOD == 0:
                    logger.info("Start validating")
                    val_result = run_val(cfg, model, listener,
                                         val_data_loaders, distributed, logger)
                    (sg_loss, img_loss, sg_acc, img_acc,
                     speaker_val) = val_result

                    if is_main_process():
                        wandb.log({
                            "Validation SG Accuracy": sg_acc,
                            "Validation IMG Accuracy": img_acc,
                            "Validation SG Loss": sg_loss,
                            "Validation IMG Loss": img_loss,
                            "Validation Speaker": speaker_val,
                        })

                    #logger.info("Validation Result: %.4f" % val_result)
        except Exception as err:
            raise (err)
            print('Dataset finished, creating new')
            train_data_loader = make_data_loader(
                cfg,
                mode='train',
                is_distributed=distributed,
                start_iter=arguments["iteration"],
                ret_images=True)

    total_training_time = time.time() - start_training_time
    total_time_str = str(datetime.timedelta(seconds=total_training_time))
    logger.info("Total training time: {} ({:.4f} s / it)".format(
        total_time_str, total_training_time / (max_iter)))
    return listener
def main():
    args = parse_args()

    vis = visdom.Visdom()

    # Load configuration file.
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(['MODEL.MASK_ON', False])
    cfg.merge_from_list(['TEST.IMS_PER_BATCH', 1])

    # Build model.
    model = build_detection_model(cfg).cuda()
    model.eval()

    # Load weights from checkpoint.
    checkpointer = Checkpointer(model)
    checkpointer.load(args.weights_file)

    # Build pre-processing transform.
    transforms = T.Compose([
        T.ToPILImage(),
        T.Resize(cfg.INPUT.MIN_SIZE_TEST),
        T.ToTensor(),
        T.Lambda(lambda x: x * 255),
        T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD)
    ])

    # Load dataset.
    attrs = DatasetCatalog.get(args.dataset)['args']
    img_dir = attrs['root']
    ann_file = attrs['ann_file']
    with open(ann_file, 'r') as f:
        dataset = json.load(f)

    images = dataset['images']
    categories = [category['name'] for category in dataset['categories']]
    categories.insert(0, '__background')

    for img in np.random.choice(images, args.num_images):
        # Load image in OpenCV format.
        image_file = os.path.join(img_dir, img['file_name'])
        image = cv2.imread(image_file)
        print(image_file)

        # Apply pre-processing to image.
        image_tensor = transforms(image)
        image_list = to_image_list(image_tensor,
                                   cfg.DATALOADER.SIZE_DIVISIBILITY)
        image_list = image_list.to('cuda')

        # Compute predictions.
        with torch.no_grad():
            predictions = model(image_list)
        predictions = predictions[0]
        predictions = predictions.to('cpu')

        # Reshape prediction into the original image size.
        height, width = image.shape[:-1]
        predictions = predictions.resize((width, height))

        # Select top predictions.
        keep = torch.nonzero(
            predictions.get_field("scores") > args.confidence_thresh).squeeze(
                1)
        predictions = predictions[keep]

        scores = predictions.get_field('scores').tolist()
        labels = predictions.get_field('labels').tolist()
        labels = [categories[i] for i in labels if i < len(categories)]
        boxes = predictions.bbox

        # Compose result image.
        result = image.copy()
        template = '{}: {:.2f}'
        for box, score, label in zip(boxes, scores, labels):
            box = box.to(torch.int64)
            top_left, bottom_right = box[:2].tolist(), box[2:].tolist()
            cv2.rectangle(result, tuple(top_left), tuple(bottom_right),
                          (0, 255, 0), 1)
            s = template.format(label, score)
            cv2.putText(result, s, tuple(top_left), cv2.FONT_HERSHEY_SIMPLEX,
                        .5, (255, 255, 255), 1)

        # Visualize image in Visdom.
        vis.image(result[:, :, ::-1].transpose((2, 0, 1)),
                  opts=dict(title=image_file))
Beispiel #8
0
def main():
    #     apply_prior   prior_mask
    # 0        -             -
    # 1        Y             -
    # 2        -             Y
    # 3        Y             Y
    parser = argparse.ArgumentParser(
        description="PyTorch Object Detection Inference")
    parser.add_argument(
        "--config-file",
        default="",
        metavar="FILE",
        help="path to config file",
    )
    parser.add_argument("--local_rank", type=int, default=0)
    parser.add_argument(
        "--ckpt",
        help=
        "The path to the checkpoint for test, default is the latest checkpoint.",
        default=None,
    )
    parser.add_argument(
        "--dataset_name",
        help="vcoco_test or vcoco_val_test",
        default=None,
    )
    parser.add_argument('--num_iteration',
                        dest='num_iteration',
                        help='Specify which weight to load',
                        default=-1,
                        type=int)
    parser.add_argument('--object_thres',
                        dest='object_thres',
                        help='Object threshold',
                        default=0.1,
                        type=float)  # used to be 0.4 or 0.05
    parser.add_argument('--human_thres',
                        dest='human_thres',
                        help='Human threshold',
                        default=0.8,
                        type=float)
    parser.add_argument('--prior_flag',
                        dest='prior_flag',
                        help='whether use prior_flag',
                        default=1,
                        type=int)
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )

    args = parser.parse_args()

    # Reczne wpisanie wartosci
    args.config_file = "my_configs/VCOCO_app_only.yaml"

    ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
    args.config_file = os.path.join(ROOT_DIR, args.config_file)

    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()

    model = build_detection_model(cfg)
    device = torch.device(
        "cuda") if torch.cuda.is_available() else torch.device("cpu")
    model.to(device)

    # # Initialize mixed-precision if necessary
    use_mixed_precision = cfg.DTYPE == 'float16'
    amp_handle = amp.init(enabled=use_mixed_precision, verbose=cfg.AMP_VERBOSE)

    output_dir = cfg.OUTPUT_DIR
    checkpointer = DetectronCheckpointer(cfg, model, save_dir=output_dir)

    if args.num_iteration != -1:
        args.ckpt = os.path.join(cfg.OUTPUT_DIR,
                                 'model_%07d.pth' % args.num_iteration)
    ckpt = cfg.MODEL.WEIGHT if args.ckpt is None else args.ckpt
    ckpt = "DRG/" + ckpt
    _ = checkpointer.load(ckpt, use_latest=args.ckpt is None)

    # iou_types = ("bbox",)
    output_folders = [None] * len(cfg.DATASETS.TEST)
    dataset_names = cfg.DATASETS.TEST
    if cfg.OUTPUT_DIR:
        for idx, dataset_name in enumerate(dataset_names):
            if args.num_iteration != -1:
                output_folder = os.path.join(cfg.OUTPUT_DIR, "inference_ho",
                                             dataset_name,
                                             "model_%07d" % args.num_iteration)
            else:
                output_folder = os.path.join(cfg.OUTPUT_DIR, "inference_ho",
                                             dataset_name)
            mkdir(output_folder)
            output_folders[idx] = output_folder

    opt = {}
    # opt['word_dim'] = 300
    for output_folder, dataset_name in zip(output_folders, dataset_names):
        data = DatasetCatalog.get(dataset_name)
        data_args = data["args"]
        im_dir = data_args['im_dir']
        test_detection = pickle.load(open(data_args['test_detection_file'],
                                          "rb"),
                                     encoding='latin1')
        prior_mask = pickle.load(open(data_args['prior_mask'], "rb"),
                                 encoding='latin1')
        action_dic = json.load(open(data_args['action_index']))
        action_dic_inv = {y: x for x, y in action_dic.items()}
        vcoco_test_ids = open(data_args['vcoco_test_ids_file'], 'r')
        test_image_id_list = [int(line.rstrip()) for line in vcoco_test_ids]
        vcocoeval = VCOCOeval(data_args['vcoco_test_file'],
                              data_args['ann_file'],
                              data_args['vcoco_test_ids_file'])
        word_embeddings = pickle.load(open(data_args['word_embedding_file'],
                                           "rb"),
                                      encoding='latin1')
        output_file = os.path.join(output_folder, 'detection.pkl')
        output_dict_file = os.path.join(
            output_folder, 'detection_human_{}_new.pkl'.format(dataset_name))

        print(sorted(test_image_id_list))