Esempio n. 1
0
def main(_run, _config):

    print(_config)
    savedir_root = _config["training"]["savedir"]
    device = torch.device(_config["misc"]["device"])

    # save the config file in the directory to restore the configuration
    os.makedirs(savedir_root, exist_ok=True)
    save_config_file(eval(str(_config)), os.path.join(savedir_root, "config.yaml"))

    # parameters for training
    N_LABELS = 40
    input_channels = 1

    print("Creating network...", end="", flush=True)

    def network_function():
        return get_network(
            _config["network"]["model"],
            input_channels,
            N_LABELS,
            _config["network"]["backend_conv"],
            _config["network"]["backend_search"],
        )

    net = network_function()
    net.to(device)
    print("Number of parameters", count_parameters(net))

    print("get the data path...", end="", flush=True)
    rootdir = os.path.join(_config["dataset"]["dir"])
    print("done")

    training_transformations = [
        lcp_transfo.UnitBallNormalize(),
        lcp_transfo.RandomSubSample(_config["dataset"]["npoints"]),
        lcp_transfo.NormalPerturbation(sigma=0.01)
    ]
    test_transformations = [
        lcp_transfo.UnitBallNormalize(),
        lcp_transfo.RandomSubSample(_config["dataset"]["npoints"]),
    ]


    print("Creating dataloaders...", end="", flush=True)
    if _config['dataset']['name'] == "Modelnet40_normal_resampled":
        Dataset = Modelnet40_normal_resampled
    elif _config['dataset']['name'] == "Modelnet40_ply_hdf5_2048":
        Dataset = Modelnet40_ply_hdf5_2048
    ds = Dataset(
        rootdir,
        split='training',
        network_function=network_function,
        transformations_points=training_transformations,
    )
    train_loader = torch.utils.data.DataLoader(
        ds,
        batch_size=_config["training"]["batchsize"],
        shuffle=True,
        num_workers=_config["misc"]["threads"],
    )
    ds_test = Dataset(
        rootdir,
        split='test',
        network_function=network_function,
        transformations_points=test_transformations,
    )
    test_loader = torch.utils.data.DataLoader(
        ds_test,
        batch_size=_config["training"]["batchsize"],
        shuffle=False,
        num_workers=_config["misc"]["threads"],
    )
    print("done")

    print("Creating optimizer...", end="")
    optimizer = torch.optim.Adam(net.parameters(), lr=_config["training"]["lr_start"])
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, _config["training"]["milestones"], gamma=0.5
    )
    print("done")

    def get_data(data):
        
        pts = data["pts"]
        features = data["features"]
        targets = data["target"]
        net_ids = data["net_indices"]
        net_support = data["net_support"]

        features = features.to(device)
        pts = pts.to(device)
        targets = targets.to(device)
        for i in range(len(net_ids)):
            net_ids[i] = net_ids[i].to(device)
        for i in range(len(net_support)):
            net_support[i] = net_support[i].to(device)

        return pts, features, targets, net_ids, net_support

    for epoch in range(_config["training"]["epoch_nbr"]):

        net.train()
        error = 0
        cm = np.zeros((N_LABELS, N_LABELS))

        train_aloss = "0"
        train_oa = "0"
        train_aa = "0"
        train_aiou = "0"

        t = tqdm(
            train_loader,
            desc="Epoch " + str(epoch),
            ncols=130,
            disable=_config["misc"]["disable_tqdm"],
        )
        for data in t:

            pts, features, targets, net_ids, net_support = get_data(data)

            optimizer.zero_grad()
            outputs = net(features, pts, support_points=net_support, indices=net_ids)
            loss = F.cross_entropy(outputs, targets)
            loss.backward()
            optimizer.step()

            # compute scores
            output_np = np.argmax(outputs.cpu().detach().numpy(), axis=1)
            target_np = targets.cpu().numpy()
            cm_ = confusion_matrix(
                target_np.ravel(), output_np.ravel(), labels=list(range(N_LABELS))
            )
            cm += cm_
            error += loss.item()

            # point wise scores on training
            train_oa = "{:.5f}".format(metrics.stats_overall_accuracy(cm))
            train_aa = "{:.5f}".format(metrics.stats_accuracy_per_class(cm)[0])
            train_aiou = "{:.5f}".format(metrics.stats_iou_per_class(cm)[0])
            train_aloss = "{:.5e}".format(error / cm.sum())

            t.set_postfix(OA=train_oa, AA=train_aa, AIOU=train_aiou, ALoss=train_aloss)

        net.eval()
        error = 0
        cm = np.zeros((N_LABELS, N_LABELS))
        test_aloss = "0"
        test_oa = "0"
        test_aa = "0"
        test_aiou = "0"
        with torch.no_grad():

            t = tqdm(
                test_loader,
                desc="  Test " + str(epoch),
                ncols=100,
                disable=_config["misc"]["disable_tqdm"],
            )
            for data in t:

                pts, features, targets, net_ids, net_support = get_data(data)

                outputs = net(
                    features, pts, support_points=net_support, indices=net_ids
                )
                loss = F.cross_entropy(outputs, targets)

                outputs_np = outputs.cpu().detach().numpy()
                pred_labels = np.argmax(outputs_np, axis=1)
                cm_ = confusion_matrix(
                    targets.cpu().numpy(), pred_labels, labels=list(range(N_LABELS))
                )
                cm += cm_
                error += loss.item()

                # point-wise scores on testing
                test_oa = "{:.5f}".format(metrics.stats_overall_accuracy(cm))
                test_aa = "{:.5f}".format(metrics.stats_accuracy_per_class(cm)[0])
                test_aiou = "{:.5f}".format(metrics.stats_iou_per_class(cm)[0])
                test_aloss = "{:.5e}".format(error / cm.sum())

                t.set_postfix(OA=test_oa, AA=test_aa, AIOU=test_aiou, ALoss=test_aloss)

        scheduler.step()

        # create the root folder
        os.makedirs(savedir_root, exist_ok=True)

        # save the checkpoint
        torch.save(
            {
                "epoch": epoch + 1,
                "state_dict": net.state_dict(),
                "optimizer": optimizer.state_dict(),
            },
            os.path.join(savedir_root, "checkpoint.pth"),
        )

        # write the logs
        logs = open(os.path.join(savedir_root, "logs.txt"), "a+")
        logs.write(str(epoch) + " ")
        logs.write(train_aloss + " ")
        logs.write(train_oa + " ")
        logs.write(train_aa + " ")
        logs.write(train_aiou + " ")
        logs.write(test_aloss + " ")
        logs.write(test_oa + " ")
        logs.write(test_aa + " ")
        logs.write(test_aiou + "\n")
        logs.flush()
        logs.close()

        # log for Sacred
        _run.log_scalar("trainOA", train_oa, epoch)
        _run.log_scalar("trainAA", train_aa, epoch)
        _run.log_scalar("trainAIoU", train_aiou, epoch)
        _run.log_scalar("trainLoss", train_aloss, epoch)
        _run.log_scalar("testOA", test_oa, epoch)
        _run.log_scalar("testAA", test_aa, epoch)
        _run.log_scalar("testAIoU", test_aiou, epoch)
        _run.log_scalar("testLoss", test_aloss, epoch)
