コード例 #1
0
def plot3d():
    """Create a GIF from 3D image in the training set."""
    for channel in ["ch1", "ch2", "ch3", "ch4", "ch6"]:
        config = default_config()
        config.features = [Feature.image]
        config.channels = [channel]
        config.crop_size = (300, 300)
        config.time_interval_min = 30
        config.num_images = 10

        train_dataset, _, _ = load_data(config=config)
        images = _first_image(train_dataset)

        image_names = []
        for i, image in enumerate(images[0]):
            plt.cla()
            plt.clf()
            plt.axis("off")
            plt.tight_layout()

            image2d = image[:, :, 0]
            plt.imshow(image2d, cmap="gray")

            name = f"assets/{i}.png"
            plt.savefig(name)

            image_names.append(name)
        make_gif(f"assets/image-3d-{channel}.gif", image_names)
コード例 #2
0
    def train(
        self,
        optimizer: tf.keras.optimizers,
        epochs=25,
        enable_checkpoint=True,
        cache_file=None,
        checkpoint=None,
    ):
        """Performs the training of the model in minibatch."""
        config = self.model.config()
        logger.info(f"Starting training\n" +
                    f" - Model: {self.model.title}\n" + f" - Config: {config}")

        train_set, valid_set, _ = load_data(
            config=config,
            skip_non_cached=self.skip_non_cached,
        )

        logger.info("Apply Preprocessing")
        train_set = self.model.preprocess(train_set)
        valid_set = self.model.preprocess(valid_set)

        if cache_file is not None:
            train_set = train_set.cache(f"{cache_file}-train")
            valid_set = valid_set.cache(f"{cache_file}-valid")

        logger.info("Fitting model.")

        init_epoch = 0
        if checkpoint is not None:
            # Checkpoints are epoch
            init_epoch = int(checkpoint) + 1
            # Load model weight
            self.model.load(checkpoint)
            # Load history
            self.history = History.load(f"{self.model.title}-{checkpoint}")

        for epoch in range(init_epoch, epochs):
            logger.info("Training...")

            for i, data in enumerate(train_set.batch(self.batch_size)):
                inputs = data[:-1]
                targets = data[-1]

                self._train_step(optimizer, inputs, targets, i + 1)

            logger.info("Evaluating validation loss")
            self._evaluate("valid", valid_set, self.batch_size)

            logger.info("Checkpointing...")
            self.model.save(str(epoch))

            self._update_progress(epoch)
            self.history.save(f"{self.model.title}-{epoch}")

        logger.info("Done.")
コード例 #3
0
    def test(self, checkpoint: str):
        """Test a trained model on the test set.

        Must be specified a checkpointted epoch
        """
        config = self.model.config()
        self.model.load(checkpoint)

        _, _, test_set = load_data(config=config,
                                   skip_non_cached=self.skip_non_cached)
        test_set = self.model.preprocess(test_set)

        self._evaluate("test", test_set, self.batch_size)
        self.history.save(f"{self.model.title}-{checkpoint}-test-set")
コード例 #4
0
def run(
    enable_tf_caching=False, skip_non_cached=False,
):
    """Performs a dry run with the data generators."""
    logger.info("Dry Run.")
    # Only test the generators, for debugging weird behavior and corner cases.
    (train_generator, valid_generator, test_generator,) = load_data(
        enable_tf_caching=enable_tf_caching, skip_non_cached=skip_non_cached
    )
    for sample in train_generator:
        print(
            sample
        )  # Just make sure that we can get a single sample out of the dry-run
        break
コード例 #5
0
def plot_comparison(encoder_instance: str, seq2seq_instance: str):
    """Show original and generated futur images in a grid."""
    encoder = Encoder()
    encoder.load(encoder_instance)
    model = Gru(encoder)
    model.load(seq2seq_instance)
    config = model.config()
    config.num_images = 6
    config.skip_missing_past_images = True

    decoder = Decoder(len(config.channels))
    decoder.load(encoder_instance)

    valid_dataset, _, _ = load_data(config=config)
    images_originals = _first_images(valid_dataset, 6)

    image_pred = model.predict_next_images(images_originals[0, :3], 3)

    before = model.scaling_image.normalize(images_originals[0, :3])
    before = encoder(before, training=False)
    before = decoder(before, training=False)

    images_originals = model.scaling_image.normalize(images_originals[0, 3:])
    images_originals = encoder(images_originals, training=False)
    images_originals = decoder(images_originals, training=False)

    before = model.scaling_image.original(before)
    images_originals = model.scaling_image.original(images_originals)
    image_pred = model.scaling_image.original(np.array(image_pred))

    generateds = []
    originals = []
    befores = []

    for i in range(3):
        generated = image_pred[i]
        original = images_originals[i]
        bef = before[i]
        channel = 0

        originals.append(original[:, :, channel])
        generateds.append(generated[:, :, channel])
        befores.append(bef[:, :, channel])

    _plt_images(originals, generateds, befores, config.crop_size)
    plt.savefig(f"assets/seq2seq.png")
コード例 #6
0
def find_target_ghi_minmax_value(dataset=None):
    """Find the minimum value of target ghi.

    The values are found based on the training dataset.

    Return:
        Tuple with (max_value, min_value)
    """
    if dataset is None:
        config = default_config()
        config.features = [dataloader.Feature.target_ghi]
        dataset, _, _ = load_data(config=config)

    max_value = dataset.reduce(0.0, _reduce_max)
    min_value = dataset.reduce(max_value, _reduce_min)

    return max_value, min_value
コード例 #7
0
def plot_comparison(instance: str):
    """Show original and generated images in a grid."""
    autoencoder = Autoencoder()
    autoencoder.load(instance)

    config = autoencoder.config()

    _, valid_dataset, _ = load_data(config=config)
    image = _first_image(valid_dataset)
    image_pred = _predict_image(autoencoder, image)

    generateds = []
    originals = []

    for i, (original, generated) in enumerate(zip(image, image_pred)):
        num_channels = original.shape[-1]
        for n in range(num_channels):
            originals.append(original[:, :, n])
            generateds.append(generated[:, :, n])

    _plt_images(originals, generateds, config.crop_size, config.channels)
    plt.savefig(f"assets/autoencoder.png")
コード例 #8
0
def cache(size, cache_dir):
    config = default_config()
    config.features = [Feature.image]
    config.crop_size = size

    # Create image cache dir
    config.image_cache_dir = cache_dir + f"/image_cache_{size}"
    config.image_cache_dir = config.image_cache_dir.replace("(", "")
    config.image_cache_dir = config.image_cache_dir.replace(",", "")
    config.image_cache_dir = config.image_cache_dir.replace(")", "")
    config.image_cache_dir = config.image_cache_dir.replace(" ", "-")

    logger.info(
        f"Caching images with size {size} in dir {config.image_cache_dir}")

    dataset_train, dataset_valid, dataset_test = load_data(
        enable_tf_caching=False, config=config)

    _create_cache("train", dataset_train)
    _create_cache("valid", dataset_valid)
    _create_cache("test", dataset_test)

    os.system(f"tar -cf {config.image_cache_dir}.tar {config.image_cache_dir}")