Ejemplo n.º 1
0
    def dump_result(self, path):

        if dist.is_main_process():
            with open(path, "w") as f:
                for p in self.predictions:
                    f.write(p)
                    f.write("\n")
Ejemplo n.º 2
0
    def summarize(self):
        if dist.is_main_process():
            imgid2pred: Dict[str, List] = defaultdict(list)
            for p in self.predictions:
                imgid2pred[p["original_id"]].append(p)

            for img_id, pred in imgid2pred.items():
                im_pred_dict = {p["task_id"]: p for p in pred}
                self.evaluator.eval_single_img(
                    img_id,
                    im_pred_dict,
                    pred_mask_tag="masks" if self.eval_mask else None,
                    pred_boxes_tag="boxes",
                    verbose=False,
                )

            mask_box = ["box"]
            if self.eval_mask:
                mask_box.append("mask")
            results = self.evaluator.analyze_stats(mask_box,
                                                   exp_name_in_summary=None,
                                                   save_result_to_path=None)

            results = results["all"]["pred_box_acc"]
            return {f"Precision@{k}": v for k, v in results.items()}
        return None
Ejemplo n.º 3
0
    def summarize(self):
        if not dist.is_main_process():
            return

        if self.fixed_ap:
            self._summarize_fixed()
        else:
            self._summarize_standard()
Ejemplo n.º 4
0
    def __init__(self, topk=10000, fixed_ap=True, out_path="lvis_eval"):

        self.results = []
        self.by_cat = {}
        self.topk = topk
        self.fixed_ap = fixed_ap
        self.out_path = out_path
        if dist.is_main_process():
            if not os.path.exists(self.out_path):
                os.mkdir(self.out_path)
Ejemplo n.º 5
0
 def __init__(self,
              split,
              ann_folder,
              output_dir="phrasecut_eval",
              eval_mask=False):
     subset = PhraseCutSubsets(ann_folder)
     loader = RefVGLoader(ann_folder, subset, split=split)
     if dist.is_main_process():
         if not os.path.exists(output_dir):
             os.mkdir(output_dir)
     self.output_dir = output_dir
     self.evaluator = Evaluator(loader, summary_path=output_dir)
     self.eval_mask = eval_mask
     self.predictions = []
Ejemplo n.º 6
0
    def summarize(self):
        if dist.is_main_process():
            dataset2score = {
                "refcoco": {k: 0.0 for k in self.k},
                "refcoco+": {k: 0.0 for k in self.k},
                "refcocog": {k: 0.0 for k in self.k},
            }
            dataset2count = {"refcoco": 0.0, "refcoco+": 0.0, "refcocog": 0.0}
            for image_id in self.img_ids:
                ann_ids = self.refexp_gt.getAnnIds(imgIds=image_id)
                assert len(ann_ids) == 1
                img_info = self.refexp_gt.loadImgs(image_id)[0]

                target = self.refexp_gt.loadAnns(ann_ids[0])
                prediction = self.predictions[image_id]
                assert prediction is not None
                sorted_scores_boxes = sorted(
                    zip(prediction["scores"].tolist(), prediction["boxes"].tolist()), reverse=True
                )
                sorted_scores, sorted_boxes = zip(*sorted_scores_boxes)
                sorted_boxes = torch.cat([torch.as_tensor(x).view(1, 4) for x in sorted_boxes])
                target_bbox = target[0]["bbox"]
                converted_bbox = [
                    target_bbox[0],
                    target_bbox[1],
                    target_bbox[2] + target_bbox[0],
                    target_bbox[3] + target_bbox[1],
                ]
                giou = generalized_box_iou(sorted_boxes, torch.as_tensor(converted_bbox).view(-1, 4))
                for k in self.k:
                    if max(giou[:k]) >= self.thresh_iou:
                        dataset2score[img_info["dataset_name"]][k] += 1.0
                dataset2count[img_info["dataset_name"]] += 1.0

            for key, value in dataset2score.items():
                for k in self.k:
                    try:
                        value[k] /= dataset2count[key]
                    except:
                        pass
            results = {}
            for key, value in dataset2score.items():
                results[key] = sorted([v for k, v in value.items()])
                print(f" Dataset: {key} - Precision @ 1, 5, 10: {results[key]} \n")

            return results
        return None
Ejemplo n.º 7
0
    def summarize(self):
        if dist.is_main_process():
            self.results = self.evaluator.evaluate(self.predictions)
            table = PrettyTable()
            all_cat = sorted(list(self.results.values())[0].keys())
            table.field_names = ["Recall@k"] + all_cat

            score = {}
            for k, v in self.results.items():
                cur_results = [v[cat] for cat in all_cat]
                header = "Upper_bound" if k == -1 else f"Recall@{k}"

                for cat in all_cat:
                    score[f"{header}_{cat}"] = v[cat]
                table.add_row([header] + cur_results)

            print(table)

            return score

        return None, None