Esempio n. 2
0
def main(_run, _config):

    print(_config)

    savedir_root = _config["training"]["savedir"]
    device = torch.device(_config["misc"]["device"])

    # save the config file
    os.makedirs(savedir_root, exist_ok=True)
    save_config_file(eval(str(_config)), os.path.join(savedir_root, "config.yaml"))

    print("get the data path...", end="", flush=True)
    rootdir = _config["dataset"]["dir"]
    print("done")

    N_CLASSES = 50

    print("Creating network...", end="", flush=True)

    def network_function():
        return Network(
            1, N_CLASSES,
            get_conv(_config["network"]["backend_conv"]),
            get_search(_config["network"]["backend_search"]),
        )

    net = network_function()
    net.to(device)
    network_parameters = count_parameters(net)
    print("parameters", network_parameters)

    training_transformations = [
        lcp_transfo.UnitBallNormalize(),
        lcp_transfo.RandomSubSample(_config["dataset"]["npoints"]),
        lcp_transfo.NormalPerturbation(sigma=0.001)
    ]
    test_transformations = [
        lcp_transfo.UnitBallNormalize(),
        lcp_transfo.RandomSubSample(_config["dataset"]["npoints"]),
    ]

    print("Creating dataloader...", end="", flush=True)
    ds = Dataset(
        rootdir,
        'training',
        network_function=network_function,
        transformations_points=training_transformations
    )
    train_loader = torch.utils.data.DataLoader(
        ds,
        batch_size=_config["training"]["batchsize"],
        shuffle=True,
        num_workers=_config["misc"]["threads"],
    )
    ds_test = Dataset(
        rootdir,
        'test',
        network_function=network_function,
        transformations_points=test_transformations
    )
    test_loader = torch.utils.data.DataLoader(
        ds_test,
        batch_size=_config["training"]["batchsize"],
        shuffle=False,
        num_workers=_config["misc"]["threads"],
    )
    print("Done")


    # define weights
    print("Computing weights...", end="", flush=True)
    weights = torch.from_numpy(ds.get_weights()).float().to(device)
    print("Done")

    print("Creating optimizer...", end="", flush=True)
    optimizer = torch.optim.Adam(net.parameters(), lr=_config["training"]["lr_start"], eps=1e-3)
    epoch_start = 0
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer,
        _config["training"]["milestones"],
        gamma=_config["training"]["gamma"],
        last_epoch=epoch_start - 1,
    )
    print("Done")


    def get_data(data):

        pts = data["pts"].to(device)
        features = data["features"].to(device)
        seg = data["seg"].to(device)
        labels = data["label"]
        net_ids = data["net_indices"]
        net_pts = data["net_support"]
        for i in range(len(net_ids)):
            net_ids[i] = net_ids[i].to(device)
        for i in range(len(net_pts)):
            net_pts[i] = net_pts[i].to(device)

        return pts, features, seg, labels, net_ids, net_pts


    # create the log file
    for epoch in range(epoch_start, _config["training"]["epoch_nbr"]):

        # train
        net.train()
        cm = np.zeros((N_CLASSES, N_CLASSES))
        t = tqdm(
            train_loader,
            ncols=120,
            desc=f"Epoch {epoch}",
            disable=_config["misc"]["disable_tqdm"],
        )
        for data in t:

            pts, features, seg, labels, net_ids, net_pts = get_data(data)

            optimizer.zero_grad()
            outputs = net(features, pts, support_points=net_pts, indices=net_ids)
            loss = F.cross_entropy(outputs, seg, weight=weights)
            loss.backward()
            optimizer.step()

            outputs_np = outputs.cpu().detach().numpy()
            for i in range(pts.size(0)):
                # get the number of part for the shape
                object_label = labels[i]
                part_start, part_end = ds.category_range[object_label]

                outputs_np[i, :part_start] = -1e7
                outputs_np[i, part_end:] = -1e7

            output_np = np.argmax(outputs_np, axis=1).copy()
            target_np = seg.cpu().numpy().copy()

            cm_ = confusion_matrix(
                target_np.ravel(), output_np.ravel(), labels=list(range(N_CLASSES))
            )
            cm += cm_

            oa = "{:.3f}".format(metrics.stats_overall_accuracy(cm))
            aa = "{:.3f}".format(metrics.stats_accuracy_per_class(cm)[0])
            iou = "{:.3f}".format(metrics.stats_iou_per_class(cm)[0])

            t.set_postfix(OA=oa, AA=aa, IOU=iou)

        # eval (this is not the final evaluation, see dedicated evaluation)
        net.eval()
        with torch.no_grad():
            cm = np.zeros((N_CLASSES, N_CLASSES))
            t = tqdm(
                test_loader,
                ncols=120,
                desc=f"Test {epoch}",
                disable=_config["misc"]["disable_tqdm"],
            )
            for data in t:

                pts, features, seg, labels, net_ids, net_pts = get_data(data)

                outputs = net(features, pts, support_points=net_pts, indices=net_ids)
                loss = 0

                for i in range(pts.size(0)):
                    # get the number of part for the shape
                    object_label = labels[i]
                    part_start, part_end = ds_test.category_range[object_label]

                    outputs_ = (outputs[i, part_start:part_end]).unsqueeze(0)
                    seg_ = (seg[i] - part_start).unsqueeze(0)

                    loss = loss + weights[object_label] * F.cross_entropy(
                        outputs_, seg_
                    )

                outputs_np = outputs.cpu().detach().numpy()
                for i in range(pts.size(0)):
                    # get the number of part for the shape
                    object_label = labels[i]
                    part_start, part_end = ds_test.category_range[object_label]

                    outputs_np[i, :part_start] = -1e7
                    outputs_np[i, part_end:] = -1e7

                output_np = np.argmax(outputs_np, axis=1).copy()
                target_np = seg.cpu().numpy().copy()

                cm_ = confusion_matrix(
                    target_np.ravel(), output_np.ravel(), labels=list(range(N_CLASSES))
                )
                cm += cm_

                oa_test = "{:.3f}".format(metrics.stats_overall_accuracy(cm))
                aa_test = "{:.3f}".format(metrics.stats_accuracy_per_class(cm)[0])
                iou_test = "{:.3f}".format(metrics.stats_iou_per_class(cm)[0])

                t.set_postfix(OA=oa_test, AA=aa_test, IOU=iou_test)

        # scheduler update
        scheduler.step()

        # save the model
        os.makedirs(savedir_root, exist_ok=True)
        torch.save(
            {
                "epoch": epoch + 1,
                "state_dict": net.state_dict(),
                "optimizer": optimizer.state_dict(),
            },
            os.path.join(savedir_root, "checkpoint.pth"),
        )

        # write the logs
        logs = open(os.path.join(savedir_root, "log.txt"), "a+")
        logs.write(f"{epoch} {oa} {aa} {iou} {oa_test} {aa_test} {iou_test} \n")
        logs.close()

        _run.log_scalar("trainOA", oa, epoch)
        _run.log_scalar("trainAA", aa, epoch)
        _run.log_scalar("trainIoU", iou, epoch)
        _run.log_scalar("testOA", oa_test, epoch)
        _run.log_scalar("testAA", aa_test, epoch)
        _run.log_scalar("testIoU", iou_test, epoch)

    logs.close()
