def main_function(experiment_directory, data_source, continue_from, batch_split):

    logging.info("running " + experiment_directory)

    # backup code
    now = datetime.datetime.now()
    code_bk_path = os.path.join(
        experiment_directory, 'code_bk_%s.tar.gz' % now.strftime('%Y_%m_%d_%H_%M_%S'))
    ws.create_code_snapshot('./', code_bk_path,
                            extensions=('.py', '.json', '.cpp', '.cu', '.h', '.sh'),
                            exclude=('examples', 'third-party', 'bin'))

    specs = ws.load_experiment_specifications(experiment_directory)

    logging.info("Experiment description: \n" + specs["Description"])

    # data_source = specs["DataSource"]
    train_split_file = specs["TrainSplit"]

    arch = __import__("networks." + specs["NetworkArch"], fromlist=["Decoder"])

    logging.info(specs["NetworkSpecs"])

    latent_size = specs["CodeLength"]

    checkpoints = list(
        range(
            specs["SnapshotFrequency"],
            specs["NumEpochs"] + 1,
            specs["SnapshotFrequency"],
        )
    )

    for checkpoint in specs["AdditionalSnapshots"]:
        checkpoints.append(checkpoint)
    checkpoints.sort()

    lr_schedules = get_learning_rate_schedules(specs)

    grad_clip = get_spec_with_default(specs, "GradientClipNorm", None)
    if grad_clip is not None:
        logging.debug("clipping gradients to max norm {}".format(grad_clip))

    def save_latest(epoch):

        ws.save_model(experiment_directory, "latest.pth", decoder, epoch)
        ws.save_optimizer(experiment_directory, "latest.pth", optimizer_all, epoch)
        ws.save_latent_vectors(experiment_directory, "latest.pth", lat_vecs, epoch)

    def save_checkpoints(epoch):

        ws.save_model(experiment_directory, str(epoch) + ".pth", decoder, epoch)
        ws.save_optimizer(experiment_directory, str(epoch) + ".pth", optimizer_all, epoch)
        ws.save_latent_vectors(experiment_directory, str(epoch) + ".pth", lat_vecs, epoch)

    def signal_handler(sig, frame):
        logging.info("Stopping early...")
        sys.exit(0)

    def adjust_learning_rate(lr_schedules, optimizer, epoch):

        for i, param_group in enumerate(optimizer.param_groups):
            param_group["lr"] = lr_schedules[i].get_learning_rate(epoch)

    def empirical_stat(latent_vecs, indices):
        lat_mat = torch.zeros(0).cuda()
        for ind in indices:
            lat_mat = torch.cat([lat_mat, latent_vecs[ind]], 0)
        mean = torch.mean(lat_mat, 0)
        var = torch.var(lat_mat, 0)
        return mean, var

    signal.signal(signal.SIGINT, signal_handler)

    num_samp_per_scene = specs["SamplesPerScene"]
    scene_per_batch = specs["ScenesPerBatch"]
    clamp_dist = specs["ClampingDistance"]
    minT = -clamp_dist
    maxT = clamp_dist
    enforce_minmax = True

    assert(scene_per_batch % batch_split == 0)  # requirements for computing chamfer loss
    scene_per_split = scene_per_batch // batch_split

    do_code_regularization = get_spec_with_default(specs, "CodeRegularization", True)
    code_reg_lambda = get_spec_with_default(specs, "CodeRegularizationLambda", 1e-4)

    code_bound = get_spec_with_default(specs, "CodeBound", None)

    decoder = arch.Decoder(latent_size, **specs["NetworkSpecs"]).cuda()

    logging.info("training with {} GPU(s)".format(torch.cuda.device_count()))

    # if torch.cuda.device_count() > 1:
    decoder = torch.nn.DataParallel(decoder)

    num_epochs = specs["NumEpochs"]
    log_frequency = get_spec_with_default(specs, "LogFrequency", 10)

    with open(train_split_file, "r") as f:
        train_split = json.load(f)

    sdf_dataset = deep_sdf.data.SDFSamples(
        data_source, train_split, num_samp_per_scene, load_ram=True
    )

    if sdf_dataset.load_ram:
        num_data_loader_threads = 0
    else:
        num_data_loader_threads = get_spec_with_default(specs, "DataLoaderThreads", 1)
    logging.debug("loading data with {} threads".format(num_data_loader_threads))

    sdf_loader = data_utils.DataLoader(
        sdf_dataset,
        batch_size=scene_per_batch,
        shuffle=True,
        num_workers=num_data_loader_threads,
        drop_last=True,
    )

    logging.debug("torch num_threads: {}".format(torch.get_num_threads()))

    num_scenes = len(sdf_dataset)

    logging.info("There are {} scenes".format(num_scenes))

    logging.info(decoder)

    lat_vecs = torch.nn.Embedding(num_scenes, latent_size, max_norm=code_bound)
    torch.nn.init.normal_(
        lat_vecs.weight.data,
        0.0,
        get_spec_with_default(specs, "CodeInitStdDev", 1.0) / math.sqrt(latent_size),
    )

    logging.debug(
        "initialized with mean magnitude {}".format(
            get_mean_latent_vector_magnitude(lat_vecs)
        )
    )

    loss_l1 = torch.nn.L1Loss(reduction="sum")
    loss_l1_soft = loss.SoftL1Loss(reduction="sum")
    loss_lp = torch.nn.DataParallel(loss.LipschitzLoss(k=0.5, reduction="sum"))
    huber_fn = loss.HuberFunc(reduction="sum")

    optimizer_all = torch.optim.Adam(
        [
            {
                "params": decoder.module.warper.parameters(),
                "lr": lr_schedules[0].get_learning_rate(0),
            },
            {
                "params": decoder.module.sdf_decoder.parameters(),
                "lr": lr_schedules[1].get_learning_rate(0),
            },
            {
                "params": lat_vecs.parameters(),
                "lr": lr_schedules[2].get_learning_rate(0),
            },
        ]
    )

    tensorboard_saver = ws.create_tensorboard_saver(experiment_directory)

    loss_log = []
    lr_log = []
    lat_mag_log = []
    timing_log = []
    param_mag_log = {}

    start_epoch = 1

    if continue_from is not None:
        if not os.path.exists(os.path.join(experiment_directory, ws.latent_codes_subdir, continue_from + ".pth")) or \
                not os.path.exists(os.path.join(experiment_directory, ws.model_params_subdir, continue_from + ".pth")) or \
                not os.path.exists(os.path.join(experiment_directory, ws.optimizer_params_subdir, continue_from + ".pth")):
            logging.warning('"{}" does not exist! Ignoring this argument...'.format(continue_from))
        else:
            logging.info('continuing from "{}"'.format(continue_from))

            lat_epoch = ws.load_latent_vectors(
                experiment_directory, continue_from + ".pth", lat_vecs
            )

            model_epoch = ws.load_model_parameters(
                experiment_directory, continue_from, decoder
            )

            optimizer_epoch = ws.load_optimizer(
                experiment_directory, continue_from + ".pth", optimizer_all
            )

            loss_log, lr_log, timing_log, lat_mag_log, param_mag_log, log_epoch = ws.load_logs(
                experiment_directory
            )

            if not log_epoch == model_epoch:
                loss_log, lr_log, timing_log, lat_mag_log, param_mag_log = ws.clip_logs(
                    loss_log, lr_log, timing_log, lat_mag_log, param_mag_log, model_epoch
                )

            if not (model_epoch == optimizer_epoch and model_epoch == lat_epoch):
                raise RuntimeError(
                    "epoch mismatch: {} vs {} vs {} vs {}".format(
                        model_epoch, optimizer_epoch, lat_epoch, log_epoch
                    )
                )

            start_epoch = model_epoch + 1

            logging.debug("loaded")

    logging.info("starting from epoch {}".format(start_epoch))

    logging.info(
        "Number of decoder parameters: {}".format(
            sum(p.data.nelement() for p in decoder.parameters())
        )
    )
    logging.info(
        "Number of shape code parameters: {} (# codes {}, code dim {})".format(
            lat_vecs.num_embeddings * lat_vecs.embedding_dim,
            lat_vecs.num_embeddings,
            lat_vecs.embedding_dim,
        )
    )

    use_curriculum = get_spec_with_default(specs, "UseCurriculum", False)

    use_pointwise_loss = get_spec_with_default(specs, "UsePointwiseLoss", False)
    pointwise_loss_weight = get_spec_with_default(specs, "PointwiseLossWeight", 0.0)

    use_pointpair_loss = get_spec_with_default(specs, "UsePointpairLoss", False)
    pointpair_loss_weight = get_spec_with_default(specs, "PointpairLossWeight", 0.0)

    logging.info("pointwise_loss_weight = {}, pointpair_loss_weight = {}".format(
        pointwise_loss_weight, pointpair_loss_weight))

    for epoch in range(start_epoch, num_epochs + 1):

        start = time.time()

        logging.info("epoch {}...".format(epoch))

        decoder.train()

        adjust_learning_rate(lr_schedules, optimizer_all, epoch)

        batch_num = len(sdf_loader)
        for bi, (sdf_data, indices) in enumerate(sdf_loader):

            # Process the input data
            sdf_data = sdf_data.reshape(-1, 4)

            num_sdf_samples = sdf_data.shape[0]

            sdf_data.requires_grad = False

            xyz = sdf_data[:, 0:3]
            sdf_gt = sdf_data[:, 3].unsqueeze(1)

            if enforce_minmax:
                sdf_gt = torch.clamp(sdf_gt, minT, maxT)

            xyz = torch.chunk(xyz, batch_split)
            indices = torch.chunk(
                indices.unsqueeze(-1).repeat(1, num_samp_per_scene).view(-1),
                batch_split,
            )

            sdf_gt = torch.chunk(sdf_gt, batch_split)

            batch_loss_sdf = 0.0
            batch_loss_pw = 0.0
            batch_loss_reg = 0.0
            batch_loss_pp = 0.0
            batch_loss = 0.0

            optimizer_all.zero_grad()

            for i in range(batch_split):

                batch_vecs = lat_vecs(indices[i])

                input = torch.cat([batch_vecs, xyz[i]], dim=1)
                xyz_ = xyz[i].cuda()

                # NN optimization
                warped_xyz_list, pred_sdf_list, _ = decoder(
                    input, output_warped_points=True, output_warping_param=True)

                if enforce_minmax:
                    # pred_sdf = pred_sdf * clamp_dist * 1.0
                    for k in range(len(pred_sdf_list)):
                        pred_sdf_list[k] = torch.clamp(pred_sdf_list[k], minT, maxT)

                if use_curriculum:
                    sdf_loss = apply_curriculum_l1_loss(
                        pred_sdf_list, sdf_gt[i].cuda(), loss_l1_soft, num_sdf_samples)
                else:
                    sdf_loss = loss_l1(pred_sdf_list[-1], sdf_gt[i].cuda()) / num_sdf_samples
                batch_loss_sdf += sdf_loss.item()
                chunk_loss = sdf_loss

                if do_code_regularization:
                    l2_size_loss = torch.sum(torch.norm(batch_vecs, dim=1))
                    reg_loss = l2_size_loss / num_sdf_samples
                    chunk_loss += code_reg_lambda * min(1.0, epoch / 100) * reg_loss.cuda()
                    batch_loss_reg += reg_loss.item()

                if use_pointwise_loss:
                    if use_curriculum:
                        pw_loss = apply_pointwise_reg(warped_xyz_list, xyz_, huber_fn, num_sdf_samples)
                    else:
                        pw_loss = apply_pointwise_reg(warped_xyz_list[-1:], xyz_, huber_fn, num_sdf_samples)
                    batch_loss_pw += pw_loss.item()
                    chunk_loss = chunk_loss + pw_loss.cuda() * pointwise_loss_weight * max(1.0, 10.0 * (1 - epoch / 100))

                if use_pointpair_loss:
                    if use_curriculum:
                        lp_loss = apply_pointpair_reg(warped_xyz_list, xyz_, loss_lp, scene_per_split, num_sdf_samples)
                    else:
                        lp_loss = apply_pointpair_reg(warped_xyz_list[-1:], xyz_, loss_lp, scene_per_split, num_sdf_samples)
                    batch_loss_pp += lp_loss.item()
                    chunk_loss += lp_loss.cuda() * pointpair_loss_weight * min(1.0, epoch / 100)

                chunk_loss.backward()
                batch_loss += chunk_loss.item()

            logging.debug("sdf_loss = {:.9f}, reg_loss = {:.9f}, pw_loss = {:.9f}, pp_loss = {:.9f}".format(
                batch_loss_sdf, batch_loss_reg, batch_loss_pw, batch_loss_pp))

            ws.save_tensorboard_logs(
                tensorboard_saver, epoch*batch_num + bi,
                loss_sdf=batch_loss_sdf, loss_pw=batch_loss_pw, loss_reg=batch_loss_reg,
                loss_pp=batch_loss_pp, loss_=batch_loss)

            loss_log.append(batch_loss)

            if grad_clip is not None:

                torch.nn.utils.clip_grad_norm_(decoder.parameters(), grad_clip)

            optimizer_all.step()

            # release memory
            del warped_xyz_list, pred_sdf_list, sdf_loss, pw_loss, \
                lp_loss, batch_loss_sdf, batch_loss_reg, batch_loss_pp, batch_loss_pw, batch_loss, chunk_loss

        end = time.time()

        seconds_elapsed = end - start
        timing_log.append(seconds_elapsed)

        lr_log.append([schedule.get_learning_rate(epoch) for schedule in lr_schedules])

        lat_mag_log.append(get_mean_latent_vector_magnitude(lat_vecs))

        append_parameter_magnitudes(param_mag_log, decoder)

        if epoch in checkpoints:
            save_checkpoints(epoch)

        if epoch % log_frequency == 0:

            save_latest(epoch)
            ws.save_logs(
                experiment_directory,
                loss_log,
                lr_log,
                timing_log,
                lat_mag_log,
                param_mag_log,
                epoch,
            )
