def test(model, dataloader, **varargs):
    model.eval()
    dataloader.batch_sampler.set_epoch(0)

    data_time_meter = AverageMeter(())
    batch_time_meter = AverageMeter(())

    make_panoptic = varargs["make_panoptic"]
    num_stuff = varargs["num_stuff"]
    save_function = varargs["save_function"]

    data_time = time.time()
    for it, batch in enumerate(dataloader):
        with torch.no_grad():
            # Extract data

            img = batch["img"].cuda(device=varargs["device"],
                                    non_blocking=True)

            data_time_meter.update(torch.tensor(time.time() - data_time))

            batch_time = time.time()

            # Run network
            _, pred, _ = model(img=img, do_loss=False, do_prediction=True)

            # Update meters
            batch_time_meter.update(torch.tensor(time.time() - batch_time))

            for i, (sem_pred, bbx_pred, cls_pred, obj_pred,
                    msk_pred) in enumerate(
                        zip(pred["sem_pred"], pred["bbx_pred"],
                            pred["cls_pred"], pred["obj_pred"],
                            pred["msk_pred"])):
                img_info = {
                    "batch_size": batch["img"][i].shape[-2:],
                    "original_size": batch["size"][i],
                    "rel_path": batch["rel_path"][i],
                    "abs_path": batch["abs_path"][i]
                }

                # Compute panoptic output
                panoptic_pred = make_panoptic(sem_pred, bbx_pred, cls_pred,
                                              obj_pred, msk_pred, num_stuff)

                # Save prediction
                raw_pred = (sem_pred, bbx_pred, cls_pred, obj_pred, msk_pred)
                save_function(raw_pred, panoptic_pred, img_info)

            # Log batch
            if varargs["summary"] is not None and (
                    it + 1) % varargs["log_interval"] == 0:
                logging.iteration(
                    None, "val", 0, 1, 1, it + 1, len(dataloader),
                    OrderedDict([("data_time", data_time_meter),
                                 ("batch_time", batch_time_meter)]))

            data_time = time.time()
def validate(model, dataloader, loss_weights, **varargs):
    model.eval()
    dataloader.batch_sampler.set_epoch(varargs["epoch"])

    num_stuff = dataloader.dataset.num_stuff

    loss_meter = AverageMeter(())
    data_time_meter = AverageMeter(())
    batch_time_meter = AverageMeter(())

    # Accumulators for ap and panoptic computation
    coco_struct = []
    img_list = []

    data_time = time.time()
    for it, batch in enumerate(dataloader):
        with torch.no_grad():
            idxs = batch["idx"]
            batch_sizes = [img.shape[-2:] for img in batch["img"]]
            original_sizes = batch["size"]

            # Upload batch
            batch = {
                k: batch[k].cuda(device=varargs["device"], non_blocking=True)
                for k in NETWORK_INPUTS
            }
            data_time_meter.update(torch.tensor(time.time() - data_time))

            batch_time = time.time()

            # Run network
            losses, pred = model(**batch, do_loss=True, do_prediction=True)
            losses = OrderedDict((k, v.mean()) for k, v in losses.items())
            losses = all_reduce_losses(losses)
            loss = sum(w * l for w, l in zip(loss_weights, losses.values()))

            # Update meters
            loss_meter.update(loss.cpu())
            batch_time_meter.update(torch.tensor(time.time() - batch_time))

            del loss, losses

            # Accumulate COCO AP and panoptic predictions
            for i, (bbx_pred, cls_pred, obj_pred, msk_pred) in enumerate(
                    zip(pred["bbx_pred"], pred["cls_pred"], pred["obj_pred"],
                        pred["msk_pred"])):
                # If there are no detections skip this image
                if bbx_pred is None:
                    continue

                # COCO AP
                coco_struct += coco_ap.process_prediction(
                    bbx_pred, cls_pred + num_stuff, obj_pred, msk_pred,
                    batch_sizes[i], idxs[i], original_sizes[i])
                img_list.append(idxs[i])

            del pred, batch

            # Log batch
            if varargs["summary"] is not None and (
                    it + 1) % varargs["log_interval"] == 0:
                logging.iteration(
                    None, "val", varargs["global_step"], varargs["epoch"] + 1,
                    varargs["num_epochs"], it + 1, len(dataloader),
                    OrderedDict([("loss", loss_meter),
                                 ("data_time", data_time_meter),
                                 ("batch_time", batch_time_meter)]))

            data_time = time.time()

    # Finalize AP computation
    det_map, msk_map = coco_ap.summarize_mp(coco_struct, varargs["coco_gt"],
                                            img_list, varargs["log_dir"], True)

    # Log results
    log_info("Validation done")
    if varargs["summary"] is not None:
        logging.iteration(
            varargs["summary"], "val", varargs["global_step"],
            varargs["epoch"] + 1, varargs["num_epochs"], len(dataloader),
            len(dataloader),
            OrderedDict([("loss", loss_meter.mean.item()),
                         ("det_map", det_map), ("msk_map", msk_map),
                         ("data_time", data_time_meter.mean.item()),
                         ("batch_time", batch_time_meter.mean.item())]))

    return msk_map