Esempio n. 3
0
def main(_config):

    print(_config)

    savedir_root = _config["training"]["savedir"]
    device = torch.device(_config["misc"]["device"])

    print("get the data path...", end="", flush=True)
    rootdir = _config["dataset"]["dir"]
    print("done")

    N_CLASSES = 50

    print("Creating network...", end="", flush=True)

    def network_function():
        return Network(
            1,
            N_CLASSES,
            get_conv(_config["network"]["backend_conv"]),
            get_search(_config["network"]["backend_search"]),
        )

    net = network_function()
    net.load_state_dict(
        torch.load(os.path.join(savedir_root, "checkpoint.pth"),
                   map_location=device)["state_dict"])
    net.to(device)
    net.eval()
    print("Done")

    print("Creating dataloader...", end="", flush=True)
    test_transformations = [
        lcp_transfo.UnitBallNormalize(),
        lcp_transfo.RandomSubSample(_config["dataset"]["npoints"]),
    ]
    ds_test = Dataset(
        rootdir,
        'test',
        network_function=network_function,
        transformations_points=test_transformations,
        # iter_per_shape=_config["test"]["num_iter_per_shape"]
        iter_per_shape=1)
    test_loader = torch.utils.data.DataLoader(
        ds_test,
        batch_size=_config["test"]["batchsize"],
        shuffle=False,
        num_workers=_config["misc"]["threads"],
    )
    print("Done")

    # per shape results
    results = torch.zeros(ds_test.data.shape[0], ds_test.data.shape[1],
                          N_CLASSES)
    results_count = torch.zeros(ds_test.data.shape[0], ds_test.data.shape[1])

    with torch.no_grad():
        cm = np.zeros((N_CLASSES, N_CLASSES))

        t = tqdm(test_loader, ncols=100, desc="Inference")
        for data in t:

            pts = data["pts"].to(device)
            features = data["features"].to(device)
            seg = data["seg"].to(device)
            choices = data["choice"]
            labels = data["label"]
            indices = data["index"]
            net_ids = data["net_indices"]
            net_pts = data["net_support"]
            for i in range(len(net_ids)):
                net_ids[i] = net_ids[i].to(device)
            for i in range(len(net_pts)):
                net_pts[i] = net_pts[i].to(device)

            outputs = net(features,
                          pts,
                          support_points=net_pts,
                          indices=net_ids)

            outputs = outputs.to(torch.device("cpu"))
            for b_id in range(outputs.shape[0]):

                object_label = labels[i]
                part_start, part_end = ds_test.category_range[object_label]
                outputs[i, :part_start] = -1e7
                outputs[i, part_end:] = -1e7

                shape_id = indices[b_id]
                choice = choices[b_id]

                results_shape = results[shape_id]
                results_shape[choice] += outputs[b_id].transpose(0, 1)
                results[shape_id] = results_shape

                results_count_shape = results_count[shape_id]
                results_count_shape[choice] = 1
                results_count[shape_id] = results_count_shape

            output_np = outputs.cpu().numpy()
            output_np = np.argmax(output_np, axis=1).copy()
            target_np = seg.cpu().numpy().copy()

            cm_ = confusion_matrix(target_np.ravel(),
                                   output_np.ravel(),
                                   labels=list(range(N_CLASSES)))
            cm += cm_

    Confs = []
    for s_id in tqdm(range(ds_test.size()), ncols=100, desc="Conf. matrices"):

        shape_label = ds_test.labels_shape[s_id]
        # get the number of points
        npts = ds_test.data_num[s_id]

        # get the gt and estimate the number of parts
        label_gt = ds_test.labels_pts[s_id, :npts]
        part_start, part_end = ds_test.category_range[shape_label]
        label_gt -= part_start

        # get the results
        res_shape = results[s_id, :npts, part_start:part_end]

        # extend results to unseen points
        mask = results_count[s_id, :npts].cpu().numpy() == 1
        if np.logical_not(mask).sum() > 0:
            res_shape_mask = res_shape[mask]
            pts_src = torch.from_numpy(
                ds_test.data[s_id, :npts][mask]).transpose(0, 1)
            pts_dest = ds_test.data[s_id, :npts]
            pts_dest = pts_dest[np.logical_not(mask)]
            pts_dest = torch.from_numpy(pts_dest).transpose(0, 1)
            res_shape_unseen = nearest_correspondance(pts_src,
                                                      pts_dest,
                                                      res_shape_mask.transpose(
                                                          0, 1),
                                                      K=1).transpose(0, 1)
            res_shape[np.logical_not(mask)] = res_shape_unseen

        res_shape = res_shape.numpy()

        label_pred = np.argmax(res_shape, axis=1)
        cm_shape = confusion_matrix(label_gt,
                                    label_pred,
                                    labels=list(range(part_end - part_start)))
        Confs.append(cm_shape)

    # compute IoU per shape
    print("Computing IoUs...", end="", flush=True)
    IoUs_per_shape = []
    for i in range(ds_test.labels_shape.shape[0]):
        IoUs_per_shape.append(metrics.stats_iou_per_class(Confs[i])[0])
    IoUs_per_shape = np.array(IoUs_per_shape)

    # compute object category average
    obj_IoUs = np.zeros(len(ds_test.label_names))
    for i in range(len(ds_test.label_names)):
        obj_IoUs[i] = IoUs_per_shape[ds_test.labels_shape == i].mean()
    print("Done")

    print("Objs | Inst | Air  Bag  Cap  Car  Cha  Ear  Gui  "
          "Kni  Lam  Lap  Mot  Mug  Pis  Roc  Ska  Tab")
    print("-----|------|-------------------------------"
          "-------------------------------------------------")
    s = "{:3.1f} | {:3.1f} | ".format(100 * obj_IoUs.mean(),
                                      100 * np.mean(IoUs_per_shape))
    for AmIoU in obj_IoUs:
        s += "{:3.1f} ".format(100 * AmIoU)
    print(s + "\n")