예제 #1
0
    def run(self, configuration: Configuration) -> None:
        seed_all(configuration.get("seed"))

        metadata = load_metadata(configuration.metadata)

        architecture_configuration = load_configuration(configuration.architecture)
        self.validate_architecture_configuration(architecture_configuration)
        architecture = create_architecture(metadata, architecture_configuration)
        architecture.to_gpu_if_available()

        checkpoints = Checkpoints()
        checkpoint = checkpoints.load(configuration.checkpoint)
        if "best_architecture" in checkpoint:
            checkpoints.load_states(checkpoint["best_architecture"], architecture)
        else:
            checkpoints.load_states(checkpoint["architecture"], architecture)

        # pre-processing
        imputation = create_component(architecture, metadata, configuration.imputation)

        pre_processing = PreProcessing(imputation)

        # post-processing
        if "scale_transform" in configuration:
            scale_transform = load_scale_transform(configuration.scale_transform)
        else:
            scale_transform = None

        post_processing = PostProcessing(metadata, scale_transform)

        # load the features
        features = to_gpu_if_available(torch.from_numpy(np.load(configuration.features)).float())
        missing_mask = to_gpu_if_available(torch.from_numpy(np.load(configuration.missing_mask)).float())

        # initial imputation
        batch = pre_processing.transform({"features": features, "missing_mask": missing_mask})

        # generate the model outputs
        output = self.impute(configuration, metadata, architecture, batch)

        # imputation
        output = compose_with_mask(mask=missing_mask, differentiable=False, where_one=output, where_zero=features)

        # post-process
        output = post_processing.transform(output)

        # save the imputation
        output = to_cpu_if_was_in_gpu(output)
        output = output.numpy()
        np.save(configuration.output, output)
예제 #2
0
    def run(self, configuration: Configuration) -> None:
        seed_all(configuration.get("seed"))

        metadata = load_metadata(configuration.metadata)

        architecture_configuration = load_configuration(
            configuration.architecture)
        self.validate_architecture_configuration(architecture_configuration)
        architecture = create_architecture(metadata,
                                           architecture_configuration)
        architecture.to_gpu_if_available()

        checkpoints = Checkpoints()
        checkpoint = checkpoints.load(configuration.checkpoint)
        if "best_architecture" in checkpoint:
            checkpoints.load_states(checkpoint["best_architecture"],
                                    architecture)
        else:
            checkpoints.load_states(checkpoint["architecture"], architecture)

        # load the features
        features = to_gpu_if_available(
            torch.from_numpy(np.load(configuration.features)).float())

        # conditional
        if "labels" in configuration:
            condition = to_gpu_if_available(
                torch.from_numpy(np.load(configuration.labels)).float())
        else:
            condition = None

        # encode
        with torch.no_grad():
            code = architecture.autoencoder.encode(features,
                                                   condition=condition)["code"]

        # save the code
        code = to_cpu_if_was_in_gpu(code)
        code = code.numpy()
        np.save(configuration.output, code)
예제 #3
0
 def run(self, configuration: Configuration) -> None:
     metadata = load_metadata(configuration.metadata)
     architecture = create_architecture(
         metadata, load_configuration(configuration.architecture))
     size = compute_parameter_size(architecture)
     self.logger.info("{}: {:d}".format(configuration.name, size))
예제 #4
0
    def run(self, configuration: Configuration) -> None:
        seed_all(configuration.get("seed"))

        metadata = load_metadata(configuration.metadata)

        if "scale_transform" in configuration:
            scale_transform = load_scale_transform(
                configuration.scale_transform)
        else:
            scale_transform = None

        post_processing = PostProcessing(metadata, scale_transform)

        architecture_configuration = load_configuration(
            configuration.architecture)
        self.validate_architecture_configuration(architecture_configuration)
        architecture = create_architecture(metadata,
                                           architecture_configuration)
        architecture.to_gpu_if_available()

        checkpoints = Checkpoints()
        checkpoint = checkpoints.load(configuration.checkpoint)
        if "best_architecture" in checkpoint:
            checkpoints.load_states(checkpoint["best_architecture"],
                                    architecture)
        else:
            checkpoints.load_states(checkpoint["architecture"], architecture)

        samples = []

        # create the strategy if defined
        if "strategy" in configuration:
            # validate strategy name is present
            if "factory" not in configuration.strategy:
                raise Exception(
                    "Missing factory name while creating sample strategy.")

            # validate strategy name
            strategy_name = configuration.strategy.factory
            if strategy_name not in strategy_class_by_name:
                raise Exception(
                    "Invalid factory name '{}' while creating sample strategy."
                    .format(strategy_name))

            # create the strategy
            strategy_class = strategy_class_by_name[strategy_name]
            strategy = strategy_class(**configuration.strategy.get(
                "arguments", default={}, transform_default=False))

        # use the default strategy
        else:
            strategy = DefaultSampleStrategy()

        # this is only to pass less parameters back and forth
        sampler = Sampler(self, configuration, metadata, architecture,
                          post_processing)

        # while more samples are needed
        start = 0
        while start < configuration.sample_size:
            # do not calculate gradients
            with torch.no_grad():
                # sample:
                # the task delegates to the strategy and passes the sampler object to avoid passing even more parameters
                #   the strategy may prepare additional sampling arguments (e.g. condition)
                #   the strategy delegates to the sampler object
                #     the sampler object delegates back to the task adding parameters that it was keeping
                #       the task child class does the actual sampling depending on the model
                #     the sampler object applies post-processing
                #   the strategy may apply filtering to the samples (e.g. rejection)
                # the task finally gets the sample
                batch_samples = strategy.generate_sample(
                    sampler, configuration, metadata)

            # transform back the samples
            batch_samples = to_cpu_if_was_in_gpu(batch_samples)
            batch_samples = batch_samples.numpy()

            # if the batch is not empty
            if len(batch_samples) > 0:
                # do not go further than the desired number of samples
                end = min(start + len(batch_samples),
                          configuration.sample_size)
                # limit the samples taken from the batch based on what is missing
                batch_samples = batch_samples[:min(len(batch_samples), end -
                                                   start), :]
                # if it is the first batch
                if len(samples) == 0:
                    samples = batch_samples
                # if its not the first batch we have to concatenate
                else:
                    samples = np.concatenate((samples, batch_samples), axis=0)
                # move to next batch
                start = end

        # save the samples
        np.save(configuration.output, samples)
