Пример #1
0
def get_model(specs, device):
    model_type = specs["ModelType"]
    latent_size = specs["LatentSize"]
    nb_classes = get_spec_with_default(specs["NetworkSpecs"], "num_class", 6)
    classifier_branch = get_spec_with_default(specs, "ClassifierBranch", False)

    if model_type == "PC_2encoder1decoder_VAE":
        # input_type = 'point_cloud'
        # If use 2 encoders, each encoder produces latent vector with half of the total size.
        half_latent_size = int(latent_size/2)
        # print("Point cloud encoder, each branch has latent size", half_latent_size)
        
        encoder_obj = arch.ResnetPointnet(c_dim=half_latent_size, hidden_dim=256)
        # hand encoder get 2xlatent_size, half for mean, another for variance.
        encoder_hand = arch.ResnetPointnet(c_dim=latent_size, hidden_dim=256, cond_dim=latent_size)

        combined_decoder = arch.CombinedDecoder(latent_size, **specs["NetworkSpecs"], 
                                                use_classifier=classifier_branch)

        encoderDecoder = arch.ModelTwoEncodersOneDecoderVAE(
            encoder_hand, encoder_obj, combined_decoder, 
            nb_classes, specs["SamplesPerScene"], 
            classifier_branch
        )

    encoderDecoder = torch.nn.DataParallel(encoderDecoder)

    # Load weights
    saved_model_state = torch.load(
        os.path.join(args.model_directory, "model.pth")
    )
    saved_model_epoch = saved_model_state["epoch"]
    # logging.info("using model from epoch {}".format(saved_model_epoch))

    encoderDecoder.load_state_dict(saved_model_state["model_state_dict"])

    encoderDecoder = encoderDecoder.to(device)# .cuda()
    
    return encoderDecoder # loaded_model
