Esempio n. 1
0
def _prep_datasets(ids, tc):
    if tc.max_tile_offset:
        # with filtering nodata, number of tiles changes
        assert tc.steps, 'max_tile_offset only supported with steps set.'
    ds = ids.dataset(config.dataset.classes.weights(), config_augmentation())

    validation=None
    if tc.validation:
        if tc.validation.from_training:
            validation = ds.take(tc.validation.steps)
            ds = ds.skip(tc.validation.steps)
        else:
            vimg   = tc.validation.images
            vlabel = tc.validation.labels
            if not vimg:
                validation = None
            else:
                if vlabel:
                    vimagery = ImageryDataset(vimg, vlabel, ids.output_shape(), ids.chunk_shape(),
                                              tile_shape=ids.tile_shape(), stride=ids.stride(),
                                              tile_overlap=ids.tile_overlap())
                else:
                    vimagery = AutoencoderDataset(vimg, ids.chunk_shape(), tile_shape=ids.tile_shape(),
                                                  stride=ids.stride(), tile_overlap=ids.tile_overlap())
                validation = vimagery.dataset(config.dataset.classes.weights())
        if validation:
            validation = validation.batch(tc.batch_size, drop_remainder=True)
    else:
        validation = None

    ds = ds.batch(tc.batch_size, drop_remainder=True)
    return (ds, validation)
Esempio n. 2
0
def _prep_datasets(ids, tc):
    ds = ids.dataset(config.dataset.classes.weights())

    validation=None
    if tc.validation:
        if tc.validation.from_training:
            validation = ds.take(tc.validation.steps)
            ds = ds.skip(tc.validation.steps)
        else:
            vimg   = tc.validation.images
            vlabel = tc.validation.labels
            if not vimg:
                validation = None
            else:
                if vlabel:
                    vimagery = ImageryDataset(vimg, vlabel, ids.output_shape(), ids.chunk_shape(),
                                              tile_shape=ids.tile_shape(), stride=ids.stride(),
                                              tile_overlap=ids.tile_overlap())
                else:
                    vimagery = AutoencoderDataset(vimg, ids.chunk_shape(), tile_shape=ids.tile_shape(),
                                                  stride=ids.stride(), tile_overlap=ids.tile_overlap())
                validation = vimagery.dataset(config.dataset.classes.weights())
                if tc.validation.steps:
                    validation = validation.take(tc.validation.steps)
        if validation:
            validation = validation.batch(tc.batch_size, drop_remainder=True).prefetch(1)
    else:
        validation = None

    ds = ds.batch(tc.batch_size, drop_remainder=True)
    ds = ds.prefetch(1)
    if tc.steps:
        ds = ds.take(tc.steps)
    return (ds, validation)
Esempio n. 3
0
def _prep_datasets(ids, tc, chunk_size, output_size):
    ds = ids.dataset(config.dataset.classes.weights())
    ds = ds.batch(tc.batch_size)
    #ds = ds.cache()
    ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
    if tc.validation:
        if tc.validation.from_training:
            validation = ds.take(tc.validation.steps)
            ds = ds.skip(tc.validation.steps)
        else:
            vimg = tc.validation.images
            vlabel = tc.validation.labels
            if not vimg:
                validation = None
            else:
                if vlabel:
                    vimagery = ImageryDataset(vimg,
                                              vlabel,
                                              chunk_size,
                                              output_size,
                                              tc.chunk_stride,
                                              resume_mode=False)
                else:
                    vimagery = AutoencoderDataset(vimg,
                                                  chunk_size,
                                                  tc.chunk_stride,
                                                  resume_mode=False)
                validation = vimagery.dataset().batch(tc.batch_size)
                if tc.validation.steps:
                    validation = validation.take(tc.validation.steps)
        #validation = validation.prefetch(4)#tf.data.experimental.AUTOTUNE)
    else:

        validation = None
    if tc.steps:
        ds = ds.take(tc.steps)
    #ds = ds.prefetch(4)#tf.data.experimental.AUTOTUNE)
    ds = ds.repeat(tc.epochs)
    return (ds, validation)
