def get_model(specs, device): model_type = specs["ModelType"] latent_size = specs["LatentSize"] nb_classes = get_spec_with_default(specs["NetworkSpecs"], "num_class", 6) classifier_branch = get_spec_with_default(specs, "ClassifierBranch", False) if model_type == "PC_2encoder1decoder_VAE": # input_type = 'point_cloud' # If use 2 encoders, each encoder produces latent vector with half of the total size. half_latent_size = int(latent_size/2) # print("Point cloud encoder, each branch has latent size", half_latent_size) encoder_obj = arch.ResnetPointnet(c_dim=half_latent_size, hidden_dim=256) # hand encoder get 2xlatent_size, half for mean, another for variance. encoder_hand = arch.ResnetPointnet(c_dim=latent_size, hidden_dim=256, cond_dim=latent_size) combined_decoder = arch.CombinedDecoder(latent_size, **specs["NetworkSpecs"], use_classifier=classifier_branch) encoderDecoder = arch.ModelTwoEncodersOneDecoderVAE( encoder_hand, encoder_obj, combined_decoder, nb_classes, specs["SamplesPerScene"], classifier_branch ) encoderDecoder = torch.nn.DataParallel(encoderDecoder) # Load weights saved_model_state = torch.load( os.path.join(args.model_directory, "model.pth") ) saved_model_epoch = saved_model_state["epoch"] # logging.info("using model from epoch {}".format(saved_model_epoch)) encoderDecoder.load_state_dict(saved_model_state["model_state_dict"]) encoderDecoder = encoderDecoder.to(device)# .cuda() return encoderDecoder # loaded_model
def shape_assembly_main_function(experiment_directory, continue_from, batch_split): def save_latest(epoch): save_model(experiment_directory, "latest.pth", encoder_decoder, epoch) save_optimizer(experiment_directory, "latest.pth", optimizer_all, epoch) def save_checkpoints(epoch): save_model(experiment_directory, str(epoch) + ".pth", encoder_decoder, epoch) save_optimizer(experiment_directory, str(epoch) + ".pth", optimizer_all, epoch) def signal_handler(sig, frame): logging.info("Stopping early...") sys.exit(0) def adjust_learning_rate(lr_schedules, optimizer, epoch): for i, param_group in enumerate(optimizer.param_groups): param_group["lr"] = lr_schedules[i].get_learning_rate(epoch) logging.debug("running " + experiment_directory) print(experiment_directory) specs = utils.misc.load_experiment_specifications(experiment_directory) logging.info("Experiment description: \n" + specs["Description"]) data_source = Path(specs["DataSource"]) train_split_file = specs["TrainSplit"] val_split_file = specs["ValSplit"] subsample = specs["SamplesPerScene"] scene_per_batch = specs["ScenesPerBatch"] latent_size = specs["LatentSize"] num_epochs = specs["NumEpochs"] num_data_loader_threads = get_spec_with_default(specs, "DataLoaderThreads", 8) val_split_file = get_spec_with_default(specs, "ValSplit", None) nb_classes = get_spec_with_default(specs["NetworkSpecs"], "num_class", 6) lr_schedules = get_learning_rate_schedules(specs) log_frequency = get_spec_with_default(specs, "LogFrequency", 5) log_frequency_step = get_spec_with_default(specs, "LogFrequencyStep", 100) checkpoints = list( range( specs["SnapshotFrequency"], specs["NumEpochs"] + 1, specs["SnapshotFrequency"], ) ) for checkpoint in specs["AdditionalSnapshots"]: checkpoints.append(checkpoint) checkpoints.sort() with open(train_split_file, "r") as f: train_split = json.load(f) with open(val_split_file, "r") as f: val_split = json.load(f) sdf_dataset = utils.data.SDFAssemblySamples(data_source, train_split, subsample) sdf_val_dataset = utils.data.SDFAssemblySamples(data_source, val_split, subsample) sdf_loader = data_utils.DataLoader( sdf_dataset, batch_size=scene_per_batch, shuffle=True, num_workers=num_data_loader_threads, drop_last=True ) sdf_val_loader = data_utils.DataLoader( sdf_val_dataset, batch_size=1, shuffle=True, num_workers=num_data_loader_threads, drop_last=True ) half_latent_size = int(latent_size/2) encoder_part1 = arch.ResnetPointnet(c_dim=half_latent_size, hidden_dim=256).to(device) encoder_part2 = arch.ResnetPointnet(c_dim=half_latent_size, hidden_dim=256).to(device) print("Point cloud encoder, each branch has latent size", half_latent_size) decoder = arch.ShapeAssemblyDecoder( latent_size, **specs["NetworkSpecs"] ) encoder_decoder = arch.ModelShapeAssemblyEncoderDecoderVAE( encoder_part1, encoder_part2, decoder, nb_classes, subsample, ) encoder_decoder = encoder_decoder.to(device) encoder_decoder.share_memory() logging.info("training with {} GPU(s)".format(torch.cuda.device_count())) # encoder_decoder = torch.nn.DataParallel(encoder_decoder) logging.debug("torch num_threads: {}".format(torch.get_num_threads())) logging.debug(encoder_decoder) optimizer_all = torch.optim.Adam( [ { "params": encoder_decoder.parameters(), } ] ) writer = SummaryWriter(os.path.join(experiment_directory, 'log')) start_epoch = 1 # continue from latest checkpoint if exists if (continue_from is None and utils.misc.is_checkpoint_exist(experiment_directory, 'latest')): continue_from = 'latest' if continue_from is not None: logging.info('continuing from "{}"'.format(continue_from)) model_epoch = utils.misc.load_model_parameters( experiment_directory, continue_from, encoder_decoder ) optimizer_epoch = load_optimizer( experiment_directory, continue_from + ".pth", optimizer_all ) start_epoch = model_epoch + 1 logging.debug("loaded") # training loop logging.info("starting from epoch {}".format(start_epoch)) for epoch in range(start_epoch, num_epochs + 1): start = time.time() logging.info("epoch {}...".format(epoch)) encoder_decoder.train() adjust_learning_rate(lr_schedules, optimizer_all, epoch) #signify to model we are training encoder_decoder.train() total_sdf_loss = 0 total_transformation_loss = 0 for i, ( part1, part2, gt_transformed_part1_points, gt_transform) in enumerate(sdf_loader): batch_loss = 0.0 optimizer_all.zero_grad() for _subbatch in range(batch_split): samples = part2["sdf_samples"] samples.requires_grad = False sdf_data = (samples.to(device)).reshape( subsample * scene_per_batch, 5 ) # part1_transform_vec = torch.cat((part1["center"], part1["quaternion"]), 1).to(device) xyzs = sdf_data[:, 0:3] sdf_gt_part1 = sdf_data[:, 3].unsqueeze(1) sdf_gt_part2 = sdf_data[:, 4].unsqueeze(1) sdf_pred_part1, sdf_pred_part2, predicted_translation, predicted_rotation = encoder_decoder( part1["surface_points"].to(device), part2["surface_points"].to(device), xyzs ) predicted_translation = predicted_translation.to(device) predicted_rotation = predicted_rotation.to(device) part1["surface_points"] = part1["surface_points"].to(device) predicted_translation = predicted_translation.reshape((predicted_translation.shape[0], 1, 3)) if translation_only: predicted_transformed_part1_points = torch.add(part1["surface_points"], predicted_translation) else: predicted_rotation = predicted_rotation.reshape((predicted_rotation.shape[0], 1, 4)) predicted_rotated_part1_points = quaternion_apply(predicted_rotation, part1["surface_points"]) predicted_transformed_part1_points = torch.add(predicted_rotated_part1_points, predicted_translation) # compute loses loss_sdf_part1 = loss_l1(sdf_pred_part1, sdf_gt_part1) loss_sdf_part2 = loss_l1(sdf_pred_part2, sdf_gt_part2) loss_transformation = alpha * loss_l1_avg(gt_transformed_part1_points.to(device), predicted_transformed_part1_points) loss = loss_sdf_part1 + loss_sdf_part2 + loss_transformation # loss = loss_l1_avg(gt_transformed_part1_points.to(device), predicted_transformed_part1_points) loss.backward() batch_loss += loss.item() print('step {}, loss {:.5f}'.format( (epoch-1) * len(sdf_loader) + i, loss.item() ) ) total_sdf_loss += (loss_sdf_part1 + loss_sdf_part2).item() total_transformation_loss += loss_transformation.item() optimizer_all.step() writer.add_scalar('training_sdf_loss', total_sdf_loss / len(sdf_loader), (epoch-1)) writer.add_scalar('training_transformation_loss', total_transformation_loss / len(sdf_loader), (epoch-1)) end = time.time() seconds_elapsed = end - start print("time used:", seconds_elapsed) for idx, schedule in enumerate(lr_schedules): writer.add_scalar('learning_rate_' + str(idx), schedule.get_learning_rate(epoch), epoch * len(sdf_loader) ) # Save the latest model save_latest(epoch) print("save at {}".format(epoch)) if epoch in checkpoints: save_checkpoints(epoch) #signify to model we are not training but evaluating. encoder_decoder.eval() avg_val_loss = 0.0 avg_val_sdf_loss = 0.0 avg_val_reprojection_loss = 0.0 # Run validation for i, ( part1, part2, gt_transformed_part1_points, gt_transform) in enumerate(sdf_val_loader): reprojection_loss, sdf_loss = shape_assembly_val_function(part1, part2, gt_transformed_part1_points, encoder_decoder, subsample, scene_per_batch) avg_val_sdf_loss += sdf_loss avg_val_reprojection_loss += reprojection_loss avg_val_loss += reprojection_loss + sdf_loss avg_val_sdf_loss = avg_val_sdf_loss / len(sdf_val_loader) avg_val_reprojection_loss = avg_val_reprojection_loss / len(sdf_val_loader) avg_val_loss = avg_val_loss / len(sdf_val_loader) writer.add_scalar('validation_loss', avg_val_loss, (epoch-1)) logging.info(f"Epoch {epoch}: Validation Loss: {avg_val_loss}") writer.add_scalar('validation_reprojection_loss', avg_val_reprojection_loss, (epoch-1)) writer.add_scalar('validation_sdf_loss', avg_val_sdf_loss, (epoch-1))
def main_function(experiment_directory, continue_from, batch_split): logging.debug("running " + experiment_directory) print(experiment_directory) specs = utils.misc.load_experiment_specifications(experiment_directory) logging.info("Experiment description: \n" + specs["Description"]) data_source = specs["DataSource"] image_source = specs["ImageSource"] train_split_file = specs["TrainSplit"] val_split_file = get_spec_with_default(specs, "ValSplit", None) is_fhb = get_spec_with_default(specs, "FHB", False) if is_fhb: print("FHB dataset") check_file = get_spec_with_default(specs,"CheckFile", True) logging.debug(specs["NetworkSpecs"]) dataset_name = get_spec_with_default(specs, "Dataset", "obman") ### Model Type model_type = get_spec_with_default(specs, "ModelType", "1encoder2decoder") obj_center = get_spec_with_default(specs, "ObjectCenter", False) hand_branch = get_spec_with_default(specs, "HandBranch", True) obj_branch = get_spec_with_default(specs, "ObjectBranch", True) print("Hand branch:", hand_branch) print("Object branch:", obj_branch) assert hand_branch or obj_branch classifier_branch = get_spec_with_default(specs, "ClassifierBranch", False) classifier_weight = get_spec_with_default(specs, "ClassifierWeight", 0.1) print("Classifier Weight:", classifier_weight) use_gaussian_reconstruction_weight = get_spec_with_default(specs, "GaussianWeightLoss", False) do_penetration_loss = get_spec_with_default(specs, "PenetrationLoss", False) penetration_loss_weight = get_spec_with_default(specs, "PenetrationLossWeight", 15.0) # 1000.0) start_additional_loss = get_spec_with_default(specs, "AdditionalLossStart", 200000) # 500) do_contact_loss = get_spec_with_default(specs, "ContactLoss", False) contact_loss_weight = get_spec_with_default(specs, "ContactLossWeight", 0.005) contact_loss_sigma = get_spec_with_default(specs, "ContactLossSigma", 0.005) print("Penetration Loss:", do_penetration_loss) print("Penetration Loss Weight:", penetration_loss_weight) print("Additional Loss start at epoch:", start_additional_loss) print("Contact Loss:", do_contact_loss) print("Contact Loss Weight:", contact_loss_weight) print("Contact Loss Sigma (m):", contact_loss_sigma) latent_size = specs["LatentSize"] checkpoints = list( range( specs["SnapshotFrequency"], specs["NumEpochs"] + 1, specs["SnapshotFrequency"], ) ) for checkpoint in specs["AdditionalSnapshots"]: checkpoints.append(checkpoint) checkpoints.sort() lr_schedules = get_learning_rate_schedules(specs) grad_clip = get_spec_with_default(specs, "GradientClipNorm", None) if grad_clip is not None: logging.debug("clipping gradients to max norm {}".format(grad_clip)) def save_latest(epoch): save_model(experiment_directory, "latest.pth", encoderDecoder, epoch) save_optimizer(experiment_directory, "latest.pth", optimizer_all, epoch) # save_latent_vectors(experiment_directory, "latest.pth", lat_vecs, epoch) def save_checkpoints(epoch): save_model(experiment_directory, str(epoch) + ".pth", encoderDecoder, epoch) save_optimizer(experiment_directory, str(epoch) + ".pth", optimizer_all, epoch) # save_latent_vectors(experiment_directory, str(epoch) + ".pth", lat_vecs, epoch) def signal_handler(sig, frame): logging.info("Stopping early...") sys.exit(0) def adjust_learning_rate(lr_schedules, optimizer, epoch): for i, param_group in enumerate(optimizer.param_groups): param_group["lr"] = lr_schedules[i].get_learning_rate(epoch) signal.signal(signal.SIGINT, signal_handler) # If true, use the data as-is. If false, multiply and offset obj location with normalized params indep_obj_scale = get_spec_with_default(specs, "IndependentObjScale", False) print("Independent Obj Scale:", indep_obj_scale) # Ignore points from other mesh in the begining when train 1 decoder ignore_other = get_spec_with_default(specs, "IgnorePointFromOtherMesh", False) print("Ignore other:", ignore_other) num_samp_per_scene = specs["SamplesPerScene"] scene_per_batch = specs["ScenesPerBatch"] scene_per_subbatch= scene_per_batch clamp_dist = specs["ClampingDistance"] minT = -clamp_dist maxT = clamp_dist enforce_minmax = True nb_classes = get_spec_with_default(specs["NetworkSpecs"], "num_class", 6) print("nb_label_class: ", nb_classes) ## Define Model if model_type == "PC_2encoder1decoder_VAE": kl_schedules = get_kl_weight_schedules(specs) input_type = 'point_cloud' same_point = True # If use 2 encoders, each encoder produces latent vector with half of the total size. half_latent_size = int(latent_size/2) print("Point cloud encoder, each branch has latent size", half_latent_size) encoder_obj = arch.ResnetPointnet(c_dim=half_latent_size, hidden_dim=256) use_sampling_trick = False if use_sampling_trick: encoder_hand = arch.ResnetPointnet(c_dim=latent_size, hidden_dim=256) else: encoder_hand = arch.ResnetPointnet(c_dim=latent_size, hidden_dim=256, cond_dim=latent_size) encoder_hand = encoder_hand.to(device) encoder_obj = encoder_obj.to(device) combined_decoder = arch.CombinedDecoder(latent_size, **specs["NetworkSpecs"], use_classifier=classifier_branch).to(device) encoderDecoder = arch.ModelTwoEncodersOneDecoderVAE( encoder_hand, encoder_obj, combined_decoder, nb_classes, num_samp_per_scene, classifier_branch ) encoderDecoder = encoderDecoder.to(device) encoder_input_source = data_source if input_type == 'point_cloud' else image_source logging.info("training with {} GPU(s)".format(torch.cuda.device_count())) encoderDecoder = torch.nn.DataParallel(encoderDecoder) num_epochs = specs["NumEpochs"] log_frequency = get_spec_with_default(specs, "LogFrequency", 5) log_frequency_step = get_spec_with_default(specs, "LogFrequencyStep", 100) logging.debug("torch num_threads: {}".format(torch.get_num_threads())) logging.debug(encoderDecoder) if "1decoder" in model_type and ignore_other: loss_l1 = torch.nn.L1Loss(reduction='sum') elif use_gaussian_reconstruction_weight: loss_l1 = torch.nn.L1Loss(reduction='none') else: loss_l1 = torch.nn.L1Loss() criterion_ce = torch.nn.CrossEntropyLoss(ignore_index=-1) if "VAE" in model_type: hand_latent_reg_l2 = torch.nn.MSELoss() optimizer_all = torch.optim.Adam( [ { "params": encoderDecoder.parameters(), } ] ) # Tensorboard summary writer = SummaryWriter(os.path.join(experiment_directory, 'log')) # writer.add_graph(encoderDecoder) start_epoch = 1 # global_step = 0 # continue from latest checkpoint if exists if (continue_from is None and utils.misc.is_checkpoint_exist(experiment_directory, 'latest')): continue_from = 'latest' if continue_from is not None: logging.info('continuing from "{}"'.format(continue_from)) model_epoch = utils.misc.load_model_parameters( experiment_directory, continue_from, encoderDecoder ) optimizer_epoch = load_optimizer( experiment_directory, continue_from + ".pth", optimizer_all ) start_epoch = model_epoch + 1 logging.debug("loaded") # Data loader filter_dist = False if start_epoch >= start_additional_loss: same_point = True filter_dist = True with open(train_split_file, "r") as f: train_split = json.load(f) sdf_dataset = utils.data.SDFSamples( input_type, data_source, train_split, num_samp_per_scene, dataset_name=dataset_name, image_source=image_source, hand_branch=hand_branch, obj_branch=obj_branch, indep_obj_scale=indep_obj_scale, same_point=same_point, filter_dist=filter_dist, clamp=clamp_dist, load_ram=False, check_file=check_file, fhb=is_fhb, model_type=model_type, obj_center=obj_center ) num_data_loader_threads = get_spec_with_default(specs, "DataLoaderThreads", 8) logging.debug("loading data with {} threads".format(num_data_loader_threads)) sdf_loader = data_utils.DataLoader( sdf_dataset, batch_size=scene_per_subbatch, shuffle=True, num_workers=num_data_loader_threads, drop_last=True ) # training loop logging.info("starting from epoch {}".format(start_epoch)) for epoch in range(start_epoch, num_epochs + 1): start = time.time() logging.info("epoch {}...".format(epoch)) encoderDecoder.train() adjust_learning_rate(lr_schedules, optimizer_all, epoch) if 'VAE' in model_type: kl_weight = kl_schedules.get_weight(epoch) # Change sdf_loader to get sdf to both hand and object from the same points # print("same_point", same_point) if epoch == start_additional_loss : # and not same_point: same_point = True filter_dist = True sdf_dataset = utils.data.SDFSamples( input_type, data_source, train_split, num_samp_per_scene, dataset_name=dataset_name, image_source=image_source, hand_branch=hand_branch, obj_branch=obj_branch, indep_obj_scale=indep_obj_scale, same_point=same_point, filter_dist=filter_dist, clamp=clamp_dist, load_ram=False, # True check_file=check_file, fhb=is_fhb, model_type=model_type, obj_center=obj_center ) sdf_loader = data_utils.DataLoader( sdf_dataset, batch_size=scene_per_subbatch, shuffle=True, num_workers=num_data_loader_threads, drop_last=True ) for i, (hand_samples, hand_labels, obj_samples, obj_labels, scale, offset, encoder_input_hand, encoder_input_obj, idx) in enumerate(sdf_loader): batch_loss = 0.0 optimizer_all.zero_grad() for _subbatch in range(batch_split): if input_type == 'image': encoder_input_hand = encoder_input_hand.to(device) elif input_type == 'point_cloud': encoder_input_hand = encoder_input_hand.to(device) encoder_input_obj = encoder_input_obj.to(device) elif input_type == 'image+point_cloud': encoder_input_hand = encoder_input_hand.to(device) encoder_input_obj = encoder_input_obj.to(device) if '1decoder' in model_type: # Using same point if hand_branch and obj_branch: samples = torch.cat([hand_samples, obj_samples], 1) labels = torch.cat([hand_labels, obj_labels], 1) # Ignore points from other shape in the begining of the training if ignore_other or epoch < start_additional_loss: mask_hand = torch.cat([torch.ones(hand_samples.size()[:2]), torch.zeros(obj_samples.size()[:2])], 1) mask_hand = (mask_hand.to(device)).reshape(num_samp_per_scene * scene_per_subbatch).unsqueeze(1) mask_obj = torch.cat([torch.zeros(hand_samples.size()[:2]), torch.ones(obj_samples.size()[:2])], 1) mask_obj = (mask_obj.to(device)).reshape(num_samp_per_scene * scene_per_subbatch).unsqueeze(1) else: mask_hand = torch.ones(num_samp_per_scene * scene_per_subbatch).unsqueeze(1).to(device) mask_obj = torch.ones(num_samp_per_scene * scene_per_subbatch).unsqueeze(1).to(device) elif hand_branch: samples = hand_samples labels = hand_labels elif obj_branch: samples = obj_samples labels = obj_labels samples.requires_grad = False labels.requires_grad = False sdf_data = (samples.to(device)).reshape( num_samp_per_scene * scene_per_subbatch, 5 ) labels = (labels.to(device).to(torch.long)).reshape( num_samp_per_scene * scene_per_subbatch) xyz_hand = sdf_data[:, 0:3] xyz_obj = xyz_hand sdf_gt_hand = sdf_data[:, 3].unsqueeze(1) sdf_gt_obj = sdf_data[:, 4].unsqueeze(1) else: hand_samples.requires_grad = False hand_labels.requires_grad = False obj_samples.requires_grad = False obj_labels.requires_grad = False # Seperated points - Hand if same_point: samples = torch.cat([hand_samples, obj_samples], 1) labels = torch.cat([hand_labels, obj_labels], 1) sdf_data = (samples.to(device)).reshape( num_samp_per_scene * scene_per_subbatch, 5 ) labels = (labels.to(device).to(torch.long)).reshape( num_samp_per_scene * scene_per_subbatch) hand_labels = labels obj_labels = labels xyz_hand = sdf_data[:, 0:3] xyz_obj = xyz_hand sdf_gt_hand = sdf_data[:, 3].unsqueeze(1) sdf_gt_obj = sdf_data[:, 4].unsqueeze(1) else: sdf_data_hand = (hand_samples.to(device)).reshape( num_samp_per_scene * scene_per_subbatch, 5 ) hand_labels = (hand_labels.to(device).to(torch.long)).reshape( num_samp_per_scene * scene_per_subbatch) xyz_hand = sdf_data_hand[:, 0:3] sdf_gt_hand = sdf_data_hand[:, 3].unsqueeze(1) # Object sdf_data_obj = (obj_samples.to(device)).reshape( num_samp_per_scene * scene_per_subbatch, 5 ) obj_labels = (obj_labels.to(device).to(torch.long)).reshape( num_samp_per_scene * scene_per_subbatch) xyz_obj = sdf_data_obj[:, 0:3] sdf_gt_obj = sdf_data_obj[:, 4].unsqueeze(1) # scale scale = scale.to(device).repeat_interleave(num_samp_per_scene, dim=0) if enforce_minmax: if hand_branch: sdf_gt_hand = torch.clamp(sdf_gt_hand, minT, maxT) if obj_branch: sdf_gt_obj = torch.clamp(sdf_gt_obj, minT, maxT) if model_type == 'PC_2encoder1decoder_VAE': pred_sdf_hand, pred_sdf_obj, pred_class, kl_loss, z_hand = encoderDecoder(encoder_input_hand, encoder_input_obj, xyz_hand) elif model_type == 'pc+1encoder1decoder': pred_sdf_hand, pred_sdf_obj, pred_class = encoderDecoder(encoder_input_hand, encoder_input_obj, xyz_hand) elif '2encoder' in model_type and '1decoder' in model_type: pred_sdf_hand, pred_sdf_obj, pred_class = encoderDecoder(encoder_input_hand, encoder_input_obj, xyz_hand) # same points elif '1decoder' in model_type: pred_sdf_hand, pred_sdf_obj, pred_class = encoderDecoder(encoder_input_hand, xyz_hand) else: pred_sdf_hand, pred_class_hand, \ pred_sdf_obj, pred_class_obj = encoderDecoder(encoder_input_hand, xyz_hand, xyz_obj) if enforce_minmax: if hand_branch: pred_sdf_hand = torch.clamp(pred_sdf_hand, minT, maxT) if obj_branch: pred_sdf_obj = torch.clamp(pred_sdf_obj, minT, maxT) ## Compute losses sigma_recon = 0.005 * 10.0 if hand_branch: if "1decoder" in model_type and ignore_other: pred_sdf_hand = torch.mul(pred_sdf_hand, mask_hand) loss_hand = loss_l1(pred_sdf_hand, sdf_gt_hand) / mask_hand.sum() else: loss_hand = loss_l1(pred_sdf_hand, sdf_gt_hand) else: loss_hand = 0. if obj_branch: if "1decoder" in model_type and ignore_other: pred_sdf_obj = torch.mul(pred_sdf_obj, mask_obj) loss_obj = loss_l1(pred_sdf_obj, sdf_gt_obj) / mask_obj.sum() else: loss_obj = loss_l1(pred_sdf_obj, sdf_gt_obj) else: loss_obj = 0. if classifier_branch: if not '1decoder' in model_type: loss_ce = (criterion_ce(pred_class_hand, hand_labels) + criterion_ce(pred_class_obj, obj_labels) ) * classifier_weight else: loss_ce = criterion_ce(pred_class, labels) * classifier_weight else: loss_ce = 0 loss = loss_hand + loss_obj if epoch >= start_additional_loss: loss = loss + loss_ce if 'VAE' in model_type: # KL-divergence kl_loss_raw = kl_loss.mean() # print("kl loss after mean", kl_loss.size()) kl_loss = kl_weight * kl_loss_raw loss = loss + kl_loss if hand_branch: scaled_pred_sdf_hand = pred_sdf_hand if obj_branch: scaled_pred_sdf_obj = pred_sdf_obj if do_penetration_loss: pen_loss = torch.max(-(scaled_pred_sdf_hand + scaled_pred_sdf_obj), torch.Tensor([0]).to(device)).mean() * penetration_loss_weight if epoch >= start_additional_loss: loss = loss + pen_loss if do_contact_loss: alpha = 1. / contact_loss_sigma**2 contact_loss = torch.min(alpha * (scaled_pred_sdf_hand**2 + scaled_pred_sdf_obj**2), torch.Tensor([1]).to(device)).mean() * contact_loss_weight if epoch >= start_additional_loss: loss = loss + contact_loss loss.backward() batch_loss += loss.item() if grad_clip is not None: torch.nn.utils.clip_grad_norm_(encoderDecoder.parameters(), grad_clip) if ((epoch-1) * len(sdf_loader) + i) % log_frequency_step == 0: loss_hand_out = loss_hand.item() if hand_branch else 0 loss_obj_out = loss_obj.item() if obj_branch else 0 loss_ce_out = loss_ce.item() if classifier_branch else 0 pen_loss_out = pen_loss.item() if do_penetration_loss else 0 contact_loss_out = contact_loss.item() if do_contact_loss else 0 print('step {}, loss {:.5f}, hand loss {:.5f}, object loss {:.5f}:, classifier loss {:.5f}, penetration loss {:.5f}, contact loss {:.5f} '.format( (epoch-1) * len(sdf_loader) + i, loss.item(), loss_hand_out, loss_obj_out, loss_ce_out, pen_loss_out, contact_loss_out)) if 'VAE' in model_type: print('KL loss {:.5f}'.format(kl_loss.item())) writer.add_scalar('KL_loss_1e-3', kl_loss.item() * 1000.0, (epoch-1) * len(sdf_loader) + i) writer.add_scalar('KL_loss_raw_1e-3', kl_loss_raw.item() * 1000.0, (epoch-1) * len(sdf_loader) + i) writer.add_scalar('training_loss_1e-3', loss.item() * 1000.0, (epoch-1) * len(sdf_loader) + i) writer.add_scalar('loss_hand_1e-3', loss_hand_out * 1000.0, (epoch-1) * len(sdf_loader) + i) writer.add_scalar('loss_object_1e-3', loss_obj_out * 1000.0, (epoch-1) * len(sdf_loader) + i) writer.add_scalar('loss_classifier_1e-3', loss_ce_out * 1000.0, (epoch-1) * len(sdf_loader) + i) writer.add_scalar('loss_penetration_1e-3', pen_loss_out * 1000.0, (epoch-1) * len(sdf_loader) + i) writer.add_scalar('loss_contact_1e-3', contact_loss_out * 1000.0, (epoch-1) * len(sdf_loader) + i) optimizer_all.step() end = time.time() seconds_elapsed = end - start print("time used:", seconds_elapsed) for idx, schedule in enumerate(lr_schedules): writer.add_scalar('learning_rate_' + str(idx), schedule.get_learning_rate(epoch), epoch * len(sdf_loader) ) recon_scale = 0.5 if not indep_obj_scale else 1.0 if epoch in checkpoints and val_split_file: save_checkpoints(epoch) print("reconstruct mesh at {}".format(epoch)) recon_st = time.time() reconstruct.reconstruct_training(experiment_directory, val_split_file, input_type, encoder_input_source, encoderDecoder, epoch, specs, hand_branch, obj_branch, model_type=model_type, scale=recon_scale, # cube_dim=128, fhb=is_fhb, dataset_name=dataset_name) print("- Reconstruction used {}".format(time.time()-recon_st)) # chamfer_st = time.time() # object_type, chamfer_mean_list, chamfer_med_list, = evaluate.evaluate( # experiment_directory, # str(epoch), # data_source, # val_split_file, # ) # print("calculate chamfer dist used {}".format(time.time()-chamfer_st)) # print(" - Chamfer distance:") # for i, obj_type in enumerate(object_type): # print("{}: mean: {:.5f}, med: {:.5f}".format(obj_type, chamfer_mean_list[i], chamfer_med_list[i])) # writer.add_scalar(obj_type+'_val_chamfer_mean_x1000', # chamfer_mean_list[i] * 1000.0, epoch) # writer.add_scalar(obj_type+'_val_chamfer_med_x1000', # chamfer_med_list[i] * 1000.0, epoch) if epoch % log_frequency == 0: save_latest(epoch) print("save at {}".format(epoch)) # End of training if val_split_file: print("Final reconstruct mesh at {}".format(num_epochs)) recon_st = time.time() reconstruct.reconstruct_training(experiment_directory, val_split_file, input_type, encoder_input_source, encoderDecoder, num_epochs, specs, hand_branch, obj_branch, model_type=model_type, scale=recon_scale, # cube_dim=256, fhb=is_fhb, dataset_name=dataset_name) print("- Final Reconstruction used {}".format(time.time()-recon_st)) # chamfer_st = time.time() # object_type, chamfer_mean_list, chamfer_med_list, = evaluate.evaluate( # experiment_directory, # str(num_epochs), # data_source, # val_split_file, # ) # print("calculate final chamfer dist used {}".format(time.time()-chamfer_st)) # print(" - Chamfer distance:") # for i, obj_type in enumerate(object_type): # print("{}: mean: {:.5f}, med: {:.5f}".format(obj_type, chamfer_mean_list[i], chamfer_med_list[i])) # writer.add_scalar(obj_type+'_val_chamfer_mean_x1000', # chamfer_mean_list[i] * 1000.0, num_epochs) # writer.add_scalar(obj_type+'_val_chamfer_med_x1000', # chamfer_med_list[i] * 1000.0, num_epochs) writer.close()
def run_visualization(experiment_directory, checkpoint): specs = utils.misc.load_experiment_specifications(experiment_directory) data_source = Path(specs["DataSource"]) train_split_file = specs["TrainSplit"] val_split_file = specs["ValSplit"] subsample = specs["SamplesPerScene"] scene_per_batch = specs["ScenesPerBatch"] latent_size = specs["LatentSize"] num_epochs = specs["NumEpochs"] num_data_loader_threads = get_spec_with_default(specs, "DataLoaderThreads", 8) val_split_file = get_spec_with_default(specs, "ValSplit", None) nb_classes = get_spec_with_default(specs["NetworkSpecs"], "num_class", 6) log_frequency = get_spec_with_default(specs, "LogFrequency", 5) log_frequency_step = get_spec_with_default(specs, "LogFrequencyStep", 100) with open(train_split_file, "r") as f: train_split = json.load(f) with open(val_split_file, "r") as f: val_split = json.load(f) sdf_dataset = utils.data.SDFAssemblySamples(data_source, train_split, subsample) sdf_val_dataset = utils.data.SDFAssemblySamples(data_source, val_split, subsample) sdf_loader = data_utils.DataLoader( sdf_dataset, batch_size=1, shuffle=True, num_workers=num_data_loader_threads, drop_last=True ) sdf_val_loader = data_utils.DataLoader( sdf_val_dataset, batch_size=1, shuffle=True, num_workers=num_data_loader_threads, drop_last=True ) half_latent_size = int(latent_size/2) encoder_part1 = arch.ResnetPointnet(c_dim=half_latent_size, hidden_dim=256).to(device) encoder_part2 = arch.ResnetPointnet(c_dim=half_latent_size, hidden_dim=256).to(device) print("Point cloud encoder, each branch has latent size", half_latent_size) decoder = arch.ShapeAssemblyDecoder( latent_size, **specs["NetworkSpecs"] ) encoder_decoder = arch.ModelShapeAssemblyEncoderDecoderVAE( encoder_part1, encoder_part2, decoder, nb_classes, subsample ) encoder_decoder = encoder_decoder.to(device) encoder_decoder.share_memory() model_epoch = utils.misc.load_model_parameters( experiment_directory, checkpoint, encoder_decoder ) logging.info("Loaded weights!") total_validation_error = 0 # Run validation for i, ( part1, part2, gt_transformed_part1_points, gt_transform) in enumerate(sdf_val_loader): shape_assembly_visualization_function(part1, part2, gt_transform, encoder_decoder, subsample, scene_per_batch)