Ejemplo n.º 8
0
    def dump_result(self, path):

        if dist.is_main_process():
            with open(path, "w") as f:
                json.dump(self.predictions, f, indent=4, sort_keys=True)
Ejemplo n.º 9
0
def main(args):
    # Init distributed mode
    dist.init_distributed_mode(args)

    # Update dataset specific configs
    if args.dataset_config is not None:
        # https://stackoverflow.com/a/16878364
        d = vars(args)
        with open(args.dataset_config, "r") as f:
            cfg = json.load(f)
        d.update(cfg)

    print("git:\n  {}\n".format(utils.get_sha()))

    # Segmentation related
    if args.mask_model != "none":
        args.masks = True
    if args.frozen_weights is not None:
        assert args.masks, "Frozen training is meant for segmentation only"

    print(args)

    device = torch.device(args.device)
    output_dir = Path(args.output_dir)

    # fix the seed for reproducibility
    seed = args.seed + dist.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.set_deterministic(True)

    # Build the model
    model, criterion, contrastive_criterion, qa_criterion, weight_dict = build_model(
        args)
    model.to(device)

    assert (
        criterion is not None or qa_criterion is not None
    ), "Error: should train either detection or question answering (or both)"

    # Get a copy of the model for exponential moving averaged version of the model
    model_ema = deepcopy(model) if args.ema else None
    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.gpu], find_unused_parameters=True)
        model_without_ddp = model.module
    n_parameters = sum(p.numel() for p in model.parameters()
                       if p.requires_grad)
    print("number of params:", n_parameters)

    # Set up optimizers
    param_dicts = [
        {
            "params": [
                p for n, p in model_without_ddp.named_parameters()
                if "backbone" not in n and "text_encoder" not in n
                and p.requires_grad
            ]
        },
        {
            "params": [
                p for n, p in model_without_ddp.named_parameters()
                if "backbone" in n and p.requires_grad
            ],
            "lr":
            args.lr_backbone,
        },
        {
            "params": [
                p for n, p in model_without_ddp.named_parameters()
                if "text_encoder" in n and p.requires_grad
            ],
            "lr":
            args.text_encoder_lr,
        },
    ]
    if args.optimizer == "sgd":
        optimizer = torch.optim.SGD(param_dicts,
                                    lr=args.lr,
                                    momentum=0.9,
                                    weight_decay=args.weight_decay)
    elif args.optimizer in ["adam", "adamw"]:
        optimizer = torch.optim.AdamW(param_dicts,
                                      lr=args.lr,
                                      weight_decay=args.weight_decay)
    else:
        raise RuntimeError(f"Unsupported optimizer {args.optimizer}")

    # Train dataset
    if len(args.combine_datasets) == 0 and not args.eval:
        raise RuntimeError("Please provide at least one training dataset")

    dataset_train, sampler_train, data_loader_train = None, None, None
    if not args.eval:
        dataset_train = ConcatDataset([
            build_dataset(name, image_set="train", args=args)
            for name in args.combine_datasets
        ])

        # To handle very big datasets, we chunk it into smaller parts.
        if args.epoch_chunks > 0:
            print(
                "Splitting the training set into {args.epoch_chunks} of size approximately "
                f" {len(dataset_train) // args.epoch_chunks}")
            chunks = torch.chunk(torch.arange(len(dataset_train)),
                                 args.epoch_chunks)
            datasets = [
                torch.utils.data.Subset(dataset_train, chunk.tolist())
                for chunk in chunks
            ]
            if args.distributed:
                samplers_train = [DistributedSampler(ds) for ds in datasets]
            else:
                samplers_train = [
                    torch.utils.data.RandomSampler(ds) for ds in datasets
                ]

            batch_samplers_train = [
                torch.utils.data.BatchSampler(sampler_train,
                                              args.batch_size,
                                              drop_last=True)
                for sampler_train in samplers_train
            ]
            assert len(batch_samplers_train) == len(datasets)
            data_loaders_train = [
                DataLoader(
                    ds,
                    batch_sampler=batch_sampler_train,
                    collate_fn=partial(utils.collate_fn, False),
                    num_workers=args.num_workers,
                ) for ds, batch_sampler_train in zip(datasets,
                                                     batch_samplers_train)
            ]
        else:
            if args.distributed:
                sampler_train = DistributedSampler(dataset_train)
            else:
                sampler_train = torch.utils.data.RandomSampler(dataset_train)

            batch_sampler_train = torch.utils.data.BatchSampler(
                sampler_train, args.batch_size, drop_last=True)
            data_loader_train = DataLoader(
                dataset_train,
                batch_sampler=batch_sampler_train,
                collate_fn=partial(utils.collate_fn, False),
                num_workers=args.num_workers,
            )

    # Val dataset
    if len(args.combine_datasets_val) == 0:
        raise RuntimeError("Please provide at leas one validation dataset")

    Val_all = namedtuple(typename="val_data",
                         field_names=[
                             "dataset_name", "dataloader", "base_ds",
                             "evaluator_list"
                         ])

    val_tuples = []
    for dset_name in args.combine_datasets_val:
        dset = build_dataset(dset_name, image_set="val", args=args)
        sampler = (DistributedSampler(dset, shuffle=False) if args.distributed
                   else torch.utils.data.SequentialSampler(dset))
        dataloader = DataLoader(
            dset,
            args.batch_size,
            sampler=sampler,
            drop_last=False,
            collate_fn=partial(utils.collate_fn, False),
            num_workers=args.num_workers,
        )
        base_ds = get_coco_api_from_dataset(dset)
        val_tuples.append(
            Val_all(dataset_name=dset_name,
                    dataloader=dataloader,
                    base_ds=base_ds,
                    evaluator_list=None))

    if args.frozen_weights is not None:
        if args.resume.startswith("https"):
            checkpoint = torch.hub.load_state_dict_from_url(args.resume,
                                                            map_location="cpu",
                                                            check_hash=True)
        else:
            checkpoint = torch.load(args.resume, map_location="cpu")
        if "model_ema" in checkpoint and checkpoint["model_ema"] is not None:
            model_without_ddp.detr.load_state_dict(checkpoint["model_ema"],
                                                   strict=False)
        else:
            model_without_ddp.detr.load_state_dict(checkpoint["model"],
                                                   strict=False)

        if args.ema:
            model_ema = deepcopy(model_without_ddp)

    # Used for loading weights from another model and starting a training from scratch. Especially useful if
    # loading into a model with different functionality.
    if args.load:
        print("loading from", args.load)
        checkpoint = torch.load(args.load, map_location="cpu")
        if "model_ema" in checkpoint:
            model_without_ddp.load_state_dict(checkpoint["model_ema"],
                                              strict=False)
        else:
            model_without_ddp.load_state_dict(checkpoint["model"],
                                              strict=False)

        if args.ema:
            model_ema = deepcopy(model_without_ddp)

    # Used for resuming training from the checkpoint of a model. Used when training times-out or is pre-empted.
    if args.resume:
        if args.resume.startswith("https"):
            checkpoint = torch.hub.load_state_dict_from_url(args.resume,
                                                            map_location="cpu",
                                                            check_hash=True)
        else:
            checkpoint = torch.load(args.resume, map_location="cpu")
        model_without_ddp.load_state_dict(checkpoint["model"])
        if not args.eval and "optimizer" in checkpoint and "epoch" in checkpoint:
            optimizer.load_state_dict(checkpoint["optimizer"])
            args.start_epoch = checkpoint["epoch"] + 1
        if args.ema:
            if "model_ema" not in checkpoint:
                print(
                    "WARNING: ema model not found in checkpoint, resetting to current model"
                )
                model_ema = deepcopy(model_without_ddp)
            else:
                model_ema.load_state_dict(checkpoint["model_ema"])

    def build_evaluator_list(base_ds, dataset_name):
        """Helper function to build the list of evaluators for a given dataset"""
        evaluator_list = []
        if args.no_detection:
            return evaluator_list
        iou_types = ["bbox"]
        if args.masks:
            iou_types.append("segm")

        evaluator_list.append(
            CocoEvaluator(base_ds, tuple(iou_types), useCats=False))
        if "refexp" in dataset_name:
            evaluator_list.append(RefExpEvaluator(base_ds, ("bbox")))
        if "clevrref" in dataset_name:
            evaluator_list.append(ClevrRefEvaluator(base_ds, ("bbox")))
        if "flickr" in dataset_name:
            evaluator_list.append(
                FlickrEvaluator(
                    args.flickr_dataset_path,
                    subset="test" if args.test else "val",
                    merge_boxes=args.GT_type == "merged",
                ))
        if "phrasecut" in dataset_name:
            evaluator_list.append(
                PhrasecutEvaluator(
                    "test" if args.test else "miniv",
                    ann_folder=args.phrasecut_orig_ann_path,
                    output_dir=os.path.join(output_dir, "phrasecut_eval"),
                    eval_mask=args.masks,
                ))
        return evaluator_list

    # Runs only evaluation, by default on the validation set unless --test is passed.
    if args.eval:
        test_stats = {}
        test_model = model_ema if model_ema is not None else model
        for i, item in enumerate(val_tuples):
            evaluator_list = build_evaluator_list(item.base_ds,
                                                  item.dataset_name)
            postprocessors = build_postprocessors(args, item.dataset_name)
            item = item._replace(evaluator_list=evaluator_list)
            print(f"Evaluating {item.dataset_name}")
            curr_test_stats = evaluate(
                model=test_model,
                criterion=criterion,
                contrastive_criterion=contrastive_criterion,
                qa_criterion=qa_criterion,
                postprocessors=postprocessors,
                weight_dict=weight_dict,
                data_loader=item.dataloader,
                evaluator_list=item.evaluator_list,
                device=device,
                args=args,
            )
            test_stats.update({
                item.dataset_name + "_" + k: v
                for k, v in curr_test_stats.items()
            })

        log_stats = {
            **{f"test_{k}": v
               for k, v in test_stats.items()},
            "n_parameters": n_parameters,
        }
        print(log_stats)
        return

    # Runs training and evaluates after every --eval_skip epochs
    print("Start training")
    start_time = time.time()
    best_metric = 0.0
    for epoch in range(args.start_epoch, args.epochs):
        if args.epoch_chunks > 0:
            sampler_train = samplers_train[epoch % len(samplers_train)]
            data_loader_train = data_loaders_train[epoch %
                                                   len(data_loaders_train)]
            print(
                f"Starting epoch {epoch // len(data_loaders_train)}, sub_epoch {epoch % len(data_loaders_train)}"
            )
        else:
            print(f"Starting epoch {epoch}")
        if args.distributed:
            sampler_train.set_epoch(epoch)
        train_stats = train_one_epoch(
            model=model,
            criterion=criterion,
            contrastive_criterion=contrastive_criterion,
            qa_criterion=qa_criterion,
            data_loader=data_loader_train,
            weight_dict=weight_dict,
            optimizer=optimizer,
            device=device,
            epoch=epoch,
            args=args,
            max_norm=args.clip_max_norm,
            model_ema=model_ema,
        )
        if args.output_dir:
            checkpoint_paths = [output_dir / "checkpoint.pth"]
            # extra checkpoint before LR drop and every 2 epochs
            if (epoch + 1) % args.lr_drop == 0 or (epoch + 1) % 2 == 0:
                checkpoint_paths.append(output_dir /
                                        f"checkpoint{epoch:04}.pth")
            for checkpoint_path in checkpoint_paths:
                dist.save_on_master(
                    {
                        "model": model_without_ddp.state_dict(),
                        "model_ema":
                        model_ema.state_dict() if args.ema else None,
                        "optimizer": optimizer.state_dict(),
                        "epoch": epoch,
                        "args": args,
                    },
                    checkpoint_path,
                )

        if epoch % args.eval_skip == 0:
            test_stats = {}
            test_model = model_ema if model_ema is not None else model
            for i, item in enumerate(val_tuples):
                evaluator_list = build_evaluator_list(item.base_ds,
                                                      item.dataset_name)
                item = item._replace(evaluator_list=evaluator_list)
                postprocessors = build_postprocessors(args, item.dataset_name)
                print(f"Evaluating {item.dataset_name}")
                curr_test_stats = evaluate(
                    model=test_model,
                    criterion=criterion,
                    contrastive_criterion=contrastive_criterion,
                    qa_criterion=qa_criterion,
                    postprocessors=postprocessors,
                    weight_dict=weight_dict,
                    data_loader=item.dataloader,
                    evaluator_list=item.evaluator_list,
                    device=device,
                    args=args,
                )
                test_stats.update({
                    item.dataset_name + "_" + k: v
                    for k, v in curr_test_stats.items()
                })
        else:
            test_stats = {}

        log_stats = {
            **{f"train_{k}": v
               for k, v in train_stats.items()},
            **{f"test_{k}": v
               for k, v in test_stats.items()},
            "epoch": epoch,
            "n_parameters": n_parameters,
        }

        if args.output_dir and dist.is_main_process():
            with (output_dir / "log.txt").open("a") as f:
                f.write(json.dumps(log_stats) + "\n")

        if epoch % args.eval_skip == 0:
            if args.do_qa:
                metric = test_stats["gqa_accuracy_answer_total_unscaled"]
            else:
                metric = np.mean([
                    v[1] for k, v in test_stats.items()
                    if "coco_eval_bbox" in k
                ])

            if args.output_dir and metric > best_metric:
                best_metric = metric
                checkpoint_paths = [output_dir / "BEST_checkpoint.pth"]
                # extra checkpoint before LR drop and every 100 epochs
                for checkpoint_path in checkpoint_paths:
                    dist.save_on_master(
                        {
                            "model": model_without_ddp.state_dict(),
                            "optimizer": optimizer.state_dict(),
                            "epoch": epoch,
                            "args": args,
                        },
                        checkpoint_path,
                    )

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print("Training time {}".format(total_time_str))