Exemple #3
0
def test(model, dataloader, **varargs):
    torch.cuda.empty_cache()

    # model.half()
    model.eval()

    torch.cuda.empty_cache()
    dataloader.batch_sampler.set_epoch(0)

    data_time_meter = AverageMeter(())
    batch_time_meter = AverageMeter(())

    make_panoptic = varargs["make_panoptic"]
    num_stuff = varargs["num_stuff"]
    save_function = varargs["save_function"]

    data_time = time.time()

    img_dir = r'/media/huan/Huan1/DC_panoramas1'
    img_files = glob.glob(os.path.join(img_dir, '*.jpg'))

    # for p in varargs:
    #     print("varargs:", str(p))

    for it, batch in tqdm(enumerate(dataloader)):
        # for it, img_path in tqdm(enumerate(img_files)):
        torch.cuda.empty_cache()
        try:
            torch.cuda.empty_cache()
            with torch.no_grad():
                torch.cuda.empty_cache()

                # Extract data
                # batch = get_test_input(img_path)
                # for k in batch.keys():
                #     print("batch: ", k, batch[k])
                print("Processing: ", batch['idx'])

                img = batch["img"].cuda(device=varargs["device"],
                                        non_blocking=True)

                # print("type(img)", type(img))

                data_time_meter.update(torch.tensor(time.time() - data_time))

                batch_time = time.time()

                # Run network
                xx, pred, xxx = model(img=img,
                                      do_loss=False,
                                      do_prediction=True)
                del img
                torch.cuda.empty_cache()
                del xx
                torch.cuda.empty_cache()
                del xxx
                torch.cuda.empty_cache()

                # print("pred:", pred["sem_pred"].device)

                # Update meters
                batch_time_meter.update(torch.tensor(time.time() - batch_time))

                for i, (sem_pred, bbx_pred, cls_pred, obj_pred,
                        msk_pred) in enumerate(
                            zip(pred["sem_pred"], pred["bbx_pred"],
                                pred["cls_pred"], pred["obj_pred"],
                                pred["msk_pred"])):
                    img_info = {
                        "batch_size": batch["img"][i].shape[-2:],
                        "original_size": batch["size"][i],
                        "rel_path": batch["rel_path"][i],
                        "abs_path": batch["abs_path"][i]
                    }

                    # Compute panoptic output
                    panoptic_pred = make_panoptic(sem_pred, bbx_pred, cls_pred,
                                                  obj_pred, msk_pred,
                                                  num_stuff)
                    # print("panoptic_pred: ", panoptic_pred.device)

                    # Save prediction
                    raw_pred = (sem_pred, bbx_pred, cls_pred, obj_pred,
                                msk_pred)
                    save_function(raw_pred, panoptic_pred, img_info)

                    del sem_pred
                    torch.cuda.empty_cache()
                    del msk_pred
                    torch.cuda.empty_cache()
                    del bbx_pred
                    torch.cuda.empty_cache()
                    del cls_pred
                    torch.cuda.empty_cache()
                    del obj_pred
                    torch.cuda.empty_cache()

                del pred
                torch.cuda.empty_cache()
                # Log batch
                if varargs["summary"] is not None and (
                        it + 1) % varargs["log_interval"] == 0:
                    logging.iteration(
                        None, "val", 0, 1, 1, it + 1, len(dataloader),
                        OrderedDict([("data_time", data_time_meter),
                                     ("batch_time", batch_time_meter)]))

                data_time = time.time()

                torch.cuda.empty_cache()

        except Exception as e:
            print("Error in tqdm(enumerate(dataloader))", e)
            #   torch.cuda.empty_cache()
            time.sleep(2)
            #   model.eval()
            continue