Esempio n. 4
0
def _prep_datasets(ids, tc, chunk_size, output_size):
    ds = ids.dataset()
    ds = ds.batch(tc.batch_size)
    if tc.validation:
        if tc.validation.from_training:
            validation = ds.take(tc.validation.steps)
            ds = ds.skip(tc.validation.steps)
        else:
            vimg = tc.validation.images
            vlabel = tc.validation.labels
            if not vimg or not vlabel:
                validation = None
            else:
                vimagery = ImageryDataset(vimg, vlabel, chunk_size,
                                          output_size, tc.chunk_stride)
                validation = vimagery.dataset().batch(tc.batch_size).take(
                    tc.validation.steps)
    else:
        validation = None
    if tc.steps:
        ds = ds.take(tc.steps)
    ds = ds.repeat(tc.epochs)
    return (ds, validation)
Esempio n. 5
0
def train(model_fn, dataset: ImageryDataset, training_spec):
    """
    Trains the specified model on a dataset according to a training
    specification.
    """
    if isinstance(model_fn, tf.keras.Model):
        model = model_fn
    else:
        with _strategy(_devices(config.general.gpus())).scope():
            model = model_fn()
            assert isinstance(model, tf.keras.models.Model),\
                   "Model is not a Tensorflow Keras model"
            loss = training_spec.loss_function
            # TODO: specify learning rate and optimizer parameters, change learning rate over time
            model.compile(optimizer=training_spec.optimizer,
                          loss=loss,
                          metrics=training_spec.metrics)

    input_shape = model.input_shape
    output_shape = model.output_shape
    chunk_size = input_shape[1]

    assert len(input_shape) == 4, 'Input to network is wrong shape.'
    assert input_shape[0] is None, 'Input is not batched.'
    # The below may no longer be valid if we move to convolutional architectures.
    assert input_shape[1] == input_shape[2], 'Input to network is not chunked'
    assert len(output_shape) == 2 or output_shape[1] == output_shape[
        2], 'Output from network is not chunked'
    assert input_shape[3] == dataset.num_bands(
    ), 'Number of bands in model does not match data.'
    # last element differs for the sparse metrics
    assert output_shape[1:-1] == dataset.output_shape()[:-1], \
            'Network output shape %s does not match label shape %s.' % (output_shape[1:], dataset.output_shape())

    (ds, validation) = _prep_datasets(dataset, training_spec, chunk_size,
                                      output_shape[1])

    callbacks = [tf.keras.callbacks.TerminateOnNaN()]
    # add callbacks from DeltaLayers
    for l in model.layers:
        if isinstance(l, DeltaLayer):
            c = l.callback()
            if c:
                callbacks.append(c)
    if config.tensorboard.enabled():
        tcb = tf.keras.callbacks.TensorBoard(log_dir=config.tensorboard.dir(),
                                             update_freq='epoch',
                                             histogram_freq=1,
                                             write_images=True,
                                             embeddings_freq=1)
        callbacks.append(tcb)

    if config.mlflow.enabled():
        mcb = _mlflow_train_setup(model, dataset, training_spec)
        callbacks.append(mcb)
        #print('Using mlflow folder: ' + mlflow.get_artifact_uri())

    try:
        history = model.fit(ds,
                            epochs=training_spec.epochs,
                            callbacks=callbacks,
                            validation_data=validation,
                            validation_steps=training_spec.validation.steps
                            if training_spec.validation else None,
                            steps_per_epoch=training_spec.steps,
                            verbose=1)

        if config.mlflow.enabled():
            model_path = os.path.join(mcb.temp_dir, 'final_model.h5')
            print('\nFinished, saving model to %s.' %
                  (mlflow.get_artifact_uri() + '/final_model.h5'))
            save_model(model, model_path)
            mlflow.log_artifact(model_path)
            os.remove(model_path)
            mlflow.log_param('Status', 'Completed')
    except:
        if config.mlflow.enabled():
            mlflow.log_param('Status', 'Aborted')
            mlflow.end_run('FAILED')
            model_path = os.path.join(mcb.temp_dir, 'aborted_model.h5')
            print('\nAborting, saving current model to %s.' %
                  (mlflow.get_artifact_uri() + '/aborted_model.h5'))
            save_model(model, model_path)
            mlflow.log_artifact(model_path)
            os.remove(model_path)
        raise
    finally:
        if config.mlflow.enabled():
            mlflow.log_param('Epoch', mcb.epoch)
            mlflow.log_param('Batch', mcb.batch)
            if mcb and mcb.temp_dir:
                shutil.rmtree(mcb.temp_dir)

    if config.mlflow.enabled():
        mlflow.end_run()

    return model, history