def code_to_mesh(results_folder, checkpoint, keep_normalized=False):

    specs_filename = os.path.join(results_folder, "specs.json")

    if not os.path.isfile(specs_filename):
        raise Exception(
            'The experiment directory does not include specifications file "specs.json"'
        )

    specs = json.load(open(specs_filename))

    arch = __import__("networks." + specs["NetworkArch"], fromlist=["Decoder"])

    latent_size = specs["CodeLength"]

    decoder = arch.Decoder(latent_size, **specs["NetworkSpecs"])

    decoder = torch.nn.DataParallel(decoder)

    saved_model_state = torch.load(
        os.path.join(results_folder, ws.model_params_subdir, checkpoint + ".pth")
    )
    saved_model_epoch = saved_model_state["epoch"]

    decoder.load_state_dict(saved_model_state["model_state_dict"])

    decoder = decoder.module.cuda()

    decoder.eval()

    latent_vectors = ws.load_latent_vectors(results_folder, checkpoint)

    train_split_file = specs["TrainSplit"]

    with open(train_split_file, "r") as f:
        train_split = json.load(f)

    data_source = specs["DataSource"]

    instance_filenames = deep_sdf.data.get_instance_filenames(data_source, train_split)

    print(len(instance_filenames), " vs ", len(latent_vectors))

    for i, latent_vector in enumerate(latent_vectors):

        dataset_name, class_name, instance_name = instance_filenames[i].split("/")
        instance_name = instance_name.split(".")[0]

        print("{} {} {}".format(dataset_name, class_name, instance_name))

        mesh_dir = os.path.join(
            results_folder,
            ws.training_meshes_subdir,
            str(saved_model_epoch),
            dataset_name,
            class_name,
        )
        print(mesh_dir)

        if not os.path.isdir(mesh_dir):
            os.makedirs(mesh_dir)

        mesh_filename = os.path.join(mesh_dir, instance_name)

        print(instance_filenames[i])

        offset = None
        scale = None

        if not keep_normalized:

            normalization_params = np.load(
                ws.get_normalization_params_filename(
                    data_source, dataset_name, class_name, instance_name
                )
            )
            offset = normalization_params["offset"]
            scale = normalization_params["scale"]

        with torch.no_grad():
            deep_sdf.mesh.create_mesh(
                decoder,
                latent_vector,
                mesh_filename,
                N=256,
                max_batch=int(2 ** 18),
                offset=offset,
                scale=scale,
            )