def train(model, optimizer, scheduler, dataloader, meters, **varargs):
    model.train()
    dataloader.batch_sampler.set_epoch(varargs["epoch"])
    optimizer.zero_grad()
    global_step = varargs["global_step"]
    loss_weights = varargs["loss_weights"]

    data_time_meter = AverageMeter((), meters["loss"].momentum)
    batch_time_meter = AverageMeter((), meters["loss"].momentum)

    data_time = time.time()
    for it, batch in enumerate(dataloader):
        # Upload batch
        batch = {
            k: batch[k].cuda(device=varargs["device"], non_blocking=True)
            for k in NETWORK_INPUTS
        }

        data_time_meter.update(torch.tensor(time.time() - data_time))

        # Update scheduler
        global_step += 1
        if varargs["batch_update"]:
            scheduler.step(global_step)

        batch_time = time.time()

        # Run network
        losses, _ = model(**batch, do_loss=True, do_prediction=False)
        distributed.barrier()

        losses = OrderedDict((k, v.mean()) for k, v in losses.items())
        losses["loss"] = sum(w * l
                             for w, l in zip(loss_weights, losses.values()))

        optimizer.zero_grad()
        losses["loss"].backward()
        optimizer.step()

        # Gather stats from all workers
        losses = all_reduce_losses(losses)

        # Update meters
        with torch.no_grad():
            for loss_name, loss_value in losses.items():
                meters[loss_name].update(loss_value.cpu())
        batch_time_meter.update(torch.tensor(time.time() - batch_time))

        # Clean-up
        del batch, losses

        # Log
        if varargs["summary"] is not None and (
                it + 1) % varargs["log_interval"] == 0:
            logging.iteration(
                varargs["summary"], "train", global_step, varargs["epoch"] + 1,
                varargs["num_epochs"], it + 1, len(dataloader),
                OrderedDict([("lr", scheduler.get_lr()[0]),
                             ("loss", meters["loss"]),
                             ("obj_loss", meters["obj_loss"]),
                             ("bbx_loss", meters["bbx_loss"]),
                             ("roi_cls_loss", meters["roi_cls_loss"]),
                             ("roi_bbx_loss", meters["roi_bbx_loss"]),
                             ("roi_msk_loss", meters["roi_msk_loss"]),
                             ("data_time", data_time_meter),
                             ("batch_time", batch_time_meter)]))

        data_time = time.time()

    return global_step
