def retrain(args): game = sys.argv[1] gen_prefix = sys.argv[2] gen_prefix_next = sys.argv[3] configs = Configs() train_config = getattr(configs, game)(gen_prefix) generation_descr = templates.default_generation_desc(train_config.game, multiple_policy_heads=True, num_previous_states=0) # create a transformer man = get_manager() transformer = man.get_transformer(train_config.game, generation_descr) # create the manager trainer = train.TrainManager(train_config, transformer) trainer.update_config(train_config, next_generation_prefix=gen_prefix_next) nn_model_config = get_nn_model(train_config.game, transformer) #nn_model_config = templates.nn_model_config_template(train_config.game, "small", transformer) trainer.get_network(nn_model_config, generation_descr) data = trainer.gather_data() trainer.do_epochs(data) trainer.save()
def do_training(game, gen_prefix, next_step, starting_step, num_previous_states, gen_prefix_next, do_data_augmentation=False): man = get_manager() # create a transformer generation_descr = templates.default_generation_desc( game, multiple_policy_heads=True, num_previous_states=num_previous_states) transformer = man.get_transformer(game, generation_descr) # create train_config train_config = get_train_config(game, gen_prefix, next_step, starting_step) trainer = train.TrainManager(train_config, transformer, do_data_augmentation=do_data_augmentation) trainer.update_config(train_config, next_generation_prefix=gen_prefix_next) # get the nn model and set on trainer nn_model_config = get_nn_model(train_config.game, transformer) trainer.get_network(nn_model_config, generation_descr) trainer.do_epochs() trainer.save()
def speed_test(): ITERATIONS = 3 man = get_manager() # get data train_config = config() # get nn to test speed on transformer = man.get_transformer(train_config.game) trainer = train.TrainManager(train_config, transformer) nn_model_config = templates.nn_model_config_template(train_config.game, "small", transformer) generation_descr = templates.default_generation_desc(train_config.game) trainer.get_network(nn_model_config, generation_descr) data = trainer.gather_data() res = [] batch_size = 4096 sample_count = len(data.inputs) keras_model = trainer.nn.get_model() # warm up for i in range(2): idx, end_idx = i * batch_size, (i + 1) * batch_size print i, idx, end_idx inputs = np.array(data.inputs[idx:end_idx]) res.append(keras_model.predict(inputs, batch_size=batch_size)) print res[0] for _ in range(ITERATIONS): res = [] times = [] gc.collect() print 'Starting speed run' num_batches = sample_count / batch_size + 1 print "batches %s, batch_size %s, inputs: %s" % (num_batches, batch_size, len(data.inputs)) for i in range(num_batches): idx, end_idx = i * batch_size, (i + 1) * batch_size inputs = np.array(data.inputs[idx:end_idx]) print "inputs", len(inputs) s = time.time() Y = keras_model.predict(inputs, batch_size=batch_size) times.append(time.time() - s) print "outputs", len(Y[0]) print "times taken", times print "total_time taken", sum(times) print "predictions per second", sample_count / float(sum(times))
def go(): ITERATIONS = 3 man = get_manager() # get data train_config = config() # get nn to test speed on transformer = man.get_transformer(train_config.game) trainer = train.TrainManager(train_config, transformer) nn_model_config = templates.nn_model_config_template(train_config.game, "small", transformer) generation_descr = templates.default_generation_desc(train_config.game) trainer.get_network(nn_model_config, generation_descr) data = trainer.gather_data() r = Runner(trainer.gather_data(), trainer.nn.get_model()) r.warmup()
def test_trainer_update_config(): # we need the data for this test conf = get_conf_reversi() # create a transformer man = get_manager() generation_descr = templates.default_generation_desc(conf.game) generation_descr.num_previous_states = 2 generation_descr.multiple_policy_heads = True transformer = man.get_transformer(conf.game, generation_descr) # create the manager trainer = train.TrainManager(conf, transformer, next_generation_prefix="x2test") nn_model_config = templates.nn_model_config_template(conf.game, "tiny", transformer) trainer.get_network(nn_model_config, generation_descr) data = trainer.gather_data() print data trainer.do_epochs(data) trainer.save()