Ejemplo n.º 1
0
    def run(self, configuration: Configuration) -> None:
        seed_all(configuration.get("seed"))

        metadata = load_metadata(configuration.metadata)

        architecture_configuration = load_configuration(configuration.architecture)
        self.validate_architecture_configuration(architecture_configuration)
        architecture = create_architecture(metadata, architecture_configuration)
        architecture.to_gpu_if_available()

        checkpoints = Checkpoints()
        checkpoint = checkpoints.load(configuration.checkpoint)
        if "best_architecture" in checkpoint:
            checkpoints.load_states(checkpoint["best_architecture"], architecture)
        else:
            checkpoints.load_states(checkpoint["architecture"], architecture)

        # pre-processing
        imputation = create_component(architecture, metadata, configuration.imputation)

        pre_processing = PreProcessing(imputation)

        # post-processing
        if "scale_transform" in configuration:
            scale_transform = load_scale_transform(configuration.scale_transform)
        else:
            scale_transform = None

        post_processing = PostProcessing(metadata, scale_transform)

        # load the features
        features = to_gpu_if_available(torch.from_numpy(np.load(configuration.features)).float())
        missing_mask = to_gpu_if_available(torch.from_numpy(np.load(configuration.missing_mask)).float())

        # initial imputation
        batch = pre_processing.transform({"features": features, "missing_mask": missing_mask})

        # generate the model outputs
        output = self.impute(configuration, metadata, architecture, batch)

        # imputation
        output = compose_with_mask(mask=missing_mask, differentiable=False, where_one=output, where_zero=features)

        # post-process
        output = post_processing.transform(output)

        # save the imputation
        output = to_cpu_if_was_in_gpu(output)
        output = output.numpy()
        np.save(configuration.output, output)
Ejemplo n.º 2
0
 def train_discriminator_steps(self, configuration: Configuration, metadata: Metadata, architecture: Architecture,
                               batch_iterator: Iterator[Batch], pre_processing: PreProcessing) -> List[float]:
     losses = []
     for _ in range(configuration.discriminator_steps):
         batch = pre_processing.transform(next(batch_iterator))
         loss = self.train_discriminator_step(configuration, metadata, architecture, batch)
         losses.append(loss)
     return losses
Ejemplo n.º 3
0
 def train_autoencoder_steps(self, configuration: Configuration,
                             architecture: Architecture,
                             batch_iterator: Iterator[Batch],
                             pre_processing: PreProcessing) -> List[float]:
     losses = []
     for _ in range(configuration.autoencoder_steps):
         batch = pre_processing.transform(next(batch_iterator))
         loss = self.autoencoder_train_task.train_batch(architecture, batch)
         losses.append(loss)
     return losses
Ejemplo n.º 4
0
    def train_epoch(self, configuration: Configuration, metadata: Metadata,
                    architecture: Architecture, datasets: Datasets,
                    pre_processing: PreProcessing,
                    post_processing: PostProcessing) -> Dict[str, float]:
        # train
        architecture.autoencoder.train()
        architecture.generator.train()
        architecture.discriminator.train()

        # prepare to accumulate losses per batch
        losses_by_batch = {
            "autoencoder": [],
            "generator": [],
            "discriminator": []
        }

        # basic data
        train_datasets = Datasets({"features": datasets.train_features})
        val_datasets = Datasets({"features": datasets.val_features})

        # conditional
        if "conditional" in architecture.arguments:
            train_datasets["labels"] = datasets.train_labels
            val_datasets["labels"] = datasets.val_labels

        # missing mask
        if "train_missing_mask" in datasets:
            train_datasets["missing_mask"] = datasets.train_missing_mask
        if "val_missing_mask" in datasets:
            val_datasets["missing_mask"] = datasets.val_missing_mask

        # an epoch will stop at any point if there are no more batches
        # it does not matter if there are models with remaining steps
        data_iterator = self.iterate_datasets(configuration, train_datasets)

        while True:
            try:
                losses_by_batch["autoencoder"].extend(
                    self.train_autoencoder_steps(configuration, architecture,
                                                 data_iterator,
                                                 pre_processing))

                losses_by_batch["discriminator"].extend(
                    self.train_discriminator_steps(configuration, metadata,
                                                   architecture, data_iterator,
                                                   pre_processing))

                losses_by_batch["generator"].extend(
                    self.train_generator_steps(configuration, metadata,
                                               architecture))
            except StopIteration:
                break

        # loss aggregation
        losses = {}

        if configuration.autoencoder_steps > 0:
            losses["autoencoder_train_mean_loss"] = np.mean(
                losses_by_batch["autoencoder"]).item()

        if configuration.discriminator_steps > 0:
            losses["discriminator_train_mean_loss"] = np.mean(
                losses_by_batch["discriminator"]).item()

        if configuration.generator_steps > 0:
            losses["generator_train_mean_loss"] = np.mean(
                losses_by_batch["generator"]).item()

        # validation
        architecture.autoencoder.eval()

        autoencoder_val_losses_by_batch = []

        for batch in self.iterate_datasets(configuration, val_datasets):
            batch = pre_processing.transform(batch)
            autoencoder_val_losses_by_batch.append(
                self.autoencoder_train_task.val_batch(architecture, batch,
                                                      post_processing))

        losses["autoencoder_val_mean_loss"] = np.mean(
            autoencoder_val_losses_by_batch).item()

        return losses
