Пример #1
0
def retrain(args):
    game = sys.argv[1]
    gen_prefix = sys.argv[2]
    gen_prefix_next = sys.argv[3]

    configs = Configs()
    train_config = getattr(configs, game)(gen_prefix)

    generation_descr = templates.default_generation_desc(train_config.game,
                                                         multiple_policy_heads=True,
                                                         num_previous_states=0)

    # create a transformer
    man = get_manager()

    transformer = man.get_transformer(train_config.game, generation_descr)

    # create the manager
    trainer = train.TrainManager(train_config, transformer)
    trainer.update_config(train_config, next_generation_prefix=gen_prefix_next)

    nn_model_config = get_nn_model(train_config.game, transformer)
    #nn_model_config = templates.nn_model_config_template(train_config.game, "small", transformer)
    trainer.get_network(nn_model_config, generation_descr)

    data = trainer.gather_data()

    trainer.do_epochs(data)
    trainer.save()
Пример #2
0
def do_training(game,
                gen_prefix,
                next_step,
                starting_step,
                num_previous_states,
                gen_prefix_next,
                do_data_augmentation=False):

    man = get_manager()

    # create a transformer
    generation_descr = templates.default_generation_desc(
        game,
        multiple_policy_heads=True,
        num_previous_states=num_previous_states)
    transformer = man.get_transformer(game, generation_descr)

    # create train_config
    train_config = get_train_config(game, gen_prefix, next_step, starting_step)
    trainer = train.TrainManager(train_config,
                                 transformer,
                                 do_data_augmentation=do_data_augmentation)
    trainer.update_config(train_config, next_generation_prefix=gen_prefix_next)

    # get the nn model and set on trainer
    nn_model_config = get_nn_model(train_config.game, transformer)
    trainer.get_network(nn_model_config, generation_descr)

    trainer.do_epochs()
    trainer.save()
Пример #3
0
def speed_test():
    ITERATIONS = 3

    man = get_manager()

    # get data
    train_config = config()

    # get nn to test speed on
    transformer = man.get_transformer(train_config.game)
    trainer = train.TrainManager(train_config, transformer)

    nn_model_config = templates.nn_model_config_template(train_config.game, "small", transformer)
    generation_descr = templates.default_generation_desc(train_config.game)
    trainer.get_network(nn_model_config, generation_descr)

    data = trainer.gather_data()

    res = []

    batch_size = 4096
    sample_count = len(data.inputs)
    keras_model = trainer.nn.get_model()

    # warm up
    for i in range(2):
        idx, end_idx = i * batch_size, (i + 1) * batch_size
        print i, idx, end_idx
        inputs = np.array(data.inputs[idx:end_idx])
        res.append(keras_model.predict(inputs, batch_size=batch_size))
        print res[0]

    for _ in range(ITERATIONS):
        res = []
        times = []
        gc.collect()

        print 'Starting speed run'
        num_batches = sample_count / batch_size + 1
        print "batches %s, batch_size %s, inputs: %s" % (num_batches,
                                                         batch_size,
                                                         len(data.inputs))
        for i in range(num_batches):
            idx, end_idx = i * batch_size, (i + 1) * batch_size
            inputs = np.array(data.inputs[idx:end_idx])
            print "inputs", len(inputs)
            s = time.time()
            Y = keras_model.predict(inputs, batch_size=batch_size)
            times.append(time.time() - s)
            print "outputs", len(Y[0])

        print "times taken", times
        print "total_time taken", sum(times)
        print "predictions per second", sample_count / float(sum(times))
Пример #4
0
def go():
    ITERATIONS = 3

    man = get_manager()

    # get data
    train_config = config()

    # get nn to test speed on
    transformer = man.get_transformer(train_config.game)
    trainer = train.TrainManager(train_config, transformer)

    nn_model_config = templates.nn_model_config_template(train_config.game, "small", transformer)
    generation_descr = templates.default_generation_desc(train_config.game)
    trainer.get_network(nn_model_config, generation_descr)

    data = trainer.gather_data()

    r = Runner(trainer.gather_data(), trainer.nn.get_model())
    r.warmup()
Пример #5
0
def test_trainer_update_config():

    # we need the data for this test
    conf = get_conf_reversi()

    # create a transformer
    man = get_manager()
    generation_descr = templates.default_generation_desc(conf.game)
    generation_descr.num_previous_states = 2
    generation_descr.multiple_policy_heads = True
    transformer = man.get_transformer(conf.game, generation_descr)

    # create the manager
    trainer = train.TrainManager(conf, transformer, next_generation_prefix="x2test")

    nn_model_config = templates.nn_model_config_template(conf.game, "tiny", transformer)
    trainer.get_network(nn_model_config, generation_descr)

    data = trainer.gather_data()

    print data

    trainer.do_epochs(data)
    trainer.save()