Exemple #1
0
def test_train_main(identity_config, tmp_path):
    config_reset()
    train_config = tmp_path / 'config.yaml'
    with open(train_config, 'w') as f:
        f.write('''
        train:
          steps: 5
          epochs: 3
          network:
            layers:
              - Input:
                  shape: [1, 1, num_bands]
              - Conv2D:
                  filters: 2
                  kernel_size: [1, 1]
                  activation: relu
                  padding: same
          batch_size: 1
          validation:
            steps: 2
          callbacks:
            - ExponentialLRScheduler:
                start_epoch: 2
        ''')
    args = 'delta train --config %s --config %s' % (identity_config,
                                                    train_config)
    main(args.split())
Exemple #2
0
def test_train_validate(identity_config, binary_identity_tiff_filenames, tmp_path):
    config_reset()
    train_config = tmp_path / 'config.yaml'
    with open(train_config, 'w') as f:
        f.write('''
        train:
          steps: 5
          epochs: 3
          network:
            layers:
              - Input:
                  shape: [~, ~, num_bands]
              - Conv2D:
                  filters: 2
                  kernel_size: [1, 1]
                  activation: relu
                  padding: same
          batch_size: 1
          validation:
            from_training: false
            images:
              nodata_value: ~
              files: [%s]
            labels:
              nodata_value: ~
              files: [%s]
            steps: 2
          callbacks:
            - ExponentialLRScheduler:
                start_epoch: 2
        ''' % (binary_identity_tiff_filenames[0][0], binary_identity_tiff_filenames[1][0]))
    args = 'delta train --config %s --config %s' % (identity_config, train_config)
    main(args.split())
Exemple #3
0
def test_predict_main(identity_config, tmp_path):
    config_reset()
    model_path = tmp_path / 'model.h5'
    inputs = tf.keras.layers.Input((32, 32, 2))
    tf.keras.Model(inputs, inputs).save(model_path)
    args = 'delta classify --config %s %s' % (identity_config, model_path)
    old = os.getcwd()
    os.chdir(tmp_path)  # put temporary outputs here
    main(args.split())
    os.chdir(old)
Exemple #4
0
def test_validate_main(identity_config):
    config_reset()
    args = 'delta validate --config %s' % (identity_config, )
    main(args.split())