コード例 #1
0
        "experiment specifications in 'specs.json', and logging will be " +
        "done in this directory as well.",
    )
    arg_parser.add_argument(
        "--continue",
        "-c",
        dest="continue_from",
        help="A snapshot to continue from. This can be 'latest' to continue" +
        "from the latest running snapshot, or an integer corresponding to " +
        "an epochal snapshot.",
    )
    arg_parser.add_argument(
        "--batch_split",
        dest="batch_split",
        default=1,
        help="This splits the batch into separate subbatches which are " +
        "processed separately, with gradients accumulated across all " +
        "subbatches. This allows for training with large effective batch " +
        "sizes in memory constrained environments.",
    )
    arg_parser.add_argument("--gpu", type=str, default="0")

    deep_sdf.add_common_args(arg_parser)

    args = arg_parser.parse_args()
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    deep_sdf.configure_logging(args)

    main_function(args.experiment_directory, args.continue_from,
                  int(args.batch_split))
コード例 #2
0
ファイル: reconstruct.py プロジェクト: edgar-tr/patchnets
def main_function_reconstruction(args):

    deep_sdf.configure_logging(args.loglevel, args.logfile)

    def empirical_stat(latent_vecs, indices):
        lat_mat = torch.zeros(0).cuda()
        for ind in indices:
            lat_mat = torch.cat([lat_mat, latent_vecs[ind]], 0)
        mean = torch.mean(lat_mat, 0)
        var = torch.var(lat_mat, 0)
        return mean, var

    specs_filename = os.path.join(args.results_folder, "specs.json")

    if not os.path.isfile(specs_filename):
        raise Exception(
            'The experiment directory does not include specifications file "specs.json"'
        )

    with open(specs_filename) as specs:
        specs = "\n".join([
            line for line in specs.readlines() if line.strip()[:2] != "//"
        ])  # remove comment lines
        specs = json.loads(specs)

    arch = __import__("networks." + specs["NetworkArch"], fromlist=["Decoder"])

    patch_latent_size = specs["PatchCodeLength"]
    mixture_latent_size = specs["MixtureCodeLength"]

    decoder = arch.Decoder(patch_latent_size=patch_latent_size,
                           mixture_latent_size=mixture_latent_size,
                           encoder=grid_encoder_param,
                           **specs["NetworkSpecs"])

    decoder = torch.nn.DataParallel(decoder)

    saved_model_state = torch.load(
        os.path.join(args.results_folder, ws.model_params_subdir,
                     args.checkpoint + ".pth"))
    saved_model_epoch = saved_model_state["epoch"]

    decoder.load_state_dict(saved_model_state["model_state_dict"])

    decoder = decoder.module.cuda()

    with open(args.split_filename, "r") as f:
        split = json.load(f)

    npz_filenames = deep_sdf.data.get_instance_filenames(
        args.data_source, split)
    npz_filenames = [(npz,
                      os.path.join(args.data_source, ws.sdf_samples_subdir,
                                   npz)) for npz in npz_filenames]

    #random.shuffle(npz_filenames)

    logging.debug(decoder)

    err_sum = 0.0
    repeat = 1
    save_latvec_only = False
    rerun = 0

    reconstruction_dir = os.path.join(
        args.output_folder, ws.reconstructions_subdir, "videos/",
        str(saved_model_epoch) + "_" + ("opt" if args.optimize else "noopt"))

    if not os.path.isdir(reconstruction_dir):
        os.makedirs(reconstruction_dir)

    reconstruction_meshes_dir = os.path.join(reconstruction_dir,
                                             ws.reconstruction_meshes_subdir)

    print(reconstruction_meshes_dir, flush=True)

    if not os.path.isdir(reconstruction_meshes_dir):
        os.makedirs(reconstruction_meshes_dir)

    reconstruction_codes_dir = os.path.join(reconstruction_dir,
                                            ws.reconstruction_codes_subdir)
    if not os.path.isdir(reconstruction_codes_dir):
        os.makedirs(reconstruction_codes_dir)

    with open(reconstruction_meshes_dir + "/MeshFiles", "w") as mesh_files:
        for npz, _ in npz_filenames:
            mesh_files.write(npz[:-4] + ".obj\n")

    for ii, (npz, full_filename) in enumerate(npz_filenames):

        if "npz" not in npz:
            continue

        logging.debug("loading {}".format(npz))

        data_sdf = deep_sdf.data.read_sdf_samples_into_ram(full_filename)

        for k in range(repeat):

            if rerun > 1:
                mesh_filename = os.path.join(reconstruction_meshes_dir,
                                             npz[:-4] + "-" + str(k + rerun))
                latent_filename = os.path.join(
                    reconstruction_codes_dir,
                    npz[:-4] + "-" + str(k + rerun) + ".pth")
            else:
                mesh_filename = os.path.join(reconstruction_meshes_dir,
                                             npz[:-4])
                latent_filename = os.path.join(reconstruction_codes_dir,
                                               npz[:-4] + ".pth")

            if (args.skip and os.path.isfile(mesh_filename + ".obj")
                    and os.path.isfile(latent_filename)):
                continue

            logging.info("reconstructing {}".format(npz))

            data_sdf[0] = data_sdf[0][torch.randperm(data_sdf[0].shape[0])]
            data_sdf[1] = data_sdf[1][torch.randperm(data_sdf[1].shape[0])]

            start = time.time()

            err, latent = reconstruct(
                decoder,
                int(args.iterations),
                mixture_latent_size,
                data_sdf,
                0.01,  # [emp_mean,emp_var],
                0.1,
                num_samples=8000,
                lr=5e-3,
                l2reg=True)

            logging.debug("reconstruct time: {}".format(time.time() - start))
            err_sum += err
            logging.debug("current_error avg: {}".format((err_sum / (ii + 1))))
            logging.debug(ii)

            logging.debug("latent: {}".format(latent.detach().cpu().numpy()))

            decoder.eval()

            if not os.path.exists(os.path.dirname(mesh_filename)):
                os.makedirs(os.path.dirname(mesh_filename))

            if not save_latvec_only:
                start = time.time()
                with torch.no_grad():
                    max_batch = int(2**18)
                    if hasattr(decoder, "num_patches"):
                        max_batch /= decoder.num_patches + 2
                    deep_sdf.mesh.create_mesh(decoder,
                                              latent,
                                              mesh_filename,
                                              N=256,
                                              max_batch=int(max_batch))
                logging.debug("total time: {}".format(time.time() - start))

            if not os.path.exists(os.path.dirname(latent_filename)):
                os.makedirs(os.path.dirname(latent_filename))

            torch.save(latent.unsqueeze(0), latent_filename)