예제 #1
0
def main():
    data_directory = os.path.dirname(os.path.realpath(__file__))
    # should load the data
    p1b1_runner.run(data_directory, "2")
    # data should now be loaded
    assert p1b1_runner.X_train is not None
    assert p1b1_runner.X_test is not None
    p1b1_runner.run(data_directory, "2")
예제 #2
0
def main():

    hyper_parameter_map = {'epochs': 2}
    hyper_parameter_map['framework'] = 'keras'
    hyper_parameter_map['model_name'] = 'p1b1'
    hyper_parameter_map['timeout'] = 3600
    hyper_parameter_map['save'] = './p1b1_output'

    p1b1_runner.run(hyper_parameter_map, "val_corr")

    hyper_parameter_map = {'epochs': 2}
    hyper_parameter_map['framework'] = 'keras'
    hyper_parameter_map['model_name'] = 'p1b1'
    hyper_parameter_map['save'] = './p1b1_output'

    p1b1_runner.run(hyper_parameter_map, "val_loss")
예제 #3
0
def main():

    hyper_parameter_map = {'epochs': 1}
    hyper_parameter_map['batch_size'] = 40
    hyper_parameter_map['dense'] = [1900, 500]
    hyper_parameter_map['framework'] = 'keras'
    hyper_parameter_map['save'] = './p1bl1_output'

    validation_loss = p1b1_runner.run(hyper_parameter_map)
    print("Validation Loss: ", validation_loss)
예제 #4
0
#List of hyperparameters - edit this to add or remove a parameter
epochs, batch_size, d1, d2, ld, lr = parameterString.split(',')

hyper_parameter_map = {'epochs': int(epochs)}
hyper_parameter_map['framework'] = 'keras'
hyper_parameter_map['batch_size'] = int(batch_size)
hyper_parameter_map['dense'] = [int(d1), int(d2)]
hyper_parameter_map['latent_dim'] = int(ld)
hyper_parameter_map['learning_rate'] = float(lr)

hyper_parameter_map['run_id'] = parameterString
# hyper_parameter_map['instance_directory'] = os.environ['TURBINE_OUTPUT']
hyper_parameter_map['save'] = os.environ[
    'TURBINE_OUTPUT'] + "/output-" + os.environ['PMI_RANK']
sys.argv = ['p1b1_runner']
val_loss = p1b1_runner.run(hyper_parameter_map)
print(val_loss)

sfn = os.environ['TURBINE_OUTPUT'] + "/output-" + os.environ[
    'PMI_RANK'] + "/procname-" + parameterString
with open(sfn, 'w') as sfile:
    sfile.write(socket.getfqdn())
    proc_id = "-" + str(os.getpid())
    sfile.write(proc_id)

# works around this error:
# https://github.com/tensorflow/tensorflow/issues/3388
from keras import backend as K

K.clear_session()