Пример #2
0
def shape_assembly_main_function(experiment_directory, continue_from, batch_split):

    def save_latest(epoch):
        save_model(experiment_directory, "latest.pth", encoder_decoder, epoch)
        save_optimizer(experiment_directory, "latest.pth", optimizer_all, epoch)

    def save_checkpoints(epoch):
        save_model(experiment_directory, str(epoch) + ".pth", encoder_decoder, epoch)
        save_optimizer(experiment_directory, str(epoch) + ".pth", optimizer_all, epoch)

    def signal_handler(sig, frame):
        logging.info("Stopping early...")
        sys.exit(0)

    def adjust_learning_rate(lr_schedules, optimizer, epoch):

        for i, param_group in enumerate(optimizer.param_groups):
            param_group["lr"] = lr_schedules[i].get_learning_rate(epoch)

    logging.debug("running " + experiment_directory)

    print(experiment_directory)

    specs = utils.misc.load_experiment_specifications(experiment_directory)

    logging.info("Experiment description: \n" + specs["Description"])

    data_source = Path(specs["DataSource"])
    train_split_file = specs["TrainSplit"]
    val_split_file = specs["ValSplit"]
    subsample = specs["SamplesPerScene"]
    scene_per_batch = specs["ScenesPerBatch"]
    latent_size = specs["LatentSize"]
    num_epochs = specs["NumEpochs"]
    num_data_loader_threads = get_spec_with_default(specs, "DataLoaderThreads", 8)
    val_split_file = get_spec_with_default(specs, "ValSplit", None)
    nb_classes = get_spec_with_default(specs["NetworkSpecs"], "num_class", 6)
    lr_schedules = get_learning_rate_schedules(specs)
    log_frequency = get_spec_with_default(specs, "LogFrequency", 5)
    log_frequency_step = get_spec_with_default(specs, "LogFrequencyStep", 100)

    checkpoints = list(
        range(
            specs["SnapshotFrequency"],
            specs["NumEpochs"] + 1,
            specs["SnapshotFrequency"],
        )
    )
    for checkpoint in specs["AdditionalSnapshots"]:
        checkpoints.append(checkpoint)
    checkpoints.sort()

    with open(train_split_file, "r") as f:
        train_split = json.load(f)

    with open(val_split_file, "r") as f:
        val_split = json.load(f)

    sdf_dataset = utils.data.SDFAssemblySamples(data_source, train_split, subsample)
    sdf_val_dataset = utils.data.SDFAssemblySamples(data_source, val_split, subsample)

    sdf_loader = data_utils.DataLoader(
        sdf_dataset,
        batch_size=scene_per_batch,
        shuffle=True,
        num_workers=num_data_loader_threads,
        drop_last=True
    )
    sdf_val_loader = data_utils.DataLoader(
        sdf_val_dataset,
        batch_size=1,
        shuffle=True,
        num_workers=num_data_loader_threads,
        drop_last=True
    )

    half_latent_size = int(latent_size/2)

    encoder_part1 = arch.ResnetPointnet(c_dim=half_latent_size, hidden_dim=256).to(device)
    encoder_part2 = arch.ResnetPointnet(c_dim=half_latent_size, hidden_dim=256).to(device)
    print("Point cloud encoder, each branch has latent size", half_latent_size)

    decoder = arch.ShapeAssemblyDecoder(
        latent_size, 
        **specs["NetworkSpecs"]
    )

    encoder_decoder = arch.ModelShapeAssemblyEncoderDecoderVAE(
        encoder_part1,
        encoder_part2,
        decoder,
        nb_classes,
        subsample,
    )
    encoder_decoder = encoder_decoder.to(device)
    encoder_decoder.share_memory()

    logging.info("training with {} GPU(s)".format(torch.cuda.device_count()))

    # encoder_decoder = torch.nn.DataParallel(encoder_decoder)

    logging.debug("torch num_threads: {}".format(torch.get_num_threads()))
    logging.debug(encoder_decoder)

    optimizer_all = torch.optim.Adam(
        [
            {
                "params": encoder_decoder.parameters(),
            }
        ]
    )

    writer = SummaryWriter(os.path.join(experiment_directory, 'log'))

    start_epoch = 1

    # continue from latest checkpoint if exists
    if (continue_from is None and 
            utils.misc.is_checkpoint_exist(experiment_directory, 'latest')):
        continue_from = 'latest'

    if continue_from is not None:
        logging.info('continuing from "{}"'.format(continue_from))

        model_epoch = utils.misc.load_model_parameters(
            experiment_directory, continue_from, encoder_decoder
        )

        optimizer_epoch = load_optimizer(
            experiment_directory, continue_from + ".pth", optimizer_all
        )
        start_epoch = model_epoch + 1
        logging.debug("loaded")

    # training loop
    logging.info("starting from epoch {}".format(start_epoch))

    for epoch in range(start_epoch, num_epochs + 1):
        start = time.time()

        logging.info("epoch {}...".format(epoch))

        encoder_decoder.train()

        adjust_learning_rate(lr_schedules, optimizer_all, epoch)

        #signify to model we are training
        encoder_decoder.train()

        total_sdf_loss = 0
        total_transformation_loss = 0

        for i, (
            part1,
            part2,
            gt_transformed_part1_points,
            gt_transform) in enumerate(sdf_loader):

            batch_loss = 0.0
            optimizer_all.zero_grad()      

            for _subbatch in range(batch_split):

                samples = part2["sdf_samples"]

                samples.requires_grad = False

                sdf_data = (samples.to(device)).reshape(
                    subsample * scene_per_batch, 5
                )

                # part1_transform_vec = torch.cat((part1["center"], part1["quaternion"]), 1).to(device)
                
                xyzs = sdf_data[:, 0:3]
                sdf_gt_part1 = sdf_data[:, 3].unsqueeze(1)
                sdf_gt_part2 = sdf_data[:, 4].unsqueeze(1)
                
                sdf_pred_part1, sdf_pred_part2, predicted_translation, predicted_rotation = encoder_decoder(
                                                        part1["surface_points"].to(device), 
                                                        part2["surface_points"].to(device), 
                                                        xyzs
                                                    )

                predicted_translation = predicted_translation.to(device)
                predicted_rotation = predicted_rotation.to(device)
                part1["surface_points"] = part1["surface_points"].to(device)    
                
                predicted_translation = predicted_translation.reshape((predicted_translation.shape[0], 1, 3))
                if translation_only:
                    predicted_transformed_part1_points = torch.add(part1["surface_points"], predicted_translation)
                else:
                    predicted_rotation = predicted_rotation.reshape((predicted_rotation.shape[0], 1, 4))
                    predicted_rotated_part1_points = quaternion_apply(predicted_rotation, part1["surface_points"])
                    predicted_transformed_part1_points = torch.add(predicted_rotated_part1_points, predicted_translation)

                 # compute loses
                loss_sdf_part1 = loss_l1(sdf_pred_part1, sdf_gt_part1)
                loss_sdf_part2 = loss_l1(sdf_pred_part2, sdf_gt_part2)
                
                loss_transformation = alpha * loss_l1_avg(gt_transformed_part1_points.to(device), predicted_transformed_part1_points)

                loss = loss_sdf_part1 + loss_sdf_part2 + loss_transformation

                # loss = loss_l1_avg(gt_transformed_part1_points.to(device), predicted_transformed_part1_points)

                loss.backward()

                batch_loss += loss.item()

                print('step {}, loss {:.5f}'.format(
                        (epoch-1) * len(sdf_loader) + i, 
                        loss.item()
                    )
                )

                total_sdf_loss += (loss_sdf_part1 + loss_sdf_part2).item()
                total_transformation_loss += loss_transformation.item()
                

            optimizer_all.step()
        
        writer.add_scalar('training_sdf_loss', total_sdf_loss / len(sdf_loader), (epoch-1))
        writer.add_scalar('training_transformation_loss', total_transformation_loss / len(sdf_loader), (epoch-1))

        end = time.time()

        seconds_elapsed = end - start
        print("time used:", seconds_elapsed)

        for idx, schedule in enumerate(lr_schedules):
            writer.add_scalar('learning_rate_' + str(idx),
                schedule.get_learning_rate(epoch),
                epoch * len(sdf_loader)
            )

        # Save the latest model
        save_latest(epoch)
        print("save at {}".format(epoch))
        
        if epoch in checkpoints:
            save_checkpoints(epoch)
            #signify to model we are not training but evaluating.
            encoder_decoder.eval()


            avg_val_loss = 0.0
            avg_val_sdf_loss = 0.0
            avg_val_reprojection_loss = 0.0

            # Run validation
            for i, (
                part1,
                part2,
                gt_transformed_part1_points,
                gt_transform) in enumerate(sdf_val_loader):
                
                reprojection_loss, sdf_loss = shape_assembly_val_function(part1, part2, gt_transformed_part1_points, encoder_decoder, subsample, scene_per_batch)

                avg_val_sdf_loss += sdf_loss
                avg_val_reprojection_loss += reprojection_loss
                avg_val_loss += reprojection_loss + sdf_loss

            avg_val_sdf_loss = avg_val_sdf_loss / len(sdf_val_loader)
            avg_val_reprojection_loss = avg_val_reprojection_loss / len(sdf_val_loader)
            avg_val_loss = avg_val_loss / len(sdf_val_loader)

            writer.add_scalar('validation_loss', avg_val_loss, (epoch-1))
            logging.info(f"Epoch {epoch}: Validation Loss: {avg_val_loss}")

            writer.add_scalar('validation_reprojection_loss', avg_val_reprojection_loss, (epoch-1))
            writer.add_scalar('validation_sdf_loss', avg_val_sdf_loss, (epoch-1))
