Exemple #1
0
def test_all(config: config_reader.CustomConfigParser):
    all = {
        'num_epochs': 2000,
        'learning_rate': 0.001,
        'batch_size': 32,
        'validation_batch_size': 344,
        'optimizer': 'Adagrad',
        'l1_regularization': 0.002,
        'l2_regularization': 0.002,
        'dropout': 0.4,
        'experiment_id': 'L1_H26_DO0.4_L10.002_L20.002_B32_LR0.001',
        'save_checkpoints_steps': 5000,
        'validation_interval': 10,
        'initialize_with_checkpoint': '',
        'save_summary_steps': 10,
        'keep_checkpoint_max': 5,
        'throttle': 50,
        'type': 'classification',
        'ground_truth_column': '-1',
        'num_classes': '2',
        'weight': '1',
        'num_layers': '1',
        'layer_size': '26',
        'hidden_layers': [[32, 16, 16], [16, 8, 4]],
        'batch_norm': 'True',
        'residual': 'False',
        'training_file': 'data/iris.csv',
        'validation_file': 'data/iris.csv',
        'checkpoint_dir': 'checkpoints/enigma',
        'log_folder': 'log/enigma_Diag',
        'model_name': 'DNNClassifier'
    }
    assert (config.all()) == all
Exemple #2
0
def test_update(config: config_reader.CustomConfigParser):
    a = config.all()
    a.update({'num_epochs': 4000})
    assert a['num_epochs'] == 4000