Exemplo n.º 1
0
def _process_single_cat_file(file, target_dir):
    """
    Processes a single file of the cat dataset
    
    Parameters
    ----------
    file : str
        the file to process
    target_dir : str
        the target directory
    
    """

    file = os.path.abspath(file)
    target_dir = os.path.abspath(target_dir)

    pd_frame = pd.read_csv(str(file) + ".cat", sep=' ', header=None)
    landmarks = (pd_frame.as_matrix()[0][1:-1]).reshape((-1, 2))
    # switch xy
    landmarks[:, [0, 1]] = landmarks[:, [1, 0]]

    target_file = os.path.join(
        target_dir,
        os.path.split(os.path.split(file)[0])[-1] + "_" +
        os.path.split(file)[-1])

    # export landmarks
    pts_exporter(landmarks, str(target_file.rsplit(".", 1)[0]) + ".pts")

    # move image file
    shutil.move(file, target_file)
    os.remove(file + ".cat")
Exemplo n.º 2
0
def test_io():
    lmks = np.loadtxt(
        os.path.join(
            os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
            "example_files", "lenna.txt"))

    ljson_exporter(lmks, "./lmks.ljson")
    assert os.path.isfile("./lmks.ljson")

    pts_exporter(lmks, "./lmks.pts")
    assert os.path.isfile("./lmks.pts")

    lmks_ljson = ljson_importer("./lmks.ljson")
    assert (lmks == lmks_ljson).all()

    lmks_pts = pts_importer("./lmks.pts")
    assert (lmks == lmks_ljson).all()

    os.remove("./lmks.ljson")
    os.remove("./lmks.pts")
Exemplo n.º 3
0
def predict():
    """
    Predicts file directory with network specified by files to output path
    
    """

    import numpy as np
    import torch
    from tqdm import tqdm
    import os
    from matplotlib import pyplot as plt
    import sys

    from shapenet.utils.load_config_file import Config
    from shapenet.layer import HomogeneousShapeLayer
    from shapenet.networks import SingleShapeNetwork
    from shapedata.single_shape import SingleShapeDataProcessing, \
                                        SingleShapeSingleImage2D
    from shapedata.io import pts_exporter
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("-v",
                        "--visualize",
                        action="store_true",
                        help="If Flag is specified, results will be plotted")
    parser.add_argument("-d", "--in_path", type=str, help="Input Data Dir")
    parser.add_argument("-s",
                        "--out_path",
                        default="./outputs",
                        type=str,
                        help="Output Data Dir")
    parser.add_argument("-w", "--weight_file", type=str, help="Model Weights")
    parser.add_argument("-c", "--config_file", type=str, help="Configuration")

    args = parser.parse_args()
    config = Config()
    config_dict = config(os.path.abspath(args.config_file))

    try:
        net = torch.jit.load(os.path.abspath(args.weight_file))
        net.eval()
        net.cpu()

    except RuntimeError:
        net_layer = HomogeneousShapeLayer

        if config_dict["training"].pop("mixed_prec", False):
            try:
                from apex import amp
                amp.init()
            except:
                pass

        shapes = np.load(
            os.path.abspath(config_dict["layer"].pop("pca_path"))
        )["shapes"][:config_dict["layer"].pop("num_shape_params") + 1]

        net = SingleShapeNetwork(net_layer, {
            "shapes": shapes,
            **config_dict["layer"]
        },
                                 img_size=config_dict["data"]["img_size"],
                                 **config_dict["network"])

        try:
            net.load_state_dict(
                torch.load(os.path.abspath(
                    args.weight_file))["state_dict"]["model"])
        except KeyError:
            net.load_state_dict(torch.load(os.path.abspath(args.weight_file)))

        net = net.to("cpu")
        net = net.eval()

    data = SingleShapeDataProcessing._get_files(os.path.abspath(args.in_path),
                                                extensions=[".png", ".jpg"])

    def process_sample(sample, img_size, net, device, crop=0.1):
        lmk_bounds = sample.get_landmark_bounds(sample.lmk)
        min_y, min_x, max_y, max_x = lmk_bounds
        range_x = max_x - min_x
        range_y = max_y - min_y

        max_range = max(range_x, range_y) * (1 + crop)

        center_x = min_x + range_x / 2
        center_y = min_y + range_y / 2

        tmp = sample.crop(center_y - max_range / 2, center_x - max_range / 2,
                          center_y + max_range / 2, center_x + max_range / 2)

        img_tensor = torch.from_numpy(tmp.to_grayscale().resize(
            (img_size, img_size)).img.transpose(2, 0, 1)).to(
                torch.float).unsqueeze(0).to(device)

        pred = net(img_tensor).cpu().numpy()[0]

        pred = pred * np.array([max_range / img_size, max_range / img_size])
        pred = pred + np.asarray(
            [center_y - max_range / 2, center_x - max_range / 2])

        return pred

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    with torch.no_grad():

        if torch.cuda.is_available():
            net = net.cuda()

        if args.visualize:
            pred_path = os.path.join(os.path.abspath(args.out_path), "pred")
            vis_path = os.path.join(os.path.abspath(args.out_path),
                                    "visualization")
            os.makedirs(vis_path, exist_ok=True)
        else:
            pred_path = os.path.abspath(args.out_path)

        os.makedirs(pred_path, exist_ok=True)
        print(data)

        for idx, file in enumerate(tqdm(data)):

            _data = SingleShapeSingleImage2D.from_files(file)

            # pred = process_sample(_data, img_size=net.img_size, net=net,
            #                       device=device)
            pred = process_sample(_data, img_size=224, net=net, device=device)

            fname = os.path.split(_data.img_file)[-1].rsplit(".", 1)[0]

            if args.visualize:
                view_kwargs = {}
                if _data.is_gray:
                    view_kwargs["cmap"] = "gray"
                fig = _data.view(True, **view_kwargs)
                plt.gca().scatter(pred[:, 1], pred[:, 0], s=5, c="C1")
                plt.gca().legend(["GT", "Pred"])
                # plt.gca().scatter( pred[:, 0], s=5, c="C1")
                # plt.gca().legend([ "Pred"])
                plt.gcf().savefig(os.path.join(vis_path, fname + ".png"))
                plt.close()

            _data.save(pred_path, fname, "PTS")
            pts_exporter(pred, os.path.join(pred_path, fname + "_pred.pts"))
            print('pred:', pred)