示例#1
0
def main(_run, _config):

    print(_config)
    savedir_root = _config["training"]["savedir"]
    device = torch.device(_config["misc"]["device"])

    # save the config file in the directory to restore the configuration
    os.makedirs(savedir_root, exist_ok=True)
    save_config_file(eval(str(_config)), os.path.join(savedir_root, "config.yaml"))

    # parameters for training
    N_LABELS = 40
    input_channels = 1

    print("Creating network...", end="", flush=True)

    def network_function():
        return get_network(
            _config["network"]["model"],
            input_channels,
            N_LABELS,
            _config["network"]["backend_conv"],
            _config["network"]["backend_search"],
        )

    net = network_function()
    net.to(device)
    print("Number of parameters", count_parameters(net))

    print("get the data path...", end="", flush=True)
    rootdir = os.path.join(_config["dataset"]["dir"])
    print("done")

    training_transformations = [
        lcp_transfo.UnitBallNormalize(),
        lcp_transfo.RandomSubSample(_config["dataset"]["npoints"]),
        lcp_transfo.NormalPerturbation(sigma=0.01)
    ]
    test_transformations = [
        lcp_transfo.UnitBallNormalize(),
        lcp_transfo.RandomSubSample(_config["dataset"]["npoints"]),
    ]


    print("Creating dataloaders...", end="", flush=True)
    if _config['dataset']['name'] == "Modelnet40_normal_resampled":
        Dataset = Modelnet40_normal_resampled
    elif _config['dataset']['name'] == "Modelnet40_ply_hdf5_2048":
        Dataset = Modelnet40_ply_hdf5_2048
    ds = Dataset(
        rootdir,
        split='training',
        network_function=network_function,
        transformations_points=training_transformations,
    )
    train_loader = torch.utils.data.DataLoader(
        ds,
        batch_size=_config["training"]["batchsize"],
        shuffle=True,
        num_workers=_config["misc"]["threads"],
    )
    ds_test = Dataset(
        rootdir,
        split='test',
        network_function=network_function,
        transformations_points=test_transformations,
    )
    test_loader = torch.utils.data.DataLoader(
        ds_test,
        batch_size=_config["training"]["batchsize"],
        shuffle=False,
        num_workers=_config["misc"]["threads"],
    )
    print("done")

    print("Creating optimizer...", end="")
    optimizer = torch.optim.Adam(net.parameters(), lr=_config["training"]["lr_start"])
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, _config["training"]["milestones"], gamma=0.5
    )
    print("done")

    def get_data(data):
        
        pts = data["pts"]
        features = data["features"]
        targets = data["target"]
        net_ids = data["net_indices"]
        net_support = data["net_support"]

        features = features.to(device)
        pts = pts.to(device)
        targets = targets.to(device)
        for i in range(len(net_ids)):
            net_ids[i] = net_ids[i].to(device)
        for i in range(len(net_support)):
            net_support[i] = net_support[i].to(device)

        return pts, features, targets, net_ids, net_support

    for epoch in range(_config["training"]["epoch_nbr"]):

        net.train()
        error = 0
        cm = np.zeros((N_LABELS, N_LABELS))

        train_aloss = "0"
        train_oa = "0"
        train_aa = "0"
        train_aiou = "0"

        t = tqdm(
            train_loader,
            desc="Epoch " + str(epoch),
            ncols=130,
            disable=_config["misc"]["disable_tqdm"],
        )
        for data in t:

            pts, features, targets, net_ids, net_support = get_data(data)

            optimizer.zero_grad()
            outputs = net(features, pts, support_points=net_support, indices=net_ids)
            loss = F.cross_entropy(outputs, targets)
            loss.backward()
            optimizer.step()

            # compute scores
            output_np = np.argmax(outputs.cpu().detach().numpy(), axis=1)
            target_np = targets.cpu().numpy()
            cm_ = confusion_matrix(
                target_np.ravel(), output_np.ravel(), labels=list(range(N_LABELS))
            )
            cm += cm_
            error += loss.item()

            # point wise scores on training
            train_oa = "{:.5f}".format(metrics.stats_overall_accuracy(cm))
            train_aa = "{:.5f}".format(metrics.stats_accuracy_per_class(cm)[0])
            train_aiou = "{:.5f}".format(metrics.stats_iou_per_class(cm)[0])
            train_aloss = "{:.5e}".format(error / cm.sum())

            t.set_postfix(OA=train_oa, AA=train_aa, AIOU=train_aiou, ALoss=train_aloss)

        net.eval()
        error = 0
        cm = np.zeros((N_LABELS, N_LABELS))
        test_aloss = "0"
        test_oa = "0"
        test_aa = "0"
        test_aiou = "0"
        with torch.no_grad():

            t = tqdm(
                test_loader,
                desc="  Test " + str(epoch),
                ncols=100,
                disable=_config["misc"]["disable_tqdm"],
            )
            for data in t:

                pts, features, targets, net_ids, net_support = get_data(data)

                outputs = net(
                    features, pts, support_points=net_support, indices=net_ids
                )
                loss = F.cross_entropy(outputs, targets)

                outputs_np = outputs.cpu().detach().numpy()
                pred_labels = np.argmax(outputs_np, axis=1)
                cm_ = confusion_matrix(
                    targets.cpu().numpy(), pred_labels, labels=list(range(N_LABELS))
                )
                cm += cm_
                error += loss.item()

                # point-wise scores on testing
                test_oa = "{:.5f}".format(metrics.stats_overall_accuracy(cm))
                test_aa = "{:.5f}".format(metrics.stats_accuracy_per_class(cm)[0])
                test_aiou = "{:.5f}".format(metrics.stats_iou_per_class(cm)[0])
                test_aloss = "{:.5e}".format(error / cm.sum())

                t.set_postfix(OA=test_oa, AA=test_aa, AIOU=test_aiou, ALoss=test_aloss)

        scheduler.step()

        # create the root folder
        os.makedirs(savedir_root, exist_ok=True)

        # save the checkpoint
        torch.save(
            {
                "epoch": epoch + 1,
                "state_dict": net.state_dict(),
                "optimizer": optimizer.state_dict(),
            },
            os.path.join(savedir_root, "checkpoint.pth"),
        )

        # write the logs
        logs = open(os.path.join(savedir_root, "logs.txt"), "a+")
        logs.write(str(epoch) + " ")
        logs.write(train_aloss + " ")
        logs.write(train_oa + " ")
        logs.write(train_aa + " ")
        logs.write(train_aiou + " ")
        logs.write(test_aloss + " ")
        logs.write(test_oa + " ")
        logs.write(test_aa + " ")
        logs.write(test_aiou + "\n")
        logs.flush()
        logs.close()

        # log for Sacred
        _run.log_scalar("trainOA", train_oa, epoch)
        _run.log_scalar("trainAA", train_aa, epoch)
        _run.log_scalar("trainAIoU", train_aiou, epoch)
        _run.log_scalar("trainLoss", train_aloss, epoch)
        _run.log_scalar("testOA", test_oa, epoch)
        _run.log_scalar("testAA", test_aa, epoch)
        _run.log_scalar("testAIoU", test_aiou, epoch)
        _run.log_scalar("testLoss", test_aloss, epoch)