Esempio n. 6
0
def train(model_fn, dataset : ImageryDataset, training_spec, resume_path=None, internal_model_extension='.h5'):
    """
    Trains the specified model on a dataset according to a training
    specification.

    Parameters
    ----------
    model_fn: Callable[[], tensorflow.keras.model.Model]
        Function that constructs a model.
    dataset: delta.imagery.imagery_dataset.ImageryDataset
        Dataset to train on.
    training_spec: delta.ml.ml_config.TrainingSpec
        Training parameters.
    resume_path: str
        Optional file to load initial model weights from.

    Returns
    -------
    (tensorflow.keras.models.Model, History):
        The trained model and the training history.
    """
    model = compile_model(model_fn, training_spec, resume_path)
    assert model.input_shape[3] == dataset.num_bands(), 'Number of bands in model does not match data.'
    # last element differs for the sparse metrics
    assert model.output_shape[1:-1] == dataset.output_shape()[:-1] or (model.output_shape[1] is None), \
            'Network output shape %s does not match label shape %s.' % \
            (model.output_shape[1:], dataset.output_shape()[:-1])

    (ds, validation) = _prep_datasets(dataset, training_spec)

    (callbacks, mcb) = _build_callbacks(model, dataset, training_spec, internal_model_extension)

    try:

        if (training_spec.steps is None) or (training_spec.steps > 0):
            if training_spec.steps is not None:
                ds = ds.repeat() # repeat for ever, use steps and epochs to stop
            done = False
            epochs = training_spec.epochs
            initial_epoch = 0
            while not done:
                try:
                    history = model.fit(ds,
                                        epochs=epochs,
                                        initial_epoch=initial_epoch,
                                        callbacks=callbacks,
                                        validation_data=validation,
                                        validation_steps=None, # Steps are controlled in the dataset setup
                                        steps_per_epoch=training_spec.steps,
                                        verbose=1) # Set to 2 when logging
                    done = True
                except ContinueTrainingException as cte:
                    print('Recompiling model and resuming training.')
                    initial_epoch += cte.completed_epochs
                    if cte.recompile_model:
                        model = compile_model(model, training_spec)
                    if cte.learning_rate:
                        K.set_value(model.optimizer.lr, cte.learning_rate)
        else: # Skip training
            print('Skipping straight to validation')
            history = model.evaluate(validation, steps=training_spec.validation.steps,
                                     callbacks=callbacks, verbose=1)

        if config.mlflow.enabled():
            model_path = os.path.join(mcb.temp_dir, 'final_model' + internal_model_extension)
            print('\nFinished, saving model to %s.'
                  % (mlflow.get_artifact_uri() + '/final_model' + internal_model_extension))
            save_model(model, model_path)
            mlflow.log_artifact(model_path)
            if os.path.isdir(model_path):
                shutil.rmtree(model_path)
            else:
                os.remove(model_path)
            mlflow.log_param('Status', 'Completed')
    except:
        if config.mlflow.enabled():
            mlflow.log_param('Status', 'Aborted')
            mlflow.log_param('Epoch', mcb.epoch)
            mlflow.log_param('Batch', mcb.batch)
            mlflow.end_run('FAILED')
            model_path = os.path.join(mcb.temp_dir, 'aborted_model' + internal_model_extension)
            print('\nAborting, saving current model to %s.'
                  % (mlflow.get_artifact_uri() + '/aborted_model' + internal_model_extension))
            save_model(model, model_path)
            mlflow.log_artifact(model_path)
            if os.path.isdir(model_path):
                shutil.rmtree(model_path)
            else:
                os.remove(model_path)
        raise
    finally:
        if config.mlflow.enabled():
            if mcb and mcb.temp_dir:
                shutil.rmtree(mcb.temp_dir)
            mlflow.end_run()

    return model, history