Пример #3
0
def main_function(experiment_directory, continue_from, batch_split):

    logging.debug("running " + experiment_directory)

    print(experiment_directory)

    specs = utils.misc.load_experiment_specifications(experiment_directory)

    logging.info("Experiment description: \n" + specs["Description"])

    data_source = specs["DataSource"]
    image_source = specs["ImageSource"]
    train_split_file = specs["TrainSplit"]
    val_split_file = get_spec_with_default(specs, "ValSplit", None)

    is_fhb = get_spec_with_default(specs, "FHB", False)
    if is_fhb:
        print("FHB dataset")

    check_file = get_spec_with_default(specs,"CheckFile", True)

    logging.debug(specs["NetworkSpecs"])

    dataset_name = get_spec_with_default(specs, "Dataset", "obman")

    ### Model Type
    model_type = get_spec_with_default(specs, "ModelType", "1encoder2decoder")
    obj_center = get_spec_with_default(specs, "ObjectCenter", False)

    hand_branch = get_spec_with_default(specs, "HandBranch", True)
    obj_branch = get_spec_with_default(specs, "ObjectBranch", True)

    print("Hand branch:", hand_branch)
    print("Object branch:", obj_branch)
    assert hand_branch or obj_branch

    classifier_branch = get_spec_with_default(specs, "ClassifierBranch", False)
    classifier_weight = get_spec_with_default(specs, "ClassifierWeight", 0.1)
    print("Classifier Weight:", classifier_weight)

    use_gaussian_reconstruction_weight = get_spec_with_default(specs, "GaussianWeightLoss", False)


    do_penetration_loss = get_spec_with_default(specs, "PenetrationLoss", False)
    penetration_loss_weight = get_spec_with_default(specs, "PenetrationLossWeight", 15.0) # 1000.0)
    start_additional_loss = get_spec_with_default(specs, "AdditionalLossStart", 200000) # 500)
    do_contact_loss = get_spec_with_default(specs, "ContactLoss", False)
    contact_loss_weight = get_spec_with_default(specs, "ContactLossWeight", 0.005)
    contact_loss_sigma = get_spec_with_default(specs, "ContactLossSigma", 0.005)
    print("Penetration Loss:", do_penetration_loss)
    print("Penetration Loss Weight:", penetration_loss_weight)
    print("Additional Loss start at epoch:", start_additional_loss)
    print("Contact Loss:", do_contact_loss)
    print("Contact Loss Weight:", contact_loss_weight)
    print("Contact Loss Sigma (m):", contact_loss_sigma)

    latent_size = specs["LatentSize"]

    checkpoints = list(
        range(
            specs["SnapshotFrequency"],
            specs["NumEpochs"] + 1,
            specs["SnapshotFrequency"],
        )
    )
    for checkpoint in specs["AdditionalSnapshots"]:
        checkpoints.append(checkpoint)
    checkpoints.sort()

    lr_schedules = get_learning_rate_schedules(specs)


    grad_clip = get_spec_with_default(specs, "GradientClipNorm", None)
    if grad_clip is not None:
        logging.debug("clipping gradients to max norm {}".format(grad_clip))

    def save_latest(epoch):
        save_model(experiment_directory, "latest.pth", encoderDecoder, epoch)
        save_optimizer(experiment_directory, "latest.pth", optimizer_all, epoch)
        # save_latent_vectors(experiment_directory, "latest.pth", lat_vecs, epoch)

    def save_checkpoints(epoch):
        save_model(experiment_directory, str(epoch) + ".pth", encoderDecoder, epoch)
        save_optimizer(experiment_directory, str(epoch) + ".pth", optimizer_all, epoch)
        # save_latent_vectors(experiment_directory, str(epoch) + ".pth", lat_vecs, epoch)

    def signal_handler(sig, frame):
        logging.info("Stopping early...")
        sys.exit(0)

    def adjust_learning_rate(lr_schedules, optimizer, epoch):

        for i, param_group in enumerate(optimizer.param_groups):
            param_group["lr"] = lr_schedules[i].get_learning_rate(epoch)

    signal.signal(signal.SIGINT, signal_handler)

    # If true, use the data as-is. If false, multiply and offset obj location with normalized params
    indep_obj_scale = get_spec_with_default(specs, "IndependentObjScale", False)
    print("Independent Obj Scale:", indep_obj_scale)
    # Ignore points from other mesh in the begining when train 1 decoder
    ignore_other = get_spec_with_default(specs, "IgnorePointFromOtherMesh", False)
    print("Ignore other:", ignore_other)

    num_samp_per_scene = specs["SamplesPerScene"]
    scene_per_batch = specs["ScenesPerBatch"]
    scene_per_subbatch= scene_per_batch
    clamp_dist = specs["ClampingDistance"]
    minT = -clamp_dist
    maxT = clamp_dist
    enforce_minmax = True

    nb_classes = get_spec_with_default(specs["NetworkSpecs"], "num_class", 6)
    print("nb_label_class: ", nb_classes)

    ## Define Model
    if model_type == "PC_2encoder1decoder_VAE":
        kl_schedules = get_kl_weight_schedules(specs)

        input_type = 'point_cloud'
        same_point = True
        # If use 2 encoders, each encoder produces latent vector with half of the total size.
        half_latent_size = int(latent_size/2)
        print("Point cloud encoder, each branch has latent size", half_latent_size)
        
        encoder_obj = arch.ResnetPointnet(c_dim=half_latent_size, hidden_dim=256)
        use_sampling_trick = False
        if use_sampling_trick:
            encoder_hand = arch.ResnetPointnet(c_dim=latent_size, hidden_dim=256)
        else:
            encoder_hand = arch.ResnetPointnet(c_dim=latent_size, hidden_dim=256, cond_dim=latent_size)

        encoder_hand = encoder_hand.to(device)
        encoder_obj = encoder_obj.to(device)
        combined_decoder = arch.CombinedDecoder(latent_size, **specs["NetworkSpecs"], 
                                                use_classifier=classifier_branch).to(device)

        encoderDecoder = arch.ModelTwoEncodersOneDecoderVAE(
            encoder_hand, encoder_obj, combined_decoder, 
            nb_classes, num_samp_per_scene, 
            classifier_branch
        )
        encoderDecoder = encoderDecoder.to(device)

    encoder_input_source = data_source if input_type == 'point_cloud' else image_source

    logging.info("training with {} GPU(s)".format(torch.cuda.device_count()))

    encoderDecoder = torch.nn.DataParallel(encoderDecoder)

    num_epochs = specs["NumEpochs"]
    log_frequency = get_spec_with_default(specs, "LogFrequency", 5)
    log_frequency_step = get_spec_with_default(specs, "LogFrequencyStep", 100)

    logging.debug("torch num_threads: {}".format(torch.get_num_threads()))

    logging.debug(encoderDecoder)

    if "1decoder" in model_type and ignore_other:
        loss_l1 = torch.nn.L1Loss(reduction='sum')
    elif use_gaussian_reconstruction_weight:
        loss_l1 = torch.nn.L1Loss(reduction='none')
    else:
        loss_l1 = torch.nn.L1Loss()
    criterion_ce = torch.nn.CrossEntropyLoss(ignore_index=-1)

    if "VAE" in model_type:
        hand_latent_reg_l2 = torch.nn.MSELoss()

    optimizer_all = torch.optim.Adam(
        [
            {
                "params": encoderDecoder.parameters(),
            }
        ]
    )

    # Tensorboard summary
    writer = SummaryWriter(os.path.join(experiment_directory, 'log'))
    # writer.add_graph(encoderDecoder)

    start_epoch = 1
    # global_step = 0

    # continue from latest checkpoint if exists
    if (continue_from is None and 
            utils.misc.is_checkpoint_exist(experiment_directory, 'latest')):
        continue_from = 'latest'

    if continue_from is not None:
        logging.info('continuing from "{}"'.format(continue_from))

        model_epoch = utils.misc.load_model_parameters(
            experiment_directory, continue_from, encoderDecoder
        )

        optimizer_epoch = load_optimizer(
            experiment_directory, continue_from + ".pth", optimizer_all
        )
        start_epoch = model_epoch + 1
        logging.debug("loaded")

    # Data loader
    filter_dist = False
    if start_epoch >= start_additional_loss:
        same_point = True
        filter_dist = True

    with open(train_split_file, "r") as f:
        train_split = json.load(f)

    sdf_dataset = utils.data.SDFSamples(
        input_type,
        data_source, train_split, num_samp_per_scene,
        dataset_name=dataset_name,
        image_source=image_source,
        hand_branch=hand_branch, obj_branch=obj_branch,
        indep_obj_scale=indep_obj_scale,
        same_point=same_point,
        filter_dist=filter_dist,
        clamp=clamp_dist,
        load_ram=False,
        check_file=check_file, fhb=is_fhb,
        model_type=model_type,
        obj_center=obj_center
    )

    num_data_loader_threads = get_spec_with_default(specs, "DataLoaderThreads", 8)
    logging.debug("loading data with {} threads".format(num_data_loader_threads))

    sdf_loader = data_utils.DataLoader(
        sdf_dataset,
        batch_size=scene_per_subbatch,
        shuffle=True,
        num_workers=num_data_loader_threads,
        drop_last=True
    )

    # training loop
    logging.info("starting from epoch {}".format(start_epoch))

    for epoch in range(start_epoch, num_epochs + 1):
        start = time.time()

        logging.info("epoch {}...".format(epoch))

        encoderDecoder.train()

        adjust_learning_rate(lr_schedules, optimizer_all, epoch)

        if 'VAE' in model_type:
            kl_weight = kl_schedules.get_weight(epoch)

        # Change sdf_loader to get sdf to both hand and object from the same points
        # print("same_point", same_point)
        if epoch == start_additional_loss : # and not same_point:
            same_point = True
            filter_dist = True
            sdf_dataset = utils.data.SDFSamples(
                input_type,
                data_source, train_split, num_samp_per_scene,
                dataset_name=dataset_name,
                image_source=image_source,
                hand_branch=hand_branch, obj_branch=obj_branch,
                indep_obj_scale=indep_obj_scale,
                same_point=same_point,
                filter_dist=filter_dist,
                clamp=clamp_dist,
                load_ram=False, # True
                check_file=check_file, fhb=is_fhb,
                model_type=model_type,
                obj_center=obj_center
            )
            sdf_loader = data_utils.DataLoader(
                sdf_dataset,
                batch_size=scene_per_subbatch,
                shuffle=True,
                num_workers=num_data_loader_threads,
                drop_last=True
            )
        
        for i, (hand_samples, hand_labels, obj_samples, obj_labels,
                scale, offset, encoder_input_hand, encoder_input_obj, idx) in enumerate(sdf_loader):

            batch_loss = 0.0
            optimizer_all.zero_grad()

            for _subbatch in range(batch_split):
                if input_type == 'image':
                    encoder_input_hand = encoder_input_hand.to(device)
                elif input_type == 'point_cloud':
                    encoder_input_hand = encoder_input_hand.to(device)
                    encoder_input_obj = encoder_input_obj.to(device)
                elif input_type == 'image+point_cloud':
                    encoder_input_hand = encoder_input_hand.to(device)
                    encoder_input_obj = encoder_input_obj.to(device)

                if '1decoder' in model_type:
                    # Using same point
                    if hand_branch and obj_branch:
                        samples = torch.cat([hand_samples, obj_samples], 1)
                        labels = torch.cat([hand_labels, obj_labels], 1)

                        # Ignore points from other shape in the begining of the training
                        if ignore_other or epoch < start_additional_loss:
                            mask_hand = torch.cat([torch.ones(hand_samples.size()[:2]), torch.zeros(obj_samples.size()[:2])], 1)
                            mask_hand = (mask_hand.to(device)).reshape(num_samp_per_scene * scene_per_subbatch).unsqueeze(1)
                            mask_obj = torch.cat([torch.zeros(hand_samples.size()[:2]), torch.ones(obj_samples.size()[:2])], 1)
                            mask_obj = (mask_obj.to(device)).reshape(num_samp_per_scene * scene_per_subbatch).unsqueeze(1)
                        else:
                            mask_hand = torch.ones(num_samp_per_scene * scene_per_subbatch).unsqueeze(1).to(device)
                            mask_obj = torch.ones(num_samp_per_scene * scene_per_subbatch).unsqueeze(1).to(device)
                    elif hand_branch:
                        samples = hand_samples
                        labels = hand_labels
                    elif obj_branch:
                        samples = obj_samples
                        labels = obj_labels
                    
                    samples.requires_grad = False
                    labels.requires_grad = False
                    
                    sdf_data = (samples.to(device)).reshape(
                        num_samp_per_scene * scene_per_subbatch, 5 
                    )
                    labels = (labels.to(device).to(torch.long)).reshape(
                        num_samp_per_scene * scene_per_subbatch)
                    xyz_hand = sdf_data[:, 0:3]
                    xyz_obj = xyz_hand
                    sdf_gt_hand = sdf_data[:, 3].unsqueeze(1)
                    sdf_gt_obj = sdf_data[:, 4].unsqueeze(1)
                else:
                    hand_samples.requires_grad = False
                    hand_labels.requires_grad = False
                    obj_samples.requires_grad = False
                    obj_labels.requires_grad = False
                    # Seperated points - Hand
                    if same_point:
                        samples = torch.cat([hand_samples, obj_samples], 1)
                        labels = torch.cat([hand_labels, obj_labels], 1)
                        sdf_data = (samples.to(device)).reshape(
                            num_samp_per_scene * scene_per_subbatch, 5
                        )
                        labels = (labels.to(device).to(torch.long)).reshape(
                            num_samp_per_scene * scene_per_subbatch)
                        hand_labels = labels
                        obj_labels = labels
                        xyz_hand = sdf_data[:, 0:3]
                        xyz_obj = xyz_hand
                        sdf_gt_hand = sdf_data[:, 3].unsqueeze(1)
                        sdf_gt_obj = sdf_data[:, 4].unsqueeze(1)

                    else:
                        sdf_data_hand = (hand_samples.to(device)).reshape(
                            num_samp_per_scene * scene_per_subbatch, 5
                        )
                        hand_labels = (hand_labels.to(device).to(torch.long)).reshape(
                            num_samp_per_scene * scene_per_subbatch)
                        xyz_hand = sdf_data_hand[:, 0:3]
                        sdf_gt_hand = sdf_data_hand[:, 3].unsqueeze(1)

                        # Object
                        sdf_data_obj = (obj_samples.to(device)).reshape(
                            num_samp_per_scene * scene_per_subbatch, 5
                        )
                        obj_labels = (obj_labels.to(device).to(torch.long)).reshape(
                            num_samp_per_scene * scene_per_subbatch)
                        xyz_obj = sdf_data_obj[:, 0:3]
                        sdf_gt_obj = sdf_data_obj[:, 4].unsqueeze(1)
                
                # scale
                scale = scale.to(device).repeat_interleave(num_samp_per_scene, dim=0)

                if enforce_minmax:
                    if hand_branch:
                        sdf_gt_hand = torch.clamp(sdf_gt_hand, minT, maxT)
                    if obj_branch:
                        sdf_gt_obj = torch.clamp(sdf_gt_obj, minT, maxT)
                
                if model_type == 'PC_2encoder1decoder_VAE':
                    pred_sdf_hand, pred_sdf_obj, pred_class, kl_loss, z_hand = encoderDecoder(encoder_input_hand, encoder_input_obj, xyz_hand)
                elif model_type == 'pc+1encoder1decoder':
                    pred_sdf_hand, pred_sdf_obj, pred_class = encoderDecoder(encoder_input_hand, encoder_input_obj, xyz_hand)
                elif '2encoder' in model_type and '1decoder' in model_type:
                    pred_sdf_hand, pred_sdf_obj, pred_class = encoderDecoder(encoder_input_hand, encoder_input_obj, xyz_hand)
                # same points
                elif '1decoder' in model_type: 
                    pred_sdf_hand, pred_sdf_obj, pred_class = encoderDecoder(encoder_input_hand, xyz_hand)
                else:
                    pred_sdf_hand, pred_class_hand,  \
                    pred_sdf_obj, pred_class_obj = encoderDecoder(encoder_input_hand, xyz_hand, xyz_obj)

                if enforce_minmax:
                    if hand_branch:
                        pred_sdf_hand = torch.clamp(pred_sdf_hand, minT, maxT)
                    if obj_branch:
                        pred_sdf_obj = torch.clamp(pred_sdf_obj, minT, maxT)

                ## Compute losses
                sigma_recon = 0.005 * 10.0
                if hand_branch:
                    if "1decoder" in model_type and ignore_other:
                        pred_sdf_hand = torch.mul(pred_sdf_hand, mask_hand)
                        loss_hand = loss_l1(pred_sdf_hand, sdf_gt_hand) / mask_hand.sum()
                    else:
                        loss_hand = loss_l1(pred_sdf_hand, sdf_gt_hand)
                else:
                    loss_hand = 0.
                if obj_branch:
                    if "1decoder" in model_type and ignore_other:
                        pred_sdf_obj = torch.mul(pred_sdf_obj, mask_obj)
                        loss_obj = loss_l1(pred_sdf_obj, sdf_gt_obj) / mask_obj.sum()
                    else:
                        loss_obj = loss_l1(pred_sdf_obj, sdf_gt_obj)
                else:
                    loss_obj = 0.
                
                if classifier_branch:
                    if not '1decoder' in model_type:
                        loss_ce = (criterion_ce(pred_class_hand, hand_labels) + 
                                   criterion_ce(pred_class_obj, obj_labels) ) * classifier_weight
                    else:
                        loss_ce = criterion_ce(pred_class, labels) * classifier_weight
                else:
                    loss_ce = 0

                loss = loss_hand + loss_obj
                if epoch >= start_additional_loss:
                    loss = loss + loss_ce

                if 'VAE' in model_type:
                    # KL-divergence
                    kl_loss_raw = kl_loss.mean()
                    # print("kl loss after mean", kl_loss.size())
                    kl_loss = kl_weight * kl_loss_raw
                    loss = loss + kl_loss

                if hand_branch:
                    scaled_pred_sdf_hand = pred_sdf_hand
                if obj_branch:
                    scaled_pred_sdf_obj = pred_sdf_obj 
                if do_penetration_loss:
                    pen_loss = torch.max(-(scaled_pred_sdf_hand + scaled_pred_sdf_obj), torch.Tensor([0]).to(device)).mean() * penetration_loss_weight
                    if epoch >= start_additional_loss:
                        loss = loss + pen_loss
                
                if do_contact_loss:
                    alpha = 1. / contact_loss_sigma**2
                    contact_loss = torch.min(alpha * (scaled_pred_sdf_hand**2 + scaled_pred_sdf_obj**2), torch.Tensor([1]).to(device)).mean() * contact_loss_weight
                    if epoch >= start_additional_loss:
                        loss = loss + contact_loss
                loss.backward()

                batch_loss += loss.item()

            if grad_clip is not None:
                torch.nn.utils.clip_grad_norm_(encoderDecoder.parameters(), grad_clip)

            if ((epoch-1) * len(sdf_loader) + i) % log_frequency_step == 0:
                loss_hand_out = loss_hand.item() if hand_branch else 0
                loss_obj_out = loss_obj.item() if obj_branch else 0
                loss_ce_out = loss_ce.item() if classifier_branch else 0
                pen_loss_out = pen_loss.item() if do_penetration_loss else 0
                contact_loss_out = contact_loss.item() if do_contact_loss else 0

                print('step {}, loss {:.5f}, hand loss {:.5f}, object loss {:.5f}:, classifier loss {:.5f}, penetration loss {:.5f}, contact loss {:.5f} '.format(
                    (epoch-1) * len(sdf_loader) + i, 
                     loss.item(), loss_hand_out, loss_obj_out, loss_ce_out,
                     pen_loss_out, contact_loss_out))
                if 'VAE' in model_type:
                    print('KL loss {:.5f}'.format(kl_loss.item()))
                    writer.add_scalar('KL_loss_1e-3', kl_loss.item() * 1000.0, (epoch-1) * len(sdf_loader) + i)
                    writer.add_scalar('KL_loss_raw_1e-3', kl_loss_raw.item() * 1000.0, (epoch-1) * len(sdf_loader) + i)
                    
                writer.add_scalar('training_loss_1e-3', loss.item() * 1000.0, (epoch-1) * len(sdf_loader) + i)
                writer.add_scalar('loss_hand_1e-3', loss_hand_out * 1000.0, (epoch-1) * len(sdf_loader) + i)
                writer.add_scalar('loss_object_1e-3', loss_obj_out * 1000.0, (epoch-1) * len(sdf_loader) + i)
                writer.add_scalar('loss_classifier_1e-3', loss_ce_out * 1000.0, (epoch-1) * len(sdf_loader) + i)
                writer.add_scalar('loss_penetration_1e-3', pen_loss_out * 1000.0, (epoch-1) * len(sdf_loader) + i)
                writer.add_scalar('loss_contact_1e-3', contact_loss_out * 1000.0, (epoch-1) * len(sdf_loader) + i)

            optimizer_all.step()
        
        end = time.time()

        seconds_elapsed = end - start
        print("time used:", seconds_elapsed)

        for idx, schedule in enumerate(lr_schedules):
            writer.add_scalar('learning_rate_' + str(idx),
                schedule.get_learning_rate(epoch),
                epoch * len(sdf_loader)
            )

        recon_scale = 0.5 if not indep_obj_scale else 1.0
        
        if epoch in checkpoints and val_split_file:
            save_checkpoints(epoch)
            print("reconstruct mesh at {}".format(epoch))
            recon_st = time.time()
            reconstruct.reconstruct_training(experiment_directory, 
                val_split_file,
                input_type,
                encoder_input_source,
                encoderDecoder,
                epoch,
                specs,
                hand_branch,
                obj_branch,
                model_type=model_type,
                scale=recon_scale, #
                cube_dim=128,
                fhb=is_fhb,
                dataset_name=dataset_name)
            print("- Reconstruction used {}".format(time.time()-recon_st))
            
            # chamfer_st = time.time()
            # object_type, chamfer_mean_list, chamfer_med_list, = evaluate.evaluate(
            #     experiment_directory,
            #     str(epoch),
            #     data_source,
            #     val_split_file,
            # )
            # print("calculate chamfer dist used {}".format(time.time()-chamfer_st))
            # print(" - Chamfer distance:")
            # for i, obj_type in enumerate(object_type):
            #     print("{}: mean: {:.5f}, med: {:.5f}".format(obj_type, chamfer_mean_list[i], chamfer_med_list[i]))
            #     writer.add_scalar(obj_type+'_val_chamfer_mean_x1000', 
            #         chamfer_mean_list[i] * 1000.0, epoch)
            #     writer.add_scalar(obj_type+'_val_chamfer_med_x1000', 
            #         chamfer_med_list[i] * 1000.0, epoch)
        

        if epoch % log_frequency == 0:
            save_latest(epoch)
            print("save at {}".format(epoch))
    

    # End of training
    if val_split_file:
        print("Final reconstruct mesh at {}".format(num_epochs))
        recon_st = time.time()
        reconstruct.reconstruct_training(experiment_directory, 
            val_split_file,
            input_type,
            encoder_input_source,
            encoderDecoder,
            num_epochs,
            specs,
            hand_branch,
            obj_branch,
            model_type=model_type,
            scale=recon_scale, #
            cube_dim=256,
            fhb=is_fhb,
            dataset_name=dataset_name)
        print("- Final Reconstruction used {}".format(time.time()-recon_st))
        
        # chamfer_st = time.time()
        # object_type, chamfer_mean_list, chamfer_med_list, = evaluate.evaluate(
        #     experiment_directory,
        #     str(num_epochs),
        #     data_source,
        #     val_split_file,
        # )
        # print("calculate final chamfer dist used {}".format(time.time()-chamfer_st))
        # print(" - Chamfer distance:")
        # for i, obj_type in enumerate(object_type):
        #     print("{}: mean: {:.5f}, med: {:.5f}".format(obj_type, chamfer_mean_list[i], chamfer_med_list[i]))
        #     writer.add_scalar(obj_type+'_val_chamfer_mean_x1000', 
        #         chamfer_mean_list[i] * 1000.0, num_epochs)
        #     writer.add_scalar(obj_type+'_val_chamfer_med_x1000', 
        #         chamfer_med_list[i] * 1000.0, num_epochs)

    writer.close()
