def _prep_datasets(ids, tc): if tc.max_tile_offset: # with filtering nodata, number of tiles changes assert tc.steps, 'max_tile_offset only supported with steps set.' ds = ids.dataset(config.dataset.classes.weights(), config_augmentation()) validation=None if tc.validation: if tc.validation.from_training: validation = ds.take(tc.validation.steps) ds = ds.skip(tc.validation.steps) else: vimg = tc.validation.images vlabel = tc.validation.labels if not vimg: validation = None else: if vlabel: vimagery = ImageryDataset(vimg, vlabel, ids.output_shape(), ids.chunk_shape(), tile_shape=ids.tile_shape(), stride=ids.stride(), tile_overlap=ids.tile_overlap()) else: vimagery = AutoencoderDataset(vimg, ids.chunk_shape(), tile_shape=ids.tile_shape(), stride=ids.stride(), tile_overlap=ids.tile_overlap()) validation = vimagery.dataset(config.dataset.classes.weights()) if validation: validation = validation.batch(tc.batch_size, drop_remainder=True) else: validation = None ds = ds.batch(tc.batch_size, drop_remainder=True) return (ds, validation)
def _prep_datasets(ids, tc): ds = ids.dataset(config.dataset.classes.weights()) validation=None if tc.validation: if tc.validation.from_training: validation = ds.take(tc.validation.steps) ds = ds.skip(tc.validation.steps) else: vimg = tc.validation.images vlabel = tc.validation.labels if not vimg: validation = None else: if vlabel: vimagery = ImageryDataset(vimg, vlabel, ids.output_shape(), ids.chunk_shape(), tile_shape=ids.tile_shape(), stride=ids.stride(), tile_overlap=ids.tile_overlap()) else: vimagery = AutoencoderDataset(vimg, ids.chunk_shape(), tile_shape=ids.tile_shape(), stride=ids.stride(), tile_overlap=ids.tile_overlap()) validation = vimagery.dataset(config.dataset.classes.weights()) if tc.validation.steps: validation = validation.take(tc.validation.steps) if validation: validation = validation.batch(tc.batch_size, drop_remainder=True).prefetch(1) else: validation = None ds = ds.batch(tc.batch_size, drop_remainder=True) ds = ds.prefetch(1) if tc.steps: ds = ds.take(tc.steps) return (ds, validation)
def _prep_datasets(ids, tc, chunk_size, output_size): ds = ids.dataset(config.dataset.classes.weights()) ds = ds.batch(tc.batch_size) #ds = ds.cache() ds = ds.prefetch(tf.data.experimental.AUTOTUNE) if tc.validation: if tc.validation.from_training: validation = ds.take(tc.validation.steps) ds = ds.skip(tc.validation.steps) else: vimg = tc.validation.images vlabel = tc.validation.labels if not vimg: validation = None else: if vlabel: vimagery = ImageryDataset(vimg, vlabel, chunk_size, output_size, tc.chunk_stride, resume_mode=False) else: vimagery = AutoencoderDataset(vimg, chunk_size, tc.chunk_stride, resume_mode=False) validation = vimagery.dataset().batch(tc.batch_size) if tc.validation.steps: validation = validation.take(tc.validation.steps) #validation = validation.prefetch(4)#tf.data.experimental.AUTOTUNE) else: validation = None if tc.steps: ds = ds.take(tc.steps) #ds = ds.prefetch(4)#tf.data.experimental.AUTOTUNE) ds = ds.repeat(tc.epochs) return (ds, validation)
def _prep_datasets(ids, tc, chunk_size, output_size): ds = ids.dataset() ds = ds.batch(tc.batch_size) if tc.validation: if tc.validation.from_training: validation = ds.take(tc.validation.steps) ds = ds.skip(tc.validation.steps) else: vimg = tc.validation.images vlabel = tc.validation.labels if not vimg or not vlabel: validation = None else: vimagery = ImageryDataset(vimg, vlabel, chunk_size, output_size, tc.chunk_stride) validation = vimagery.dataset().batch(tc.batch_size).take( tc.validation.steps) else: validation = None if tc.steps: ds = ds.take(tc.steps) ds = ds.repeat(tc.epochs) return (ds, validation)
def train(model_fn, dataset: ImageryDataset, training_spec): """ Trains the specified model on a dataset according to a training specification. """ if isinstance(model_fn, tf.keras.Model): model = model_fn else: with _strategy(_devices(config.general.gpus())).scope(): model = model_fn() assert isinstance(model, tf.keras.models.Model),\ "Model is not a Tensorflow Keras model" loss = training_spec.loss_function # TODO: specify learning rate and optimizer parameters, change learning rate over time model.compile(optimizer=training_spec.optimizer, loss=loss, metrics=training_spec.metrics) input_shape = model.input_shape output_shape = model.output_shape chunk_size = input_shape[1] assert len(input_shape) == 4, 'Input to network is wrong shape.' assert input_shape[0] is None, 'Input is not batched.' # The below may no longer be valid if we move to convolutional architectures. assert input_shape[1] == input_shape[2], 'Input to network is not chunked' assert len(output_shape) == 2 or output_shape[1] == output_shape[ 2], 'Output from network is not chunked' assert input_shape[3] == dataset.num_bands( ), 'Number of bands in model does not match data.' # last element differs for the sparse metrics assert output_shape[1:-1] == dataset.output_shape()[:-1], \ 'Network output shape %s does not match label shape %s.' % (output_shape[1:], dataset.output_shape()) (ds, validation) = _prep_datasets(dataset, training_spec, chunk_size, output_shape[1]) callbacks = [tf.keras.callbacks.TerminateOnNaN()] # add callbacks from DeltaLayers for l in model.layers: if isinstance(l, DeltaLayer): c = l.callback() if c: callbacks.append(c) if config.tensorboard.enabled(): tcb = tf.keras.callbacks.TensorBoard(log_dir=config.tensorboard.dir(), update_freq='epoch', histogram_freq=1, write_images=True, embeddings_freq=1) callbacks.append(tcb) if config.mlflow.enabled(): mcb = _mlflow_train_setup(model, dataset, training_spec) callbacks.append(mcb) #print('Using mlflow folder: ' + mlflow.get_artifact_uri()) try: history = model.fit(ds, epochs=training_spec.epochs, callbacks=callbacks, validation_data=validation, validation_steps=training_spec.validation.steps if training_spec.validation else None, steps_per_epoch=training_spec.steps, verbose=1) if config.mlflow.enabled(): model_path = os.path.join(mcb.temp_dir, 'final_model.h5') print('\nFinished, saving model to %s.' % (mlflow.get_artifact_uri() + '/final_model.h5')) save_model(model, model_path) mlflow.log_artifact(model_path) os.remove(model_path) mlflow.log_param('Status', 'Completed') except: if config.mlflow.enabled(): mlflow.log_param('Status', 'Aborted') mlflow.end_run('FAILED') model_path = os.path.join(mcb.temp_dir, 'aborted_model.h5') print('\nAborting, saving current model to %s.' % (mlflow.get_artifact_uri() + '/aborted_model.h5')) save_model(model, model_path) mlflow.log_artifact(model_path) os.remove(model_path) raise finally: if config.mlflow.enabled(): mlflow.log_param('Epoch', mcb.epoch) mlflow.log_param('Batch', mcb.batch) if mcb and mcb.temp_dir: shutil.rmtree(mcb.temp_dir) if config.mlflow.enabled(): mlflow.end_run() return model, history
def train(model_fn, dataset : ImageryDataset, training_spec, resume_path=None, internal_model_extension='.h5'): """ Trains the specified model on a dataset according to a training specification. Parameters ---------- model_fn: Callable[[], tensorflow.keras.model.Model] Function that constructs a model. dataset: delta.imagery.imagery_dataset.ImageryDataset Dataset to train on. training_spec: delta.ml.ml_config.TrainingSpec Training parameters. resume_path: str Optional file to load initial model weights from. Returns ------- (tensorflow.keras.models.Model, History): The trained model and the training history. """ model = compile_model(model_fn, training_spec, resume_path) assert model.input_shape[3] == dataset.num_bands(), 'Number of bands in model does not match data.' # last element differs for the sparse metrics assert model.output_shape[1:-1] == dataset.output_shape()[:-1] or (model.output_shape[1] is None), \ 'Network output shape %s does not match label shape %s.' % \ (model.output_shape[1:], dataset.output_shape()[:-1]) (ds, validation) = _prep_datasets(dataset, training_spec) (callbacks, mcb) = _build_callbacks(model, dataset, training_spec, internal_model_extension) try: if (training_spec.steps is None) or (training_spec.steps > 0): if training_spec.steps is not None: ds = ds.repeat() # repeat for ever, use steps and epochs to stop done = False epochs = training_spec.epochs initial_epoch = 0 while not done: try: history = model.fit(ds, epochs=epochs, initial_epoch=initial_epoch, callbacks=callbacks, validation_data=validation, validation_steps=None, # Steps are controlled in the dataset setup steps_per_epoch=training_spec.steps, verbose=1) # Set to 2 when logging done = True except ContinueTrainingException as cte: print('Recompiling model and resuming training.') initial_epoch += cte.completed_epochs if cte.recompile_model: model = compile_model(model, training_spec) if cte.learning_rate: K.set_value(model.optimizer.lr, cte.learning_rate) else: # Skip training print('Skipping straight to validation') history = model.evaluate(validation, steps=training_spec.validation.steps, callbacks=callbacks, verbose=1) if config.mlflow.enabled(): model_path = os.path.join(mcb.temp_dir, 'final_model' + internal_model_extension) print('\nFinished, saving model to %s.' % (mlflow.get_artifact_uri() + '/final_model' + internal_model_extension)) save_model(model, model_path) mlflow.log_artifact(model_path) if os.path.isdir(model_path): shutil.rmtree(model_path) else: os.remove(model_path) mlflow.log_param('Status', 'Completed') except: if config.mlflow.enabled(): mlflow.log_param('Status', 'Aborted') mlflow.log_param('Epoch', mcb.epoch) mlflow.log_param('Batch', mcb.batch) mlflow.end_run('FAILED') model_path = os.path.join(mcb.temp_dir, 'aborted_model' + internal_model_extension) print('\nAborting, saving current model to %s.' % (mlflow.get_artifact_uri() + '/aborted_model' + internal_model_extension)) save_model(model, model_path) mlflow.log_artifact(model_path) if os.path.isdir(model_path): shutil.rmtree(model_path) else: os.remove(model_path) raise finally: if config.mlflow.enabled(): if mcb and mcb.temp_dir: shutil.rmtree(mcb.temp_dir) mlflow.end_run() return model, history