示例#2
0
文件: train.py 项目: valeoai/FKAConv
def main(_run, _config):

    print(_config)

    savedir_root = _config["training"]["savedir"]
    device = torch.device(_config["misc"]["device"])

    # save the config file
    os.makedirs(savedir_root, exist_ok=True)
    save_config_file(eval(str(_config)), os.path.join(savedir_root, "config.yaml"))

    print("get the data path...", end="", flush=True)
    rootdir = _config["dataset"]["dir"]
    print("done")

    N_CLASSES = 50

    print("Creating network...", end="", flush=True)

    def network_function():
        return Network(
            1, N_CLASSES,
            get_conv(_config["network"]["backend_conv"]),
            get_search(_config["network"]["backend_search"]),
        )

    net = network_function()
    net.to(device)
    network_parameters = count_parameters(net)
    print("parameters", network_parameters)

    training_transformations = [
        lcp_transfo.UnitBallNormalize(),
        lcp_transfo.RandomSubSample(_config["dataset"]["npoints"]),
        lcp_transfo.NormalPerturbation(sigma=0.001)
    ]
    test_transformations = [
        lcp_transfo.UnitBallNormalize(),
        lcp_transfo.RandomSubSample(_config["dataset"]["npoints"]),
    ]

    print("Creating dataloader...", end="", flush=True)
    ds = Dataset(
        rootdir,
        'training',
        network_function=network_function,
        transformations_points=training_transformations
    )
    train_loader = torch.utils.data.DataLoader(
        ds,
        batch_size=_config["training"]["batchsize"],
        shuffle=True,
        num_workers=_config["misc"]["threads"],
    )
    ds_test = Dataset(
        rootdir,
        'test',
        network_function=network_function,
        transformations_points=test_transformations
    )
    test_loader = torch.utils.data.DataLoader(
        ds_test,
        batch_size=_config["training"]["batchsize"],
        shuffle=False,
        num_workers=_config["misc"]["threads"],
    )
    print("Done")


    # define weights
    print("Computing weights...", end="", flush=True)
    weights = torch.from_numpy(ds.get_weights()).float().to(device)
    print("Done")

    print("Creating optimizer...", end="", flush=True)
    optimizer = torch.optim.Adam(net.parameters(), lr=_config["training"]["lr_start"], eps=1e-3)
    epoch_start = 0
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer,
        _config["training"]["milestones"],
        gamma=_config["training"]["gamma"],
        last_epoch=epoch_start - 1,
    )
    print("Done")


    def get_data(data):

        pts = data["pts"].to(device)
        features = data["features"].to(device)
        seg = data["seg"].to(device)
        labels = data["label"]
        net_ids = data["net_indices"]
        net_pts = data["net_support"]
        for i in range(len(net_ids)):
            net_ids[i] = net_ids[i].to(device)
        for i in range(len(net_pts)):
            net_pts[i] = net_pts[i].to(device)

        return pts, features, seg, labels, net_ids, net_pts


    # create the log file
    for epoch in range(epoch_start, _config["training"]["epoch_nbr"]):

        # train
        net.train()
        cm = np.zeros((N_CLASSES, N_CLASSES))
        t = tqdm(
            train_loader,
            ncols=120,
            desc=f"Epoch {epoch}",
            disable=_config["misc"]["disable_tqdm"],
        )
        for data in t:

            pts, features, seg, labels, net_ids, net_pts = get_data(data)

            optimizer.zero_grad()
            outputs = net(features, pts, support_points=net_pts, indices=net_ids)
            loss = F.cross_entropy(outputs, seg, weight=weights)
            loss.backward()
            optimizer.step()

            outputs_np = outputs.cpu().detach().numpy()
            for i in range(pts.size(0)):
                # get the number of part for the shape
                object_label = labels[i]
                part_start, part_end = ds.category_range[object_label]

                outputs_np[i, :part_start] = -1e7
                outputs_np[i, part_end:] = -1e7

            output_np = np.argmax(outputs_np, axis=1).copy()
            target_np = seg.cpu().numpy().copy()

            cm_ = confusion_matrix(
                target_np.ravel(), output_np.ravel(), labels=list(range(N_CLASSES))
            )
            cm += cm_

            oa = "{:.3f}".format(metrics.stats_overall_accuracy(cm))
            aa = "{:.3f}".format(metrics.stats_accuracy_per_class(cm)[0])
            iou = "{:.3f}".format(metrics.stats_iou_per_class(cm)[0])

            t.set_postfix(OA=oa, AA=aa, IOU=iou)

        # eval (this is not the final evaluation, see dedicated evaluation)
        net.eval()
        with torch.no_grad():
            cm = np.zeros((N_CLASSES, N_CLASSES))
            t = tqdm(
                test_loader,
                ncols=120,
                desc=f"Test {epoch}",
                disable=_config["misc"]["disable_tqdm"],
            )
            for data in t:

                pts, features, seg, labels, net_ids, net_pts = get_data(data)

                outputs = net(features, pts, support_points=net_pts, indices=net_ids)
                loss = 0

                for i in range(pts.size(0)):
                    # get the number of part for the shape
                    object_label = labels[i]
                    part_start, part_end = ds_test.category_range[object_label]

                    outputs_ = (outputs[i, part_start:part_end]).unsqueeze(0)
                    seg_ = (seg[i] - part_start).unsqueeze(0)

                    loss = loss + weights[object_label] * F.cross_entropy(
                        outputs_, seg_
                    )

                outputs_np = outputs.cpu().detach().numpy()
                for i in range(pts.size(0)):
                    # get the number of part for the shape
                    object_label = labels[i]
                    part_start, part_end = ds_test.category_range[object_label]

                    outputs_np[i, :part_start] = -1e7
                    outputs_np[i, part_end:] = -1e7

                output_np = np.argmax(outputs_np, axis=1).copy()
                target_np = seg.cpu().numpy().copy()

                cm_ = confusion_matrix(
                    target_np.ravel(), output_np.ravel(), labels=list(range(N_CLASSES))
                )
                cm += cm_

                oa_test = "{:.3f}".format(metrics.stats_overall_accuracy(cm))
                aa_test = "{:.3f}".format(metrics.stats_accuracy_per_class(cm)[0])
                iou_test = "{:.3f}".format(metrics.stats_iou_per_class(cm)[0])

                t.set_postfix(OA=oa_test, AA=aa_test, IOU=iou_test)

        # scheduler update
        scheduler.step()

        # save the model
        os.makedirs(savedir_root, exist_ok=True)
        torch.save(
            {
                "epoch": epoch + 1,
                "state_dict": net.state_dict(),
                "optimizer": optimizer.state_dict(),
            },
            os.path.join(savedir_root, "checkpoint.pth"),
        )

        # write the logs
        logs = open(os.path.join(savedir_root, "log.txt"), "a+")
        logs.write(f"{epoch} {oa} {aa} {iou} {oa_test} {aa_test} {iou_test} \n")
        logs.close()

        _run.log_scalar("trainOA", oa, epoch)
        _run.log_scalar("trainAA", aa, epoch)
        _run.log_scalar("trainIoU", iou, epoch)
        _run.log_scalar("testOA", oa_test, epoch)
        _run.log_scalar("testAA", aa_test, epoch)
        _run.log_scalar("testIoU", iou_test, epoch)

    logs.close()
