Ejemplo n.º 1
0
def train(player, receipt):

    learning_cfg = get_learning_cfg("conceded")
    history = train_history_utils.init_history('in progress', learning_cfg)

    training_utils.train(data_range=training_utils.create_data_range(
        learning_cfg=learning_cfg, history_file=history_file),
                         label='conceded',
                         label_values=match_dataset.CONCEDED,
                         model_dir="conceded",
                         train_path=training_utils.create_train_path(),
                         receipt=receipt,
                         history=history,
                         history_file=history_file)

    receipt_utils.put_receipt(receipt_utils.TRAIN_RECEIPT_URL, receipt, None)
def train(player, receipt):

    logger.info('started train')

    learning_cfg = get_learning_cfg("goals")

    history = train_history_utils.init_history('in progress', learning_cfg)

    training_utils.train(data_range=training_utils.create_data_range(
        learning_cfg=learning_cfg, history_file=history_file),
                         label='goals',
                         label_values=match_dataset.SCORE,
                         model_dir="goals",
                         train_path=training_utils.create_train_path(),
                         receipt=receipt,
                         history=history,
                         history_file=history_file)

    receipt_utils.put_receipt(receipt_utils.TRAIN_RECEIPT_URL, receipt, None)
Ejemplo n.º 3
0
def create(train, label, label_values, model_dir, train_filename,
           test_filename, init):

    aws_model_dir = 'models/' + model_dir
    tf_models_dir = local_dir + '/' + aws_model_dir

    learning_cfg = get_learning_cfg(model_dir)

    logger.info(learning_cfg)

    logger.info('team vocab started...')
    team_file = vocab_utils.create_vocab(url=vocab_utils.ALL_TEAMS_URL,
                                         filename=vocab_utils.TEAMS_FILE,
                                         player='default')
    logger.info('team vocab completed')

    logger.info('player vocab started...')
    player_file = vocab_utils.create_vocab(url=vocab_utils.PLAYERS_URL,
                                           filename=vocab_utils.PLAYERS_FILE,
                                           player='default')
    logger.info('[player vocab completed')

    # and the other numerics.  they will be read from a CSV / or direct from mongo more likely.  yes.  from mongo.
    # and review checkpoints, to only train with the newest data?  or build from scratch.  lets see.
    #need to add the label field too.

    feature_columns = match_featureset.create_feature_columns(
        team_vocab=team_file, player_vocab=player_file)

    # Build 2 hidden layer DNN with 10, 10 units respectively.  (from example will enrich at some point).
    classifier = classifier_utils.create(feature_columns=feature_columns,
                                         classes=len(label_values),
                                         model_dir=aws_model_dir,
                                         learning_cfg=learning_cfg,
                                         init=init)

    if train:

        logger.info(label_values)

        if learning_cfg['evaluate'] and test_filename is not None:
            (train_x, train_y), (test_x, test_y) = match_dataset.load_data(
                train_path=local_dir + train_filename,
                test_path=local_dir + test_filename,
                y_name=label,
                convert=label_values)

        else:
            (train_x, train_y) = match_dataset.load_train_data(
                train_path=local_dir + train_filename,
                y_name=label,
                convert=label_values)

        # Train the Model.
        classifier.train(input_fn=lambda: dataset_utils.train_input_fn(
            train_x, train_y, learning_cfg['batch_size']),
                         steps=learning_cfg['steps'])

        if learning_cfg['evaluate'] and test_filename is not None:
            # Evaluate the model.   not much use anymore.  but could use the first test file.  makes sense
            eval_result = classifier.evaluate(
                input_fn=lambda: dataset_utils.eval_input_fn(
                    test_x, test_y, learning_cfg['batch_size']))

            logger.info(
                '\nTest set accuracy: {accuracy:0.3f}\n'.format(**eval_result))

            if learning_cfg['aws_debug']:
                with open(local_dir + 'sample.json') as f:
                    sample = json.load(f)

                predict(classifier=classifier,
                        predict_x=sample,
                        label_values=label_values)

        if init:
            logger.info('tidying up')
            tidy_up(tf_models_dir=tf_models_dir,
                    aws_model_dir=aws_model_dir,
                    team_file=team_file,
                    train_filename=train_filename)

            time.sleep(30)

    return classifier
Ejemplo n.º 4
0
def train(data_range, label, label_values, model_dir, train_path, receipt,
          history, history_file):

    for data in data_range:

        learning_cfg = get_learning_cfg(model_dir)

        train_filename = "train-players" + data.replace('/', '-') + ".csv"
        evaluate_filename = "train-players" + get_next_in_range(
            data_range, data).replace('/', '-') + ".csv"
        train_file_path = local_dir + train_path + train_filename
        evaluate_file_path = local_dir + train_path + evaluate_filename

        has_data = model_utils.create_csv(url=model_utils.EVENT_MODEL_URL,
                                          filename=train_file_path,
                                          range=data,
                                          aws_path=train_path)

        if learning_cfg['evaluate']:

            has_test_data = model_utils.create_csv(
                url=model_utils.EVENT_MODEL_URL,
                filename=evaluate_file_path,
                range=get_next_in_range(data_range, data),
                aws_path=train_path)

            if has_data == True and has_test_data == False:
                evaluate_filename = None
            else:
                logger.info('we can evaluate')

        if has_data:

            train_filename = train_path + train_filename
            if evaluate_filename is not None:
                evaluate_filename = train_path + evaluate_filename
            ##take a copy of our file if it doesnt exist.
            #if not is_on_file(test_file_path):
            #    copyfile(train_file_path,
            #             test_file_path)
            #    put_aws_file_with_path(train_path,test_filename)
            #    write_filenames_index_from_filename(test_file_path)
            # else:
            #    get_aws_file(train_path,  test_filename)

            match_model.create(train=True,
                               label=label,
                               label_values=label_values,
                               model_dir=model_dir,
                               train_filename=train_filename,
                               test_filename=evaluate_filename,
                               init=True)
        else:
            logger.info('no data to train')

        #write the history...
        start_day, start_month, start_year, end_day, end_month, end_year = get_range_details(
            data)
        history = train_history_utils.create_history('Success - Partial',
                                                     start_day, start_month,
                                                     start_year, end_day,
                                                     end_month, end_year)
        train_history_utils.add_history(history_file, 'default', history)

    if receipt is not None:
        receipt_utils.put_receipt(receipt_utils.TRAIN_RECEIPT_URL, receipt,
                                  None)

    history['status'] = "Success - Full"
    train_history_utils.add_history(history_file, 'default', history)
import util.config_utils as config_utils

print(config_utils.get_learning_cfg('england', 'match_result'))
print(config_utils.get_learning_cfg('greece', 'match_result'))