def validate(model, dataloader, loss_weights, **varargs):
    model.eval()
    dataloader.batch_sampler.set_epoch(varargs["epoch"])

    num_stuff = dataloader.dataset.num_stuff
    num_classes = dataloader.dataset.num_categories

    loss_meter = AverageMeter(())
    data_time_meter = AverageMeter(())
    batch_time_meter = AverageMeter(())

    # Accumulators for ap, mIoU and panoptic computation
    panoptic_buffer = torch.zeros(4, num_classes, dtype=torch.double)
    conf_mat = torch.zeros(256, 256, dtype=torch.double)
    coco_struct = []
    img_list = []

    data_time = time.time()
    for it, batch in enumerate(dataloader):
        with torch.no_grad():
            idxs = batch["idx"]
            batch_sizes = [img.shape[-2:] for img in batch["img"]]
            original_sizes = batch["size"]

            # Upload batch
            batch = {
                k: batch[k].cuda(device=varargs["device"], non_blocking=True)
                for k in NETWORK_INPUTS
            }
            assert all(msk.size(0) == 1 for msk in batch["msk"]), \
                "Mask R-CNN + segmentation requires panoptic ground truth"
            data_time_meter.update(torch.tensor(time.time() - data_time))

            batch_time = time.time()

            # Run network
            losses, pred, conf = model(**batch,
                                       do_loss=True,
                                       do_prediction=True)
            losses = OrderedDict((k, v.mean()) for k, v in losses.items())
            losses = all_reduce_losses(losses)
            loss = sum(w * l for w, l in zip(loss_weights, losses.values()))

            if varargs["eval_mode"] == "separate":
                # Directly accumulate confusion matrix from the network
                conf_mat[:num_classes, :num_classes] += conf["sem_conf"].to(
                    conf_mat)

            # Update meters
            loss_meter.update(loss.cpu())
            batch_time_meter.update(torch.tensor(time.time() - batch_time))

            del loss, losses, conf

            # Accumulate COCO AP and panoptic predictions
            for i, (sem_pred, bbx_pred, cls_pred, obj_pred, msk_pred, msk_gt,
                    cat_gt, iscrowd) in enumerate(
                        zip(pred["sem_pred"], pred["bbx_pred"],
                            pred["cls_pred"], pred["obj_pred"],
                            pred["msk_pred"], batch["msk"], batch["cat"],
                            batch["iscrowd"])):
                msk_gt = msk_gt.squeeze(0)
                sem_gt = cat_gt[msk_gt]

                # Remove crowd from gt
                cmap = msk_gt.new_zeros(cat_gt.numel())
                cmap[~iscrowd] = torch.arange(0,
                                              (~iscrowd).long().sum().item(),
                                              dtype=cmap.dtype,
                                              device=cmap.device)
                msk_gt = cmap[msk_gt]
                cat_gt = cat_gt[~iscrowd]

                # Compute panoptic output
                panoptic_pred = varargs["make_panoptic"](sem_pred, bbx_pred,
                                                         cls_pred, obj_pred,
                                                         msk_pred, num_stuff)

                # Panoptic evaluation
                panoptic_buffer += torch.stack(panoptic_stats(
                    msk_gt, cat_gt, panoptic_pred, num_classes, num_stuff),
                                               dim=0)

                if varargs["eval_mode"] == "panoptic":
                    # Calculate confusion matrix on panoptic output
                    sem_pred = panoptic_pred[1][panoptic_pred[0]]

                    conf_mat_i = confusion_matrix(sem_gt.cpu(), sem_pred)
                    conf_mat += conf_mat_i.to(conf_mat)

                    # Update coco AP from panoptic output
                    if varargs["eval_coco"] and (
                        (panoptic_pred[1] >= num_stuff) &
                        (panoptic_pred[1] != 255)).any():
                        coco_struct += coco_ap.process_panoptic_prediction(
                            panoptic_pred, num_stuff, idxs[i], batch_sizes[i],
                            original_sizes[i])
                        img_list.append(idxs[i])
                elif varargs["eval_mode"] == "separate":
                    # Update coco AP from detection output
                    if varargs["eval_coco"] and bbx_pred is not None:
                        coco_struct += coco_ap.process_prediction(
                            bbx_pred, cls_pred + num_stuff, obj_pred, msk_pred,
                            batch_sizes[i], idxs[i], original_sizes[i])
                        img_list.append(idxs[i])

            del pred, batch

            # Log batch
            if varargs["summary"] is not None and (
                    it + 1) % varargs["log_interval"] == 0:
                logging.iteration(
                    None, "val", varargs["global_step"], varargs["epoch"] + 1,
                    varargs["num_epochs"], it + 1, len(dataloader),
                    OrderedDict([("loss", loss_meter),
                                 ("data_time", data_time_meter),
                                 ("batch_time", batch_time_meter)]))

            data_time = time.time()

    # Finalize mIoU computation
    conf_mat = conf_mat.to(device=varargs["device"])
    distributed.all_reduce(conf_mat, distributed.ReduceOp.SUM)
    conf_mat = conf_mat.cpu()[:num_classes, :]
    miou = conf_mat.diag() / (conf_mat.sum(dim=1) + conf_mat.sum(
        dim=0)[:num_classes] - conf_mat.diag())

    # Finalize AP computation
    if varargs["eval_coco"]:
        det_map, msk_map = coco_ap.summarize_mp(coco_struct,
                                                varargs["coco_gt"], img_list,
                                                varargs["log_dir"], True)

    # Finalize panoptic computation
    panoptic_score, stuff_pq, thing_pq = get_panoptic_scores(
        panoptic_buffer, varargs["device"], num_stuff)

    # Log results
    log_info("Validation done")
    if varargs["summary"] is not None:
        metrics = OrderedDict()
        metrics["loss"] = loss_meter.mean.item()
        if varargs["eval_coco"]:
            metrics["det_map"] = det_map
            metrics["msk_map"] = msk_map
        metrics["miou"] = miou.mean().item()
        metrics["panoptic"] = panoptic_score
        metrics["stuff_pq"] = stuff_pq
        metrics["thing_pq"] = thing_pq
        metrics["data_time"] = data_time_meter.mean.item()
        metrics["batch_time"] = batch_time_meter.mean.item()

        logging.iteration(varargs["summary"], "val", varargs["global_step"],
                          varargs["epoch"] + 1, varargs["num_epochs"],
                          len(dataloader), len(dataloader), metrics)

    log_miou(miou, dataloader.dataset.categories)

    return panoptic_score