Ejemplo n.º 5
0
    def run(self, configuration: Configuration) -> None:
        seed_all(configuration.get("seed"))

        datasets = Datasets()
        for dataset_name, dataset_path in configuration.data.items():
            datasets[dataset_name] = to_gpu_if_available(torch.from_numpy(np.load(dataset_path)).float())

        metadata = load_metadata(configuration.metadata)

        architecture_configuration = load_configuration(configuration.architecture)
        self.validate_architecture_configuration(architecture_configuration)
        architecture = create_architecture(metadata, architecture_configuration)
        architecture.to_gpu_if_available()

        create_parent_directories_if_needed(configuration.checkpoints.output)
        checkpoints = Checkpoints()

        # no input checkpoint by default
        checkpoint = None

        # continue from an output checkpoint (has priority over input checkpoint)
        if configuration.checkpoints.get("continue_from_output", default=False) \
                and checkpoints.exists(configuration.checkpoints.output):
            checkpoint = checkpoints.load(configuration.checkpoints.output)
        # continue from an input checkpoint
        elif "input" in configuration.checkpoints:
            checkpoint = checkpoints.load(configuration.checkpoints.input)
            if configuration.checkpoints.get("ignore_input_epochs", default=False):
                checkpoint["epoch"] = 0
            if configuration.checkpoints.get("use_best_input", default=False):
                checkpoint["architecture"] = checkpoint.pop("best_architecture")
                checkpoint.pop("best_epoch")
                checkpoint.pop("best_metric")

        # if there is no starting checkpoint then initialize
        if checkpoint is None:
            architecture.initialize()

            checkpoint = {
                "architecture": checkpoints.extract_states(architecture),
                "epoch": 0
            }
        # if there is a starting checkpoint then load it
        else:
            checkpoints.load_states(checkpoint["architecture"], architecture)

        log_path = create_parent_directories_if_needed(configuration.logs)
        logger = TrainLogger(self.logger, log_path, checkpoint["epoch"] > 0)

        # pre-processing
        if "imputation" in configuration:
            imputation = create_component(architecture, metadata, configuration.imputation)
        else:
            imputation = None

        pre_processing = PreProcessing(imputation)

        # post-processing
        if "scale_transform" in configuration:
            scale_transform = load_scale_transform(configuration.scale_transform)
        else:
            scale_transform = None

        post_processing = PostProcessing(metadata, scale_transform)

        for epoch in range(checkpoint["epoch"] + 1, configuration.epochs + 1):
            # train discriminator and generator
            logger.start_timer()

            metrics = self.train_epoch(configuration, metadata, architecture, datasets, pre_processing, post_processing)

            for metric_name, metric_value in metrics.items():
                logger.log(epoch, configuration.epochs, metric_name, metric_value)

            # update the checkpoint
            checkpoint["architecture"] = checkpoints.extract_states(architecture)
            checkpoint["epoch"] = epoch

            # if the best architecture parameters should be kept
            if "keep_checkpoint_by_metric" in configuration:
                # get the metric used to compare checkpoints
                checkpoint_metric = metrics[configuration.keep_checkpoint_by_metric]

                # check if this is the best checkpoint (or the first)
                if "best_metric" not in checkpoint or checkpoint_metric < checkpoint["best_metric"]:
                    checkpoint["best_architecture"] = checkpoint["architecture"]
                    checkpoint["best_epoch"] = epoch
                    checkpoint["best_metric"] = checkpoint_metric

            # save checkpoint
            checkpoints.delayed_save(checkpoint, configuration.checkpoints.output, configuration.checkpoints.max_delay)

        # force save of last checkpoint
        checkpoints.save(checkpoint, configuration.checkpoints.output)

        # finish
        logger.close()