def main_function(experiment_directory, continue_from, batch_split, device): logging.debug("running " + experiment_directory) specs = ws.load_experiment_specifications(experiment_directory) logging.info("Experiment description: \n" + specs["Description"][0]) data_source = specs["DataSource"] train_split_file = specs["TrainSplit"] arch = __import__("networks." + specs["NetworkArch"], fromlist=["Decoder"]) logging.debug(specs["NetworkSpecs"]) latent_size = specs["CodeLength"] checkpoints = list( range( specs["SnapshotFrequency"], specs["NumEpochs"] + 1, specs["SnapshotFrequency"], ) ) for checkpoint in specs["AdditionalSnapshots"]: checkpoints.append(checkpoint) checkpoints.sort() lr_schedules = get_learning_rate_schedules(specs) grad_clip = get_spec_with_default(specs, "GradientClipNorm", None) if grad_clip is not None: logging.debug("clipping gradients to max norm {}".format(grad_clip)) def save_latest(epoch): save_model(experiment_directory, "latest.pth", decoder, epoch) save_optimizer(experiment_directory, "latest.pth", optimizer_all, epoch) save_latent_vectors(experiment_directory, "latest.pth", lat_vecs, epoch) def save_checkpoints(epoch): save_model(experiment_directory, str(epoch) + ".pth", decoder, epoch) save_optimizer(experiment_directory, str(epoch) + ".pth", optimizer_all, epoch) save_latent_vectors(experiment_directory, str(epoch) + ".pth", lat_vecs, epoch) def signal_handler(sig, frame): logging.info("Stopping early...") sys.exit(0) def adjust_learning_rate(lr_schedules, optimizer, epoch): for i, param_group in enumerate(optimizer.param_groups): param_group["lr"] = lr_schedules[i].get_learning_rate(epoch) def empirical_stat(latent_vecs, indices): lat_mat = torch.zeros(0).cuda() for ind in indices: lat_mat = torch.cat([lat_mat, latent_vecs[ind]], 0) mean = torch.mean(lat_mat, 0) var = torch.var(lat_mat, 0) return mean, var signal.signal(signal.SIGINT, signal_handler) num_samp_per_scene = specs["SamplesPerScene"] scene_per_batch = specs["ScenesPerBatch"] clamp_dist = specs["ClampingDistance"] minT = -clamp_dist maxT = clamp_dist enforce_minmax = True do_code_regularization = get_spec_with_default(specs, "CodeRegularization", True) code_reg_lambda = get_spec_with_default(specs, "CodeRegularizationLambda", 1e-4) code_bound = get_spec_with_default(specs, "CodeBound", None) decoder = arch.Decoder(latent_size, **specs["NetworkSpecs"]).to(device) # Parallelize if GPUs available if torch.cuda.is_available(): logging.info("training with {} GPU(s)".format(torch.cuda.device_count())) decoder = torch.nn.DataParallel(decoder) num_epochs = specs["NumEpochs"] log_frequency = get_spec_with_default(specs, "LogFrequency", 10) with open(train_split_file, "r") as f: train_split = json.load(f) sdf_dataset = deep_sdf.data.SDFSamples( data_source, train_split, num_samp_per_scene, load_ram=False ) num_data_loader_threads = get_spec_with_default(specs, "DataLoaderThreads", 1) logging.debug("loading data with {} threads".format(num_data_loader_threads)) sdf_loader = data_utils.DataLoader( sdf_dataset, batch_size=scene_per_batch, shuffle=True, num_workers=num_data_loader_threads, drop_last=True, ) logging.debug("torch num_threads: {}".format(torch.get_num_threads())) num_scenes = len(sdf_dataset) logging.info("There are {} scenes".format(num_scenes)) logging.debug(decoder) lat_vecs = torch.nn.Embedding(num_scenes, latent_size, max_norm=code_bound) torch.nn.init.normal_( lat_vecs.weight.data, 0.0, get_spec_with_default(specs, "CodeInitStdDev", 1.0) / math.sqrt(latent_size), ) logging.debug( "initialized with mean magnitude {}".format( get_mean_latent_vector_magnitude(lat_vecs) ) ) loss_l1 = torch.nn.L1Loss(reduction="sum") optimizer_all = torch.optim.Adam( [ { "params": decoder.parameters(), "lr": lr_schedules[0].get_learning_rate(0), }, { "params": lat_vecs.parameters(), "lr": lr_schedules[1].get_learning_rate(0), }, ] ) loss_log = [] lr_log = [] lat_mag_log = [] timing_log = [] param_mag_log = {} start_epoch = 1 if continue_from is not None: logging.info('continuing from "{}"'.format(continue_from)) lat_epoch = load_latent_vectors( experiment_directory, continue_from + ".pth", lat_vecs, device ) model_epoch = ws.load_model_parameters( experiment_directory, continue_from, decoder, device ) optimizer_epoch = load_optimizer( experiment_directory, continue_from + ".pth", optimizer_all ) loss_log, lr_log, timing_log, lat_mag_log, param_mag_log, log_epoch = load_logs( experiment_directory ) if not log_epoch == model_epoch: loss_log, lr_log, timing_log, lat_mag_log, param_mag_log = clip_logs( loss_log, lr_log, timing_log, lat_mag_log, param_mag_log, model_epoch ) if not (model_epoch == optimizer_epoch and model_epoch == lat_epoch): raise RuntimeError( "epoch mismatch: {} vs {} vs {} vs {}".format( model_epoch, optimizer_epoch, lat_epoch, log_epoch ) ) start_epoch = model_epoch + 1 logging.debug("loaded") logging.info("starting from epoch {}".format(start_epoch)) logging.info( "Number of decoder parameters: {}".format( sum(p.data.nelement() for p in decoder.parameters()) ) ) logging.info( "Number of shape code parameters: {} (# codes {}, code dim {})".format( lat_vecs.num_embeddings * lat_vecs.embedding_dim, lat_vecs.num_embeddings, lat_vecs.embedding_dim, ) ) for epoch in range(start_epoch, num_epochs + 1): start = time.time() logging.info("epoch {}...".format(epoch)) decoder.train() adjust_learning_rate(lr_schedules, optimizer_all, epoch) for sdf_data, indices in sdf_loader: # Process the input data sdf_data = sdf_data.reshape(-1, 4) num_sdf_samples = sdf_data.shape[0] sdf_data.requires_grad = False xyz = sdf_data[:, 0:3] sdf_gt = sdf_data[:, 3].unsqueeze(1) if enforce_minmax: sdf_gt = torch.clamp(sdf_gt, minT, maxT) xyz = torch.chunk(xyz, batch_split) indices = torch.chunk( indices.unsqueeze(-1).repeat(1, num_samp_per_scene).view(-1), batch_split, ) sdf_gt = torch.chunk(sdf_gt, batch_split) batch_loss = 0.0 optimizer_all.zero_grad() for i in range(batch_split): batch_vecs = lat_vecs(indices[i]) input = torch.cat([batch_vecs, xyz[i]], dim=1) # NN optimization pred_sdf = decoder(input) if enforce_minmax: pred_sdf = torch.clamp(pred_sdf, minT, maxT) chunk_loss = loss_l1(pred_sdf, sdf_gt[i].to(device=device)) / num_sdf_samples if do_code_regularization: l2_size_loss = torch.sum(torch.norm(batch_vecs, dim=1)) reg_loss = ( code_reg_lambda * min(1, epoch / 100) * l2_size_loss ) / num_sdf_samples chunk_loss = chunk_loss + reg_loss.to(device=device) chunk_loss.backward() batch_loss += chunk_loss.item() logging.debug("loss = {}".format(batch_loss)) loss_log.append(batch_loss) if grad_clip is not None: torch.nn.utils.clip_grad_norm_(decoder.parameters(), grad_clip) optimizer_all.step() end = time.time() seconds_elapsed = end - start timing_log.append(seconds_elapsed) lr_log.append([schedule.get_learning_rate(epoch) for schedule in lr_schedules]) lat_mag_log.append(get_mean_latent_vector_magnitude(lat_vecs)) append_parameter_magnitudes(param_mag_log, decoder) if epoch in checkpoints: save_checkpoints(epoch) if epoch % log_frequency == 0: save_latest(epoch) save_logs( experiment_directory, loss_log, lr_log, timing_log, lat_mag_log, param_mag_log, epoch, )
def main_function(experiment_directory, continue_from, batch_split, finetune_from): logging.debug("running " + experiment_directory) specs = ws.load_experiment_specifications(experiment_directory) logging.info("Experiment description: \n" + str(specs["Description"])) data_source = specs["DataSource"] train_split_file = specs["TrainSplit"] arch = __import__("networks." + specs["NetworkArch"], fromlist=["Decoder"]) logging.debug(specs["NetworkSpecs"]) latent_size = specs["CodeLength"] checkpoints = list( range( specs["SnapshotFrequency"], specs["NumEpochs"] + 1, specs["SnapshotFrequency"], )) for checkpoint in specs["AdditionalSnapshots"]: checkpoints.append(checkpoint) checkpoints.sort() lr_schedules = get_learning_rate_schedules(specs) grad_clip = get_spec_with_default(specs, "GradientClipNorm", None) if grad_clip is not None: logging.debug("clipping gradients to max norm {}".format(grad_clip)) def save_latest(epoch): save_model(experiment_directory, "latest.pth", decoder, epoch) save_optimizer(experiment_directory, "latest.pth", optimizer_all, epoch) save_latent_vectors(experiment_directory, "latest.pth", lat_vecs, epoch) def save_checkpoints(epoch): save_model(experiment_directory, str(epoch) + ".pth", decoder, epoch) save_optimizer(experiment_directory, str(epoch) + ".pth", optimizer_all, epoch) save_latent_vectors(experiment_directory, str(epoch) + ".pth", lat_vecs, epoch) def signal_handler(sig, frame): logging.info("Stopping early...") sys.exit(0) def adjust_learning_rate(lr_schedules, optimizer, epoch): for i, param_group in enumerate(optimizer.param_groups): param_group["lr"] = lr_schedules[i].get_learning_rate(epoch) def latent_size_regul(latent, indices): latent_loss = 0.0 for ind in indices: latent_loss += torch.mean(latent[ind].pow(2)) return latent_loss / len(indices) def empirical_stat(latent_vecs, indices): lat_mat = torch.zeros(0).cuda() for ind in indices: lat_mat = torch.cat([lat_mat, latent_vecs[ind]], 0) mean = torch.mean(lat_mat, 0) var = torch.var(lat_mat, 0) return mean, var signal.signal(signal.SIGINT, signal_handler) num_samp_per_scene = specs["SamplesPerScene"] scene_per_batch = specs["ScenesPerBatch"] clamp_dist = specs["ClampingDistance"] minT = -clamp_dist maxT = clamp_dist enforce_minmax = True if not (scene_per_batch % batch_split) == 0: raise RuntimeError("Unequal batch splitting is not supported.") scene_per_subbatch = scene_per_batch // batch_split min_vec = torch.ones(num_samp_per_scene * scene_per_subbatch, 1).cuda() * minT max_vec = torch.ones(num_samp_per_scene * scene_per_subbatch, 1).cuda() * maxT do_code_regularization = get_spec_with_default(specs, "CodeRegularization", True) code_reg_lambda = get_spec_with_default(specs, "CodeRegularizationLambda", 1e-4) code_bound = get_spec_with_default(specs, "CodeBound", None) decoder = arch.Decoder(latent_size, **specs["NetworkSpecs"]).cuda() logging.info("training with {} GPU(s)".format(torch.cuda.device_count())) # if torch.cuda.device_count() > 1: decoder = torch.nn.DataParallel(decoder) num_epochs = specs["NumEpochs"] log_frequency = get_spec_with_default(specs, "LogFrequency", 10) with open(train_split_file, "r") as f: train_split = json.load(f) sdf_dataset = deep_sdf.data.SDFSamples(data_source, train_split, num_samp_per_scene, load_ram=False) num_data_loader_threads = get_spec_with_default(specs, "DataLoaderThreads", 1) logging.debug( "loading data with {} threads".format(num_data_loader_threads)) sdf_loader = data_utils.DataLoader( sdf_dataset, batch_size=scene_per_subbatch, shuffle=True, num_workers=num_data_loader_threads, drop_last=True, ) logging.debug("torch num_threads: {}".format(torch.get_num_threads())) num_scenes = len(sdf_dataset) logging.info("There are {} scenes".format(num_scenes)) logging.debug(decoder) lat_vecs = [] for _i in range(num_scenes): vec = (torch.ones(1, latent_size).normal_( 0, get_spec_with_default(specs, "CodeInitStdDev", 1.0)).cuda()) vec.requires_grad = True lat_vecs.append(vec) logging.debug("initialized with mean magnitude {}".format( get_mean_latent_vector_magnitude(lat_vecs))) loss_l1 = torch.nn.L1Loss() optimizer_all = torch.optim.Adam([ { "params": decoder.parameters(), "lr": lr_schedules[0].get_learning_rate(0), }, { "params": lat_vecs, "lr": lr_schedules[1].get_learning_rate(0) }, ]) loss_log = [] lr_log = [] lat_mag_log = [] timing_log = [] param_mag_log = {} start_epoch = 1 if continue_from is not None: logging.info('continuing from "{}"'.format(continue_from)) lat_epoch = load_latent_vectors(experiment_directory, continue_from + ".pth", lat_vecs) model_epoch = ws.load_model_parameters(experiment_directory, continue_from, decoder) optimizer_epoch = load_optimizer(experiment_directory, continue_from + ".pth", optimizer_all) loss_log, lr_log, timing_log, lat_mag_log, param_mag_log, log_epoch = load_logs( experiment_directory) if not log_epoch == model_epoch: loss_log, lr_log, timing_log, lat_mag_log, param_mag_log = clip_logs( loss_log, lr_log, timing_log, lat_mag_log, param_mag_log, model_epoch) if not (model_epoch == optimizer_epoch and model_epoch == lat_epoch): raise RuntimeError("epoch mismatch: {} vs {} vs {} vs {}".format( model_epoch, optimizer_epoch, lat_epoch, log_epoch)) start_epoch = model_epoch + 1 logging.debug("loaded") if finetune_from is not None: logging.info('Finetuning from "{}"'.format(finetune_from)) if not os.path.isfile(finetune_from): raise Exception( 'model state dict "{}" does not exist'.format(finetune_from)) data = torch.load(finetune_from) decoder.load_state_dict(data["model_state_dict"]) logging.debug("loaded on epoch {}".format(data["epoch"])) logging.info("starting from epoch {}".format(start_epoch)) for epoch in range(start_epoch, num_epochs + 1): start = time.time() logging.info("epoch {}...".format(epoch)) decoder.train() adjust_learning_rate(lr_schedules, optimizer_all, epoch) for sdf_data, indices in sdf_loader: batch_loss = 0.0 optimizer_all.zero_grad() for _subbatch in range(batch_split): # Process the input datag latent_inputs = torch.zeros(0).cuda() sdf_data.requires_grad = False sdf_data = (sdf_data.cuda()).reshape( num_samp_per_scene * scene_per_subbatch, 4) xyz = sdf_data[:, 0:3] sdf_gt = sdf_data[:, 3].unsqueeze(1) for ind in indices.numpy(): latent_ind = lat_vecs[ind] latent_repeat = latent_ind.expand(num_samp_per_scene, -1) latent_inputs = torch.cat([latent_inputs, latent_repeat], 0) inputs = torch.cat([latent_inputs, xyz], 1) if enforce_minmax: sdf_gt = deep_sdf.utils.threshold_min_max( sdf_gt, min_vec, max_vec) if latent_size == 0: inputs = xyz # NN optimization pred_sdf = decoder(inputs) if enforce_minmax: pred_sdf = deep_sdf.utils.threshold_min_max( pred_sdf, min_vec, max_vec) loss = loss_l1(pred_sdf, sdf_gt) if do_code_regularization: l2_size_loss = latent_size_regul(lat_vecs, indices.numpy()) loss += code_reg_lambda * min(1, epoch / 100) * l2_size_loss loss.backward() batch_loss += loss.item() loss_log.append(batch_loss) if grad_clip is not None: torch.nn.utils.clip_grad_norm_(decoder.parameters(), grad_clip) optimizer_all.step() # Project latent vectors onto sphere if code_bound is not None: deep_sdf.utils.project_vecs_onto_sphere(lat_vecs, code_bound) end = time.time() seconds_elapsed = end - start timing_log.append(seconds_elapsed) lr_log.append( [schedule.get_learning_rate(epoch) for schedule in lr_schedules]) lat_mag_log.append(get_mean_latent_vector_magnitude(lat_vecs)) append_parameter_magnitudes(param_mag_log, decoder) if epoch in checkpoints: save_checkpoints(epoch) if epoch % log_frequency == 0: save_latest(epoch) save_logs( experiment_directory, loss_log, lr_log, timing_log, lat_mag_log, param_mag_log, epoch, )