예제 #5
0
    def run(self, configuration: Configuration) -> None:
        seed_all(configuration.get("seed"))

        datasets = Datasets()
        for dataset_name, dataset_path in configuration.data.items():
            datasets[dataset_name] = to_gpu_if_available(torch.from_numpy(np.load(dataset_path)).float())

        metadata = load_metadata(configuration.metadata)

        architecture_configuration = load_configuration(configuration.architecture)
        self.validate_architecture_configuration(architecture_configuration)
        architecture = create_architecture(metadata, architecture_configuration)
        architecture.to_gpu_if_available()

        create_parent_directories_if_needed(configuration.checkpoints.output)
        checkpoints = Checkpoints()

        # no input checkpoint by default
        checkpoint = None

        # continue from an output checkpoint (has priority over input checkpoint)
        if configuration.checkpoints.get("continue_from_output", default=False) \
                and checkpoints.exists(configuration.checkpoints.output):
            checkpoint = checkpoints.load(configuration.checkpoints.output)
        # continue from an input checkpoint
        elif "input" in configuration.checkpoints:
            checkpoint = checkpoints.load(configuration.checkpoints.input)
            if configuration.checkpoints.get("ignore_input_epochs", default=False):
                checkpoint["epoch"] = 0
            if configuration.checkpoints.get("use_best_input", default=False):
                checkpoint["architecture"] = checkpoint.pop("best_architecture")
                checkpoint.pop("best_epoch")
                checkpoint.pop("best_metric")

        # if there is no starting checkpoint then initialize
        if checkpoint is None:
            architecture.initialize()

            checkpoint = {
                "architecture": checkpoints.extract_states(architecture),
                "epoch": 0
            }
        # if there is a starting checkpoint then load it
        else:
            checkpoints.load_states(checkpoint["architecture"], architecture)

        log_path = create_parent_directories_if_needed(configuration.logs)
        logger = TrainLogger(self.logger, log_path, checkpoint["epoch"] > 0)

        # pre-processing
        if "imputation" in configuration:
            imputation = create_component(architecture, metadata, configuration.imputation)
        else:
            imputation = None

        pre_processing = PreProcessing(imputation)

        # post-processing
        if "scale_transform" in configuration:
            scale_transform = load_scale_transform(configuration.scale_transform)
        else:
            scale_transform = None

        post_processing = PostProcessing(metadata, scale_transform)

        for epoch in range(checkpoint["epoch"] + 1, configuration.epochs + 1):
            # train discriminator and generator
            logger.start_timer()

            metrics = self.train_epoch(configuration, metadata, architecture, datasets, pre_processing, post_processing)

            for metric_name, metric_value in metrics.items():
                logger.log(epoch, configuration.epochs, metric_name, metric_value)

            # update the checkpoint
            checkpoint["architecture"] = checkpoints.extract_states(architecture)
            checkpoint["epoch"] = epoch

            # if the best architecture parameters should be kept
            if "keep_checkpoint_by_metric" in configuration:
                # get the metric used to compare checkpoints
                checkpoint_metric = metrics[configuration.keep_checkpoint_by_metric]

                # check if this is the best checkpoint (or the first)
                if "best_metric" not in checkpoint or checkpoint_metric < checkpoint["best_metric"]:
                    checkpoint["best_architecture"] = checkpoint["architecture"]
                    checkpoint["best_epoch"] = epoch
                    checkpoint["best_metric"] = checkpoint_metric

            # save checkpoint
            checkpoints.delayed_save(checkpoint, configuration.checkpoints.output, configuration.checkpoints.max_delay)

        # force save of last checkpoint
        checkpoints.save(checkpoint, configuration.checkpoints.output)

        # finish
        logger.close()