def plot3d(): """Create a GIF from 3D image in the training set.""" for channel in ["ch1", "ch2", "ch3", "ch4", "ch6"]: config = default_config() config.features = [Feature.image] config.channels = [channel] config.crop_size = (300, 300) config.time_interval_min = 30 config.num_images = 10 train_dataset, _, _ = load_data(config=config) images = _first_image(train_dataset) image_names = [] for i, image in enumerate(images[0]): plt.cla() plt.clf() plt.axis("off") plt.tight_layout() image2d = image[:, :, 0] plt.imshow(image2d, cmap="gray") name = f"assets/{i}.png" plt.savefig(name) image_names.append(name) make_gif(f"assets/image-3d-{channel}.gif", image_names)
def train( self, optimizer: tf.keras.optimizers, epochs=25, enable_checkpoint=True, cache_file=None, checkpoint=None, ): """Performs the training of the model in minibatch.""" config = self.model.config() logger.info(f"Starting training\n" + f" - Model: {self.model.title}\n" + f" - Config: {config}") train_set, valid_set, _ = load_data( config=config, skip_non_cached=self.skip_non_cached, ) logger.info("Apply Preprocessing") train_set = self.model.preprocess(train_set) valid_set = self.model.preprocess(valid_set) if cache_file is not None: train_set = train_set.cache(f"{cache_file}-train") valid_set = valid_set.cache(f"{cache_file}-valid") logger.info("Fitting model.") init_epoch = 0 if checkpoint is not None: # Checkpoints are epoch init_epoch = int(checkpoint) + 1 # Load model weight self.model.load(checkpoint) # Load history self.history = History.load(f"{self.model.title}-{checkpoint}") for epoch in range(init_epoch, epochs): logger.info("Training...") for i, data in enumerate(train_set.batch(self.batch_size)): inputs = data[:-1] targets = data[-1] self._train_step(optimizer, inputs, targets, i + 1) logger.info("Evaluating validation loss") self._evaluate("valid", valid_set, self.batch_size) logger.info("Checkpointing...") self.model.save(str(epoch)) self._update_progress(epoch) self.history.save(f"{self.model.title}-{epoch}") logger.info("Done.")
def test(self, checkpoint: str): """Test a trained model on the test set. Must be specified a checkpointted epoch """ config = self.model.config() self.model.load(checkpoint) _, _, test_set = load_data(config=config, skip_non_cached=self.skip_non_cached) test_set = self.model.preprocess(test_set) self._evaluate("test", test_set, self.batch_size) self.history.save(f"{self.model.title}-{checkpoint}-test-set")
def run( enable_tf_caching=False, skip_non_cached=False, ): """Performs a dry run with the data generators.""" logger.info("Dry Run.") # Only test the generators, for debugging weird behavior and corner cases. (train_generator, valid_generator, test_generator,) = load_data( enable_tf_caching=enable_tf_caching, skip_non_cached=skip_non_cached ) for sample in train_generator: print( sample ) # Just make sure that we can get a single sample out of the dry-run break
def plot_comparison(encoder_instance: str, seq2seq_instance: str): """Show original and generated futur images in a grid.""" encoder = Encoder() encoder.load(encoder_instance) model = Gru(encoder) model.load(seq2seq_instance) config = model.config() config.num_images = 6 config.skip_missing_past_images = True decoder = Decoder(len(config.channels)) decoder.load(encoder_instance) valid_dataset, _, _ = load_data(config=config) images_originals = _first_images(valid_dataset, 6) image_pred = model.predict_next_images(images_originals[0, :3], 3) before = model.scaling_image.normalize(images_originals[0, :3]) before = encoder(before, training=False) before = decoder(before, training=False) images_originals = model.scaling_image.normalize(images_originals[0, 3:]) images_originals = encoder(images_originals, training=False) images_originals = decoder(images_originals, training=False) before = model.scaling_image.original(before) images_originals = model.scaling_image.original(images_originals) image_pred = model.scaling_image.original(np.array(image_pred)) generateds = [] originals = [] befores = [] for i in range(3): generated = image_pred[i] original = images_originals[i] bef = before[i] channel = 0 originals.append(original[:, :, channel]) generateds.append(generated[:, :, channel]) befores.append(bef[:, :, channel]) _plt_images(originals, generateds, befores, config.crop_size) plt.savefig(f"assets/seq2seq.png")
def find_target_ghi_minmax_value(dataset=None): """Find the minimum value of target ghi. The values are found based on the training dataset. Return: Tuple with (max_value, min_value) """ if dataset is None: config = default_config() config.features = [dataloader.Feature.target_ghi] dataset, _, _ = load_data(config=config) max_value = dataset.reduce(0.0, _reduce_max) min_value = dataset.reduce(max_value, _reduce_min) return max_value, min_value
def plot_comparison(instance: str): """Show original and generated images in a grid.""" autoencoder = Autoencoder() autoencoder.load(instance) config = autoencoder.config() _, valid_dataset, _ = load_data(config=config) image = _first_image(valid_dataset) image_pred = _predict_image(autoencoder, image) generateds = [] originals = [] for i, (original, generated) in enumerate(zip(image, image_pred)): num_channels = original.shape[-1] for n in range(num_channels): originals.append(original[:, :, n]) generateds.append(generated[:, :, n]) _plt_images(originals, generateds, config.crop_size, config.channels) plt.savefig(f"assets/autoencoder.png")
def cache(size, cache_dir): config = default_config() config.features = [Feature.image] config.crop_size = size # Create image cache dir config.image_cache_dir = cache_dir + f"/image_cache_{size}" config.image_cache_dir = config.image_cache_dir.replace("(", "") config.image_cache_dir = config.image_cache_dir.replace(",", "") config.image_cache_dir = config.image_cache_dir.replace(")", "") config.image_cache_dir = config.image_cache_dir.replace(" ", "-") logger.info( f"Caching images with size {size} in dir {config.image_cache_dir}") dataset_train, dataset_valid, dataset_test = load_data( enable_tf_caching=False, config=config) _create_cache("train", dataset_train) _create_cache("valid", dataset_valid) _create_cache("test", dataset_test) os.system(f"tar -cf {config.image_cache_dir}.tar {config.image_cache_dir}")