コード例 #1
0
def get_AMLRun():
    try:
        run = Run.get_submitted_run()
        return run
    except Exception as e:
        print("Caught = {}".format(e.message))
        return None
コード例 #2
0
ファイル: train_cnn.py プロジェクト: kjaanson/kaggle_rec_bio
    os.makedirs("./outputs", exist_ok=True)

    test_data = pd.read_csv(f"{data_path}/test.csv")
    print("Shape of test_data:", test_data.shape)
    test_data.head()

    train_data = pd.read_csv(f"{data_path}/train.csv")
    print("Shape of train_data:", train_data.shape)
    train_data.head()

    sirna_label_encoder = LabelEncoder().fit(train_data.sirna)

    joblib.dump(sirna_label_encoder, "./outputs/sirna_label_encoder.joblib")

    run = Run.get_submitted_run()

    model = models.create_cnn_model()

    test_size = 0.025
    batch_size = args.batch

    run.log("Batch Size", batch_size)
    run.log("Test fraction", test_size)
    run.log("Training samples", len(train_data))
    run.log("Learning rate", learning_rate)

    aml_callback = CheckpointCallback(run)

    # resampling entire training dataset
    train_data = train_data.sample(frac=training_fraction).reset_index(
コード例 #3
0
def main(argv=None):
    # get hold of the current run
    run = Run.get_submitted_run()
    train_evaluate(run)