Пример #4
0
def run_visualization(experiment_directory, checkpoint):

    specs = utils.misc.load_experiment_specifications(experiment_directory)

    data_source = Path(specs["DataSource"])
    train_split_file = specs["TrainSplit"]
    val_split_file = specs["ValSplit"]
    subsample = specs["SamplesPerScene"]
    scene_per_batch = specs["ScenesPerBatch"]
    latent_size = specs["LatentSize"]
    num_epochs = specs["NumEpochs"]
    num_data_loader_threads = get_spec_with_default(specs, "DataLoaderThreads", 8)
    val_split_file = get_spec_with_default(specs, "ValSplit", None)
    nb_classes = get_spec_with_default(specs["NetworkSpecs"], "num_class", 6)
    log_frequency = get_spec_with_default(specs, "LogFrequency", 5)
    log_frequency_step = get_spec_with_default(specs, "LogFrequencyStep", 100)

    with open(train_split_file, "r") as f:
        train_split = json.load(f)

    with open(val_split_file, "r") as f:
        val_split = json.load(f)

    sdf_dataset = utils.data.SDFAssemblySamples(data_source, train_split, subsample)
    sdf_val_dataset = utils.data.SDFAssemblySamples(data_source, val_split, subsample)

    sdf_loader = data_utils.DataLoader(
        sdf_dataset,
        batch_size=1,
        shuffle=True,
        num_workers=num_data_loader_threads,
        drop_last=True
    )
    sdf_val_loader = data_utils.DataLoader(
        sdf_val_dataset,
        batch_size=1,
        shuffle=True,
        num_workers=num_data_loader_threads,
        drop_last=True
    )

    half_latent_size = int(latent_size/2)

    encoder_part1 = arch.ResnetPointnet(c_dim=half_latent_size, hidden_dim=256).to(device)
    encoder_part2 = arch.ResnetPointnet(c_dim=half_latent_size, hidden_dim=256).to(device)
    print("Point cloud encoder, each branch has latent size", half_latent_size)

    decoder = arch.ShapeAssemblyDecoder(
        latent_size, 
        **specs["NetworkSpecs"]
    )

    encoder_decoder = arch.ModelShapeAssemblyEncoderDecoderVAE(
        encoder_part1,
        encoder_part2,
        decoder,
        nb_classes,
        subsample
    )
    encoder_decoder = encoder_decoder.to(device)
    encoder_decoder.share_memory()

    model_epoch = utils.misc.load_model_parameters(
        experiment_directory, checkpoint, encoder_decoder
    )

    logging.info("Loaded weights!")

    total_validation_error = 0
    # Run validation
    for i, (
        part1,
        part2,
        gt_transformed_part1_points,
        gt_transform) in enumerate(sdf_val_loader):

            shape_assembly_visualization_function(part1, part2, gt_transform, encoder_decoder, subsample, scene_per_batch)