示例#3
0
def main(_run, _config):

    print(_config)

    savedir_root = _config["training"]["savedir"]
    device = torch.device(_config["misc"]["device"])

    # save the config file
    os.makedirs(savedir_root, exist_ok=True)
    save_config_file(eval(str(_config)), os.path.join(savedir_root, "config.yaml"))

    print("get the data path...", end="", flush=True)
    rootdir = os.path.join(_config["dataset"]["datasetdir"], _config["dataset"]["dataset"])
    print("done")

    filelist_train = os.path.join(rootdir, "train_files.txt")
    filelist_val = os.path.join(rootdir, "val_files.txt")
    filelist_test = os.path.join(rootdir, "test_files.txt")

    N_CLASSES = 50

    shapenet_labels = [
        ["Airplane", 4],
        ["Bag", 2],
        ["Cap", 2],
        ["Car", 4],
        ["Chair", 4],
        ["Earphone", 3],
        ["Guitar", 3],
        ["Knife", 2],
        ["Lamp", 4],
        ["Laptop", 2],
        ["Motorbike", 6],
        ["Mug", 2],
        ["Pistol", 3],
        ["Rocket", 3],
        ["Skateboard", 3],
        ["Table", 3],
    ]
    category_range = []
    count = 0
    for element in shapenet_labels:
        part_start = count
        count += element[1]
        part_end = count
        category_range.append([part_start, part_end])

    # Prepare inputs
    print("Preparing datasets...", end="", flush=True)
    (
        data_train,
        labels_shape_train,
        data_num_train,
        labels_pts_train,
        _,
    ) = data_utils.load_seg(filelist_train)
    data_val, labels_shape_val, data_num_val, labels_pts_val, _ = data_utils.load_seg(
        filelist_val
    )
    (
        data_test,
        labels_shape_test,
        data_num_test,
        labels_pts_test,
        _,
    ) = data_utils.load_seg(filelist_test)
    data_train = np.concatenate([data_train, data_val], axis=0)
    labels_shape_train = np.concatenate([labels_shape_train, labels_shape_val], axis=0)
    data_num_train = np.concatenate([data_num_train, data_num_val], axis=0)
    labels_pts_train = np.concatenate([labels_pts_train, labels_pts_val], axis=0)
    print("Done", data_train.shape)

    # define weights
    print("Computing weights...", end="", flush=True)
    frequences = [0 for i in range(len(shapenet_labels))]
    for i in range(len(shapenet_labels)):
        frequences[i] += (labels_shape_train == i).sum()
    for i in range(len(shapenet_labels)):
        frequences[i] /= shapenet_labels[i][1]
    frequences = np.array(frequences)
    frequences = frequences.mean() / frequences
    repeat_factor = [sh[1] for sh in shapenet_labels]
    frequences = np.repeat(frequences, repeat_factor)
    weights = torch.from_numpy(frequences).float().to(device)
    print("Done")

    print("Creating network...", end="", flush=True)

    def network_function():
        return get_network(
            _config["network"]["model"],
            in_channels=1,
            out_channels=N_CLASSES,
            backend_conv=_config["network"]["backend_conv"],
            backend_search=_config["network"]["backend_search"],
        )

    net = network_function()
    net.to(device)
    network_parameters = count_parameters(net)
    print("parameters", network_parameters)

    print("Creating dataloader...", end="", flush=True)
    ds = Dataset(
        data_train,
        data_num_train,
        labels_pts_train,
        labels_shape_train,
        npoints=_config["dataset"]["npoints"],
        training=True,
        network_function=network_function,
    )
    train_loader = torch.utils.data.DataLoader(
        ds,
        batch_size=_config["training"]["batchsize"],
        shuffle=True,
        num_workers=_config["misc"]["threads"],
    )
    ds_test = Dataset(
        data_test,
        data_num_test,
        labels_pts_test,
        labels_shape_test,
        npoints=_config["dataset"]["npoints"],
        training=False,
        network_function=network_function,
    )
    test_loader = torch.utils.data.DataLoader(
        ds_test,
        batch_size=_config["training"]["batchsize"],
        shuffle=False,
        num_workers=_config["misc"]["threads"],
    )
    print("Done")

    print("Creating optimizer...", end="", flush=True)
    optimizer = torch.optim.Adam(net.parameters(), lr=_config["training"]["lr_start"], eps=1e-3)
    epoch_start = 0
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer,
        _config["training"]["milestones"],
        gamma=_config["training"]["gamma"],
        last_epoch=epoch_start - 1,
    )
    print("Done")

    # create the log file
    for epoch in range(epoch_start, _config["training"]["epoch_nbr"]):

        # train
        net.train()
        cm = np.zeros((N_CLASSES, N_CLASSES))
        t = tqdm(
            train_loader,
            ncols=120,
            desc=f"Epoch {epoch}",
            disable=_config["misc"]["disable_tqdm"],
        )
        for data in t:

            pts = data["pts"].to(device)
            features = data["features"].to(device)
            seg = data["seg"].to(device)
            labels = data["label"]
            net_ids = data["net_indices"]
            net_pts = data["net_support"]
            for i in range(len(net_ids)):
                net_ids[i] = net_ids[i].to(device)
            for i in range(len(net_pts)):
                net_pts[i] = net_pts[i].to(device)

            optimizer.zero_grad()
            outputs = net(features, pts, support_points=net_pts, indices=net_ids)
            loss = F.cross_entropy(outputs, seg, weight=weights)

            loss.backward()
            optimizer.step()

            outputs_np = outputs.cpu().detach().numpy()
            for i in range(pts.size(0)):
                # get the number of part for the shape
                object_label = labels[i]
                part_start, part_end = category_range[object_label]

                outputs_np[i, :part_start] = -1e7
                outputs_np[i, part_end:] = -1e7

            output_np = np.argmax(outputs_np, axis=1).copy()
            target_np = seg.cpu().numpy().copy()

            cm_ = confusion_matrix(
                target_np.ravel(), output_np.ravel(), labels=list(range(N_CLASSES))
            )
            cm += cm_

            oa = "{:.3f}".format(metrics.stats_overall_accuracy(cm))
            aa = "{:.3f}".format(metrics.stats_accuracy_per_class(cm)[0])
            iou = "{:.3f}".format(metrics.stats_iou_per_class(cm)[0])

            t.set_postfix(OA=oa, AA=aa, IOU=iou)

        # eval (this is not the final evaluation, see dedicated evaluation)
        net.eval()
        with torch.no_grad():
            cm = np.zeros((N_CLASSES, N_CLASSES))
            t = tqdm(
                test_loader,
                ncols=120,
                desc=f"Test {epoch}",
                disable=_config["misc"]["disable_tqdm"],
            )
            for data in t:
                pts = data["pts"].to(device)
                features = data["features"].to(device)
                seg = data["seg"].to(device)
                labels = data["label"]
                net_ids = data["net_indices"]
                net_pts = data["net_support"]
                for i in range(len(net_ids)):
                    net_ids[i] = net_ids[i].to(device)
                for i in range(len(net_pts)):
                    net_pts[i] = net_pts[i].to(device)

                outputs = net(features, pts, support_points=net_pts, indices=net_ids)
                loss = 0

                for i in range(pts.size(0)):
                    # get the number of part for the shape
                    object_label = labels[i]
                    part_start, part_end = category_range[object_label]

                    outputs_ = (outputs[i, part_start:part_end]).unsqueeze(0)
                    seg_ = (seg[i] - part_start).unsqueeze(0)

                    loss = loss + weights[object_label] * F.cross_entropy(
                        outputs_, seg_
                    )

                outputs_np = outputs.cpu().detach().numpy()
                for i in range(pts.size(0)):
                    # get the number of part for the shape
                    object_label = labels[i]
                    part_start, part_end = category_range[object_label]

                    outputs_np[i, :part_start] = -1e7
                    outputs_np[i, part_end:] = -1e7

                output_np = np.argmax(outputs_np, axis=1).copy()
                target_np = seg.cpu().numpy().copy()

                cm_ = confusion_matrix(
                    target_np.ravel(), output_np.ravel(), labels=list(range(N_CLASSES))
                )
                cm += cm_

                oa_test = "{:.3f}".format(metrics.stats_overall_accuracy(cm))
                aa_test = "{:.3f}".format(metrics.stats_accuracy_per_class(cm)[0])
                iou_test = "{:.3f}".format(metrics.stats_iou_per_class(cm)[0])

                t.set_postfix(OA=oa_test, AA=aa_test, IOU=iou_test)

        # scheduler update
        scheduler.step()

        # save the model
        os.makedirs(savedir_root, exist_ok=True)
        torch.save(
            {
                "epoch": epoch + 1,
                "state_dict": net.state_dict(),
                "optimizer": optimizer.state_dict(),
            },
            os.path.join(savedir_root, "checkpoint.pth"),
        )

        # write the logs
        logs = open(os.path.join(savedir_root, "log.txt"), "a+")
        logs.write(f"{epoch} {oa} {aa} {iou} {oa_test} {aa_test} {iou_test} \n")
        logs.close()

        _run.log_scalar("trainOA", oa, epoch)
        _run.log_scalar("trainAA", aa, epoch)
        _run.log_scalar("trainIoU", iou, epoch)
        _run.log_scalar("testOA", oa_test, epoch)
        _run.log_scalar("testAA", aa_test, epoch)
        _run.log_scalar("testIoU", iou_test, epoch)

    logs.close()