示例#1
0
 def evaluate(self,
              filename_testset,
              stack_key="stack_val",
              disp_key="disp_val",
              image_size=(383, 552)):
     #Calculate pad size for images
     test_pad_size = (np.ceil(
         (image_size[0] / 32)) * 32, np.ceil(
             (image_size[1] / 32)) * 32)  #32=2**numPoolings(=5)
     #Create test set transforms
     transform_test = [
         FocalStackDDFFH5Reader.FocalStackDDFFH5Reader.ToTensor(),
         FocalStackDDFFH5Reader.FocalStackDDFFH5Reader.ClipGroundTruth(
             0.0202, 0.2825),
         FocalStackDDFFH5Reader.FocalStackDDFFH5Reader.PadSamples(
             test_pad_size),
         FocalStackDDFFH5Reader.FocalStackDDFFH5Reader.Normalize(
             mean_input=[0.485, 0.456, 0.406],
             std_input=[0.229, 0.224, 0.225])
     ]
     transform_test = torchvision.transforms.Compose(transform_test)
     #Create dataloader
     datareader = FocalStackDDFFH5Reader.FocalStackDDFFH5Reader(
         filename_testset,
         transform=transform_test,
         stack_key=stack_key,
         disp_key=disp_key)
     dataloader = DataLoader(datareader,
                             batch_size=1,
                             shuffle=False,
                             num_workers=0)
     return super(DDFFEval, self).evaluate(dataloader)
示例#2
0
 def evaluate(self,
              filename_testset,
              stack_key="stack_val",
              disp_key="disp_val",
              image_size=(383, 552)):
     #Calculate pad size for images
     test_pad_size = (np.ceil(
         (image_size[0] / 32)) * 32, np.ceil(
             (image_size[1] / 32)) * 32)  #32=2**numPoolings(=5)
     #Create test set transforms
     transform_test = [
         FocalStackDDFFH5Reader.FocalStackDDFFH5Reader.ToTensor(),
         FocalStackDDFFH5Reader.FocalStackDDFFH5Reader.PadSamples(
             test_pad_size)
     ]
     if self.norm_mean is not None and self.norm_std is not None:
         transform_test += [
             FocalStackDDFFH5Reader.FocalStackDDFFH5Reader.Normalize(
                 mean_input=self.norm_mean, std_input=self.norm_std)
         ]
     transform_test = torchvision.transforms.Compose(transform_test)
     #Create dataloader
     datareader = FocalStackDDFFH5Reader.FocalStackDDFFH5Reader(
         filename_testset,
         transform=transform_test,
         stack_key=stack_key,
         disp_key=disp_key)
     dataloader = DataLoader(datareader,
                             batch_size=1,
                             shuffle=False,
                             num_workers=0)
     return super(DDFFTFLearnEval, self).evaluate(dataloader)
示例#3
0
def main():
    #Set parameters
    image_size = (383,552)
    filename_testset = "../../ddff-dataset-trainval.h5"
    checkpoint_file = "checkpoints/ddff_cc3_checkpoint_1.pt"

    #Create validation reader
    tmp_datareader = FocalStackDDFFH5Reader.FocalStackDDFFH5Reader(filename_testset, transform=None, stack_key="stack_val", disp_key="disp_val")

    #Create PSPDDFF evaluator
    evaluator = DDFFEval.DDFFEval(checkpoint_file, focal_stack_size=tmp_datareader.get_stack_size())
    #Evaluate
    metrics = evaluator.evaluate(filename_testset, image_size=image_size)
    print(metrics)
示例#4
0
import ddff.dataproviders.datareaders.FocalStackDDFFH5Reader as FocalStackDDFFH5Reader
import ddff.metricseval.DDFFTFLearnEval as DDFFTFLearnEval

if __name__ == "__main__":
    #Set parameters
    image_size = (383,552)
    filename_testset = "ddff-dataset-trainval.h5"
    checkpoint_file = "ddffnet-cc3-snapshot-121256.npz"
    stack_key = "stack_val"
    disp_key="disp_val"

    #Create validation reader
    tmp_datareader = FocalStackDDFFH5Reader.FocalStackDDFFH5Reader(filename_testset, transform=None, stack_key=stack_key, disp_key=disp_key)

    #Create PSPDDFF evaluator
    evaluator = DDFFTFLearnEval.DDFFTFLearnEval(checkpoint_file, focal_stack_size=tmp_datareader.get_stack_size(), norm_mean=None, norm_std=None)
    #Evaluate
    metrics = evaluator.evaluate(filename_testset, stack_key=stack_key, disp_key=disp_key, image_size=image_size)
    print(metrics)
示例#5
0
    def from_h5_data(cls,
                     root_dir,
                     learning_rate=0.001,
                     cc1_enabled=False,
                     cc2_enabled=False,
                     cc3_enabled=True,
                     cc4_enabled=False,
                     cc5_enabled=False,
                     training_crop_size=None,
                     validation_crop_size=None,
                     pretrained='no_bn',
                     normalize_mean=[0.485, 0.456, 0.406],
                     normalize_std=[0.229, 0.224, 0.225],
                     scheduler_step_size=4,
                     scheduler_gama=0.9,
                     max_gradient=5.0,
                     deterministic=False,
                     optimizer='sgd',
                     normalize_loss=False,
                     epochs=20,
                     batch_size=2,
                     num_workers=4,
                     checkpoint_file=None,
                     checkpoint_frequency=50):
        #Create data loaders
        transform_train = cls.__create_preprocessing(
            cls,
            crop_size=training_crop_size,
            mean=normalize_mean,
            std=normalize_std)
        transform_validation = cls.__create_preprocessing(
            cls,
            crop_size=validation_crop_size,
            mean=normalize_mean,
            std=normalize_std)
        #Create h5 reader
        dataset_train = FocalStackDDFFH5Reader.FocalStackDDFFH5Reader(
            root_dir,
            transform=transform_train,
            stack_key="stack_train",
            disp_key="disp_train")
        dataset_validation = FocalStackDDFFH5Reader.FocalStackDDFFH5Reader(
            root_dir,
            transform=transform_validation,
            stack_key="stack_val",
            disp_key="disp_val")
        #Create data loader
        dataloader_train = DataLoader(dataset_train,
                                      batch_size=batch_size,
                                      shuffle=True,
                                      num_workers=num_workers)
        dataloader_validation = DataLoader(dataset_validation,
                                           batch_size=1,
                                           shuffle=True,
                                           num_workers=0)
        #Call constructor
        instance = cls(dataset_train.get_stack_size(),
                       learning_rate=learning_rate,
                       cc1_enabled=cc1_enabled,
                       cc2_enabled=cc2_enabled,
                       cc3_enabled=cc3_enabled,
                       cc4_enabled=cc4_enabled,
                       cc5_enabled=cc5_enabled,
                       pretrained=pretrained,
                       scheduler_step_size=scheduler_step_size,
                       scheduler_gama=scheduler_gama,
                       deterministic=deterministic,
                       optimizer=optimizer,
                       normalize_loss=normalize_loss)

        #Save instance variables
        instance.dataloader_validation = dataloader_validation

        #Fit instance
        epoch_losses = instance.train(
            dataloader_train,
            epochs,
            checkpoint_file=checkpoint_file,
            checkpoint_frequency=checkpoint_frequency,
            max_gradient=max_gradient)
        print("Losses per epoch: " + str(epoch_losses))

        return instance