Esempio n. 1
0
def play_with_policy(policy,
                     policy2,
                     env,
                     seed,
                     nsteps=5,
                     nstack=4,
                     total_timesteps=int(80e6),
                     vf_coef=0.5,
                     ent_coef=0.01,
                     max_grad_norm=0.5,
                     lr=7e-4,
                     lrschedule='linear',
                     epsilon=1e-5,
                     alpha=0.99,
                     gamma=0.99,
                     log_interval=20,
                     model_path='',
                     model_path2=''):
    tf.reset_default_graph()
    set_global_seeds(seed)

    nenvs = env.num_envs
    ob_space = env.observation_space
    ac_space = env.action_space
    num_procs = len(env.remotes)  # HACK
    statistics_path = ('./stadistics_random')

    runner = Runner(env, None, None, nsteps=nsteps, nstack=nstack, gamma=gamma)
    runner.mcts.self_play_with_simple_policy()
    env.close()
Esempio n. 2
0
def learn(policy, env, seed, nsteps, nstack=4, total_timesteps=int(80e6), vf_coef=0.5, ent_coef=0.01,
          max_grad_norm=0.5, lr=7e-4, lrschedule='linear', epsilon=1e-5, alpha=0.99, gamma=0.99, log_interval=1000,
          load_model=False, model_path='', data_augmentation=True, BATCH_SIZE=100,NUMBER_OF_MODELS=4):
    tf.reset_default_graph()
    set_global_seeds(seed)

    nenvs = env.num_envs
    ob_space = env.observation_space
    ac_space = env.action_space
    num_procs = len(env.remotes)  # HACK
    now = datetime.datetime.now()

    CHANGE_PLAYER = 4000
    NUMBER_TEST = 1000
    TEMP_CTE = 10000
    counter_stadistics = 0
    temp = np.ones(1)

    parameters = now.strftime("%d-%m-%Y_%H-%M-%S") + "_seed_" + str(
        seed) + "_BATCH_" + str(BATCH_SIZE) + "_TEMP_" + str(TEMP_CTE) + "_DA_" + str(data_augmentation) + str(np.sqrt(nsteps))+ 'x'+ str(np.sqrt(nsteps)) +'_num_players_' +str(NUMBER_OF_MODELS) + str(policy)
    statistics_path = ('../statistics/AI_vs_AI/' + parameters )

    models_path= statistics_path + '/model/'
    statistics_csv = statistics_path + "/csv/"


    summary_writer = tf.summary.FileWriter(statistics_path)

    models = create_models(NUMBER_OF_MODELS,policy, ob_space, ac_space, nenvs, nsteps, nstack, num_procs, ent_coef, vf_coef,
                           max_grad_norm, lr, alpha, epsilon, total_timesteps, lrschedule)

    BATCH_SIZE = np.sqrt(nsteps) * BATCH_SIZE


    if load_model:
        # model_A.load('./models/model_A.cpkt')
        # model_B.load('./models/model_B.cpkt')
        print('Model loaded')

    runner, model, model_2 = change_player_keep_one(models, env)
    print('Loaded players', model.scope, 'A', model_2.scope, 'B')

    nbatch = nenvs * nsteps
    tstart = time.time()
    try:
        os.stat(statistics_path)
    except:
        os.mkdir(statistics_path)
    try:
        os.stat(models_path)
    except:
        os.mkdir(models_path)
    for update in range(0, total_timesteps // nbatch + 1):


        # if update % TEMP_COUNTER == 0:
        #     temp_count += 1
        #     temp = temp * (TEMP_CTE / (temp_count + TEMP_CTE)) + 0.2
        #     print('temp:', temp)

        if update % CHANGE_PLAYER == 0 and update != 0:
            env.print_stadistics_vs()
            print_tensorboard_training_score(summary_writer, update, env)
            temp = (0.9 * np.exp(-(update / TEMP_CTE)) + 0.1) * np.ones(1)
            print('Testing players, update:', update)
            runner.test(temp, model,NUMBER_TEST,summary_writer, env, update)
            runner.test(temp, model_2,NUMBER_TEST,summary_writer, env, update)

            runner, model, model_2 = change_player_keep_one(models,env)
            print('Change players, new players', model.scope, 'A',model_2.scope,'B')


        else:
            train(temp, runner, model, model_2, data_augmentation, BATCH_SIZE, env, summary_writer, update,
                  counter_stadistics, tstart, nsteps=nsteps)


        if (update % (log_interval * 10)) == 0 and update != 0:
             print('Save check point')
             for mod in models:
                 mod.save(models_path + parameters + '_' + str(mod.scope) )
Esempio n. 3
0
def learn(policy,
          env,
          seed,
          nsteps,
          nstack=4,
          total_timesteps=int(80e6),
          vf_coef=0.5,
          ent_coef=0.01,
          max_grad_norm=0.5,
          lr=7e-4,
          lrschedule='linear',
          epsilon=1e-5,
          alpha=0.99,
          gamma=0.99,
          log_interval=1000,
          load_model=False,
          model_path='',
          data_augmentation=True,
          BATCH_SIZE=10,
          TEMP_CTE=30000,
          RUN_TEST=5000):

    tf.reset_default_graph()
    set_global_seeds(seed)
    print('Data augmentation', data_augmentation)

    nenvs = env.num_envs
    ob_space = env.observation_space
    ac_space = env.action_space
    num_procs = len(env.remotes)  # HACK
    now = datetime.datetime.now()
    temp = np.ones(1)
    counter_stadistics = 0
    parameters = now.strftime("%d-%m-%Y_%H-%M-%S") + "_seed_" + str(
        seed) + "_BATCH_" + str(BATCH_SIZE) + "_TEMP_" + str(
            TEMP_CTE) + "_DA_" + str(data_augmentation) + "_VF_" + str(
                vf_coef) + str(policy) + str(np.sqrt(nsteps)) + 'x' + str(
                    np.sqrt(nsteps)) + 'random'
    statistics_path = "../statistics/random/"
    BATCH_SIZE = np.sqrt(nsteps) * BATCH_SIZE
    model_path_load = model_path
    try:
        os.stat("../statistics/")
    except:
        os.mkdir("../statistics/")
    try:
        os.stat(statistics_path)
    except:
        os.mkdir(statistics_path)
    statistics_path = "../statistics/random/" + parameters
    model_path = statistics_path + "/model/"
    statistics_csv = statistics_path + "/csv/"

    games_wonAI_test_saver, games_finish_in_draw_test_saver, games_wonRandom_test_saver, illegal_test_games_test_saver = [], [], [], []
    games_wonAI_train_saver, games_finish_in_draw_train_saver, games_wonRandom_train_saver, illegal_test_games_train_saver = [], [], [], []
    update_test, update_train = [], []
    policy_entropy_saver, policy_loss_saver, explained_variance_saver, value_loss_saver, ev_saver = [], [], [], [], []

    try:
        os.stat(statistics_path)
    except:
        os.mkdir(statistics_path)

    try:
        os.stat(statistics_csv)
    except:
        os.mkdir(statistics_csv)

    try:
        os.stat(model_path)
    except:
        os.mkdir(model_path)

    summary_writer = tf.summary.FileWriter(statistics_path)
    temp = np.ones(1)

    model = Model(policy=policy,
                  ob_space=ob_space,
                  ac_space=ac_space,
                  nenvs=nenvs,
                  nsteps=np.sqrt(nsteps),
                  nstack=nstack,
                  num_procs=num_procs,
                  ent_coef=ent_coef,
                  vf_coef=vf_coef,
                  max_grad_norm=max_grad_norm,
                  lr=lr,
                  alpha=alpha,
                  epsilon=epsilon,
                  total_timesteps=total_timesteps,
                  lrschedule=lrschedule,
                  summary_writter=summary_writer)

    if load_model:
        model.load(model_path_load)
    runner = Runner(env, model, nsteps=nsteps, nstack=nstack, gamma=gamma)

    nbatch = nenvs * nsteps
    tstart = time.time()

    for update in range(0, total_timesteps // nbatch + 1):
        if update % 1000 == 0:
            print('update: ', update)
            import threading
            env.print_stadistics(threading.get_ident())
        if (update % RUN_TEST < 1000) and (update % RUN_TEST > 0) and (update
                                                                       != 0):
            # print("Aqui")
            runner.test(np.ones(1))
            temp = (0.8 * np.exp(-(update / TEMP_CTE)) + 0.2) * np.ones(1)

            if ((update % RUN_TEST) == 999):
                games_wonAI, games_wonRandom, games_finish_in_draw, illegal_games = env.get_stadistics(
                )
                summary = tf.Summary()
                summary.value.add(tag='test/games_wonAI',
                                  simple_value=float(games_wonAI))
                summary.value.add(tag='test/games_wonRandom',
                                  simple_value=float(games_wonRandom))
                summary.value.add(tag='test/games_finish_in_draw',
                                  simple_value=float(games_finish_in_draw))
                summary.value.add(tag='test/illegal_games',
                                  simple_value=float(illegal_games))
                summary_writer.add_summary(summary, update)

                games_wonAI_test_saver.append(games_wonAI)
                games_wonRandom_test_saver.append(games_wonRandom)
                games_finish_in_draw_test_saver.append(games_finish_in_draw)
                illegal_test_games_test_saver.append(illegal_games)
                update_test.append(update)

                save_csv(statistics_csv + 'games_wonAI_test.csv',
                         games_wonAI_test_saver)
                save_csv(statistics_csv + 'games_wonRandom_test.csv',
                         games_wonRandom_test_saver)
                save_csv(statistics_csv + 'games_finish_in_draw_test.csv',
                         games_finish_in_draw_test_saver)
                save_csv(statistics_csv + 'illegal_games_test.csv',
                         illegal_test_games_test_saver)
                save_csv(statistics_csv + 'update_test.csv', update_test)

                summary_writer.flush()
        else:
            obs, states, rewards, masks, actions, values = runner.run(temp)
            # print('obs',obs,'actions',actions)
            # print('values',values,'rewards',rewards,)

            obs, states, rewards, masks, actions, values = redimension_results(
                obs, states, rewards, masks, actions, values, env, nsteps)

            size_batch = runner.put_in_batch(obs, states, rewards, masks,
                                             actions, values)
            if size_batch >= BATCH_SIZE:
                # print('Training batch')
                batch = runner.get_batch()
                policy_loss_sv, value_loss_sv, policy_entropy_sv = [], [], []

                for i in range(len(batch)):
                    obs, states, rewards, masks, actions, values = batch.get(i)
                    if data_augmentation:
                        pl, vl, pe = train_data_augmentation(
                            obs, states, rewards, masks, actions, values,
                            model, temp)
                        policy_loss_sv.append(pl)
                        value_loss_sv.append(vl)
                        policy_entropy_sv.append(pe)
                    else:

                        pl, vl, pe = train_without_data_augmentation(
                            obs, states, rewards, masks, actions, values,
                            model, temp)
                        policy_loss_sv.append(pl)
                        value_loss_sv.append(vl)
                        policy_entropy_sv.append(pe)

                runner.empty_batch()
                policy_loss, value_loss, policy_entropy = np.mean(
                    policy_loss_sv), np.mean(value_loss_sv), np.mean(
                        policy_entropy_sv)
                # print('batch trained')
                nseconds = time.time() - tstart
                fps = int((update * nbatch) / nseconds)
                ev = explained_variance(values, rewards)

                counter_stadistics += 1
                if counter_stadistics % 10 == 0:
                    counter_stadistics = 0
                    logger.record_tabular("nupdates", update)
                    logger.record_tabular("total_timesteps", update * nbatch)
                    logger.record_tabular("fps", fps)
                    logger.record_tabular("policy_entropy",
                                          float(policy_entropy))
                    logger.record_tabular("policy_loss", float(policy_loss))
                    logger.record_tabular("value_loss", float(value_loss))
                    logger.record_tabular("explained_variance", float(ev))
                    #logger.dump_tabular()

                    games_wonAI, games_wonRandom, games_finish_in_draw, illegal_games = env.get_stadistics(
                    )

                    summary = tf.Summary()
                    summary.value.add(tag='train/policy_entropy',
                                      simple_value=float(policy_entropy))
                    summary.value.add(tag='train/policy_loss',
                                      simple_value=float(policy_loss))
                    summary.value.add(tag='train/explained_variance',
                                      simple_value=float(ev))
                    summary.value.add(tag='train/value_loss',
                                      simple_value=float(value_loss))
                    summary.value.add(tag='train/games_wonAI',
                                      simple_value=float(games_wonAI))
                    summary.value.add(tag='train/games_wonRandom',
                                      simple_value=float(games_wonRandom))
                    summary.value.add(tag='train/games_finish_in_draw',
                                      simple_value=float(games_finish_in_draw))
                    summary.value.add(tag='train/illegal_games',
                                      simple_value=float(illegal_games))
                    summary.value.add(tag='train/temp',
                                      simple_value=float(temp))
                    summary_writer.add_summary(summary, update)
                    summary_writer.flush()

                    games_wonAI_train_saver.append(games_wonAI)
                    games_wonRandom_train_saver.append(games_wonRandom)
                    games_finish_in_draw_train_saver.append(
                        games_finish_in_draw)
                    illegal_test_games_train_saver.append(illegal_games)
                    update_train.append(update)

                    save_csv(statistics_csv + 'games_wonAI_train.csv',
                             games_wonAI_train_saver)
                    save_csv(statistics_csv + 'games_wonRandom_train.csv',
                             games_wonRandom_train_saver)
                    save_csv(statistics_csv + 'games_finish_in_draw_train.csv',
                             games_finish_in_draw_train_saver)
                    save_csv(statistics_csv + 'illegal_games_train.csv',
                             illegal_test_games_train_saver)
                    save_csv(statistics_csv + 'update_train.csv', update_train)

                    policy_entropy_saver.append(policy_entropy)
                    policy_loss_saver.append(policy_loss)
                    ev_saver.append(ev)
                    value_loss_saver.append(value_loss)

                    save_csv(statistics_csv + 'policy_entropy.csv',
                             policy_entropy_saver)
                    save_csv(statistics_csv + 'policy_loss.csv',
                             policy_loss_saver)
                    save_csv(statistics_csv + 'ev.csv', ev_saver)
                    save_csv(statistics_csv + 'value_loss.csv',
                             value_loss_saver)

            if (update % (log_interval * 10)) == 0:
                print('Save check point')
                model.save(model_path + parameters + '.cpkt')

    env.close()
Esempio n. 4
0
def play(policy,
         policy2,
         env,
         seed,
         nsteps=5,
         nstack=4,
         total_timesteps=int(80e6),
         vf_coef=0.5,
         ent_coef=0.01,
         max_grad_norm=0.5,
         lr=7e-4,
         lrschedule='linear',
         epsilon=1e-5,
         alpha=0.99,
         gamma=0.99,
         log_interval=20,
         model_path='',
         model_path2=''):
    tf.reset_default_graph()
    set_global_seeds(seed)

    nenvs = env.num_envs
    ob_space = env.observation_space
    ac_space = env.action_space
    num_procs = len(env.remotes)  # HACK
    statistics_path = ('./stadistics_random')
    summary_writer = tf.summary.FileWriter(statistics_path)

    # model = Model(policy=policy, ob_space=ob_space, ac_space=ac_space, nenvs=nenvs, nsteps=nsteps, nstack=nstack,
    #                num_procs=num_procs, ent_coef=ent_coef, vf_coef=vf_coef, size=env.get_board_size(),
    #                max_grad_norm=max_grad_norm, lr=lr, alpha=alpha, epsilon=epsilon, total_timesteps=total_timesteps,
    #                lrschedule=lrschedule)
    model = Model2(policy=policy,
                   ob_space=ob_space,
                   ac_space=ac_space,
                   nenvs=nenvs,
                   nsteps=nsteps,
                   nstack=nstack,
                   num_procs=num_procs,
                   ent_coef=ent_coef,
                   vf_coef=vf_coef,
                   max_grad_norm=max_grad_norm,
                   lr=lr,
                   alpha=alpha,
                   epsilon=epsilon,
                   total_timesteps=total_timesteps,
                   lrschedule=lrschedule,
                   summary_writter=summary_writer)

    # model2 = Model(policy=policy2, ob_space=ob_space, ac_space=ac_space, nenvs=nenvs, nsteps=nsteps, nstack=nstack,
    #                num_procs=num_procs, ent_coef=ent_coef, vf_coef=vf_coef, size=env.get_board_size(),
    #                max_grad_norm=max_grad_norm, lr=lr, alpha=alpha, epsilon=epsilon, total_timesteps=total_timesteps,
    #                lrschedule=lrschedule)
    model2 = Model2(policy=policy2,
                    ob_space=ob_space,
                    ac_space=ac_space,
                    nenvs=nenvs,
                    nsteps=nsteps,
                    nstack=nstack,
                    num_procs=num_procs,
                    ent_coef=ent_coef,
                    vf_coef=vf_coef,
                    max_grad_norm=max_grad_norm,
                    lr=lr,
                    alpha=alpha,
                    epsilon=epsilon,
                    total_timesteps=total_timesteps,
                    lrschedule=lrschedule,
                    summary_writter=summary_writer)

    model.load(model_path)
    model2.load(model_path2)
    runner = Runner(env,
                    model,
                    model2,
                    nsteps=nsteps,
                    nstack=nstack,
                    gamma=gamma)
    runner.mcts.start_play()
    env.close()
Esempio n. 5
0
def learn(policy,
          policy2,
          env,
          seed,
          nsteps=5,
          nstack=4,
          total_timesteps=int(80e6),
          vf_coef=0.5,
          ent_coef=0.01,
          max_grad_norm=0.5,
          lr=7e-4,
          lrschedule='linear',
          epsilon=1e-5,
          alpha=0.99,
          gamma=0.99,
          log_interval=20,
          load_model=False,
          model_path='',
          model_path2=''):
    tf.reset_default_graph()
    set_global_seeds(seed)

    nenvs = env.num_envs
    ob_space = env.observation_space
    ac_space = env.action_space
    num_procs = len(env.remotes)  # HACK

    model = Model(policy=policy,
                  ob_space=ob_space,
                  ac_space=ac_space,
                  nenvs=nenvs,
                  nsteps=nsteps,
                  nstack=nstack,
                  num_procs=num_procs,
                  ent_coef=ent_coef,
                  vf_coef=vf_coef,
                  size=env.get_board_size(),
                  max_grad_norm=max_grad_norm,
                  lr=lr,
                  alpha=alpha,
                  epsilon=epsilon,
                  total_timesteps=total_timesteps,
                  lrschedule=lrschedule)

    model2 = Model(policy=policy2,
                   ob_space=ob_space,
                   ac_space=ac_space,
                   nenvs=nenvs,
                   nsteps=nsteps,
                   nstack=nstack,
                   num_procs=num_procs,
                   ent_coef=ent_coef,
                   vf_coef=vf_coef,
                   size=env.get_board_size(),
                   max_grad_norm=max_grad_norm,
                   lr=lr,
                   alpha=alpha,
                   epsilon=epsilon,
                   total_timesteps=total_timesteps,
                   lrschedule=lrschedule)

    if load_model:
        model.load(model_path)
        model2.load(model_path2)
    runner = Runner(env,
                    model,
                    model2,
                    nsteps=nsteps,
                    nstack=nstack,
                    gamma=gamma)

    nbatch = nenvs * nsteps
    tstart = time.time()
    policy_loss_saver, value_loss_saver, policy_entropy_saver = [], [], []
    policy_loss_saver_2, value_loss_saver_2, policy_entropy_saver_2 = [], [], []
    for update in range(1, total_timesteps // nbatch + 1):
        policy_loss, value_loss, policy_entropy, policy_loss_2, value_loss_2, policy_entropy_2 = runner.run(
        )

        policy_loss_saver.append(str(policy_loss))
        value_loss_saver.append(value_loss)
        policy_entropy_saver.append(policy_entropy)
        policy_loss_saver_2.append(str(policy_loss_2))
        value_loss_saver_2.append(value_loss_2)
        policy_entropy_saver_2.append(policy_entropy_2)

        nseconds = time.time() - tstart
        fps = float((update * nbatch) / nseconds)

        if update % log_interval == 0 or update == 1:
            runner.mcts.print_statistic()
            logger.record_tabular("nupdates", update)
            logger.record_tabular("total_timesteps", update * nbatch)
            logger.record_tabular("fps", fps)
            logger.record_tabular("policy_entropy", float(policy_entropy))
            logger.record_tabular("value_loss", float(value_loss))
            logger.record_tabular("policy_entropy_2", float(policy_entropy_2))
            logger.record_tabular("value_loss_2", float(value_loss_2))
            logger.dump_tabular()
        if (update % (log_interval * 10)) == 0:
            logger.warn('Try to save cpkt file.')
            model.save(model_path)
            model2.save(model_path2)

            PolicyLossFile = "../statistics/policy_loss.csv"
            with open(PolicyLossFile, 'w') as f:
                writer = csv.writer(f, lineterminator='\n', delimiter=',')
                writer.writerow(float(val) for val in policy_loss_saver)

            ValueLossFile = "../statistics/value_loss.csv"
            with open(ValueLossFile, 'w') as f:
                writer = csv.writer(f, lineterminator='\n', delimiter=',')
                writer.writerow(float(val) for val in value_loss_saver)

            PolicyEntropyFile = "../statistics/policy_entropy_loss.csv"
            with open(PolicyEntropyFile, 'w') as f:
                writer = csv.writer(f, lineterminator='\n', delimiter=',')
                writer.writerow(float(val) for val in policy_entropy_saver)

            # Second model
            PolicyLossFile = "../statistics/policy_loss_2.csv"
            with open(PolicyLossFile, 'w') as f:
                writer = csv.writer(f, lineterminator='\n', delimiter=',')
                writer.writerow(float(val) for val in policy_loss_saver_2)

            ValueLossFile = "../statistics/value_loss_2.csv"
            with open(ValueLossFile, 'w') as f:
                writer = csv.writer(f, lineterminator='\n', delimiter=',')
                writer.writerow(float(val) for val in value_loss_saver_2)

            PolicyEntropyFile = "../statistics/policy_entropy_loss_2.csv"
            with open(PolicyEntropyFile, 'w') as f:
                writer = csv.writer(f, lineterminator='\n', delimiter=',')
                writer.writerow(float(val) for val in policy_entropy_saver_2)

    runner.mcts.visualization()
    env.close()
Esempio n. 6
0
def learn(policy,
          policy_2,
          env,
          seed,
          nsteps=5,
          nstack=4,
          total_timesteps=int(80e6),
          vf_coef=0.5,
          ent_coef=0.01,
          max_grad_norm=0.5,
          lr=7e-4,
          lrschedule='linear',
          epsilon=1e-5,
          alpha=0.99,
          gamma=0.99,
          log_interval=1000,
          load_model=False,
          model_path='',
          data_augmentation=True,
          TRAINING_BATCH=10):
    tf.reset_default_graph()
    set_global_seeds(seed)

    nenvs = env.num_envs
    ob_space = env.observation_space
    ac_space = env.action_space
    num_procs = len(env.remotes)  # HACK
    statistics_path = ('./stadistics')
    summary_writer = tf.summary.FileWriter(statistics_path)
    run_test = 2000
    temp = np.ones(1)
    CHANGE_PLAYER = 100

    model = Model(policy=policy,
                  ob_space=ob_space,
                  ac_space=ac_space,
                  nenvs=nenvs,
                  nsteps=nsteps,
                  nstack=nstack,
                  num_procs=num_procs,
                  ent_coef=ent_coef,
                  vf_coef=vf_coef,
                  max_grad_norm=max_grad_norm,
                  lr=lr,
                  alpha=alpha,
                  epsilon=epsilon,
                  total_timesteps=total_timesteps,
                  lrschedule=lrschedule)

    model_2 = Model_2(policy=policy_2,
                      ob_space=ob_space,
                      ac_space=ac_space,
                      nenvs=nenvs,
                      nsteps=nsteps,
                      nstack=nstack,
                      num_procs=num_procs,
                      ent_coef=ent_coef,
                      vf_coef=vf_coef,
                      max_grad_norm=max_grad_norm,
                      lr=lr,
                      alpha=alpha,
                      epsilon=epsilon,
                      total_timesteps=total_timesteps,
                      lrschedule=lrschedule)

    if load_model:
        # model_A.load('./models/model_A.cpkt')
        # model_B.load('./models/model_B.cpkt')
        print('Model loaded')

    runner = Runner(env,
                    model,
                    model_2,
                    nsteps=nsteps,
                    nstack=nstack,
                    gamma=gamma)

    nbatch = nenvs * nsteps
    tstart = time.time()
    for update in range(0, total_timesteps // nbatch + 1):
        if update % 999 == 0:
            print('update', update)
            env.print_stadistics_vs()
            games_A, games_B, games_finish_in_draw, illegal_games = env.get_stadistics_vs(
            )

        if (update % run_test < 1000) and (update % run_test > 0):
            runner.test(temp)

            if ((update % run_test) == 999):
                summary = tf.Summary()
                summary.value.add(tag='test/games_A',
                                  simple_value=float(games_A))
                summary.value.add(tag='test/games_B',
                                  simple_value=float(games_B))
                summary.value.add(tag='test/games_finish_in_draw',
                                  simple_value=float(games_finish_in_draw))
                summary.value.add(tag='test/illegal_games',
                                  simple_value=float(illegal_games))
                summary_writer.add_summary(summary, update)

                summary_writer.flush()
        else:

            if update % CHANGE_PLAYER == 0 and update != 0:
                aux = model_2
                model_2 = model
                model = aux
                runner = Runner(env,
                                model,
                                model_2,
                                nsteps=nsteps,
                                nstack=nstack,
                                gamma=gamma)

            obs, states, rewards, masks, actions, values, obs_B, states_B, rewards_B, masks_B, actions_B, values_B = runner.run(
                temp)

            obs, states, rewards, masks, actions, values = redimension_results(
                obs, states, rewards, masks, actions, values, env, nsteps)
            obs_B, states_B, rewards_B, masks_B, actions_B, values_B = redimension_results(
                obs_B, states_B, rewards_B, masks_B, actions_B, values_B, env,
                nsteps)

            size_batch = runner.put_in_batch(obs, states, rewards, masks,
                                             actions, values, obs_B, states_B,
                                             rewards_B, masks_B, actions_B,
                                             values_B)

            if size_batch == TRAINING_BATCH:

                batch = runner.get_batch()
                for i in range(len(batch)):
                    obs, states, rewards, masks, actions, values, obs_B, states_B, rewards_B, masks_B, actions_B, values_B = batch.get(
                        i)
                    if data_augmentation:
                        policy_loss, value_loss, policy_entropy = train_data_augmentation(
                            obs, states, rewards, masks, actions, values,
                            model, temp)
                        policy_loss_B, value_loss_B, policy_entropy_B = train_data_augmentation(
                            obs_B, states_B, rewards_B, masks_B, actions_B,
                            values_B, model_2, temp)
                    else:
                        policy_loss, value_loss, policy_entropy = train_without_data_augmentation(
                            obs, states, rewards, masks, actions, values,
                            model, temp)
                        policy_loss_B, value_loss_B, policy_entropy_B = train_without_data_augmentation(
                            obs_B, states_B, rewards_B, masks_B, actions_B,
                            values_B, model_2, temp)
                runner.empty_batch()

                nseconds = time.time() - tstart
                fps = int((update * nbatch) / nseconds)
                ev = explained_variance(values, rewards)
                ev_B = explained_variance(values_B, rewards_B)
                if update % (TRAINING_BATCH * 100) == 0:
                    print('update:', update)
                    logger.record_tabular("nupdates", update)
                    logger.record_tabular("total_timesteps", update * nbatch)
                    logger.record_tabular("fps", fps)

                    logger.record_tabular("policy_entropy",
                                          float(policy_entropy))
                    logger.record_tabular("policy_loss", float(policy_loss))
                    logger.record_tabular("value_loss", float(value_loss))
                    logger.record_tabular("explained_variance", float(ev))

                    logger.record_tabular("policy_entropy_B",
                                          float(policy_entropy_B))
                    logger.record_tabular("policy_loss_B",
                                          float(policy_loss_B))
                    logger.record_tabular("value_loss_B", float(value_loss_B))
                    logger.record_tabular("explained_variance_B", float(ev_B))
                    logger.dump_tabular()

                summary = tf.Summary()
                summary.value.add(tag='train_A/policy_entropy',
                                  simple_value=float(policy_entropy))
                summary.value.add(tag='train_loss/policy_loss_A',
                                  simple_value=float(policy_loss))
                summary.value.add(tag='train_A/explained_variance',
                                  simple_value=float(ev))
                summary.value.add(tag='train_loss/value_loss_A',
                                  simple_value=float(value_loss))

                summary.value.add(tag='train_B/policy_entropy',
                                  simple_value=float(policy_entropy_B))
                summary.value.add(tag='train_loss/policy_loss',
                                  simple_value=float(policy_loss_B))
                summary.value.add(tag='train_B/explained_variance_B',
                                  simple_value=float(ev_B))
                summary.value.add(tag='train_loss/value_loss_B',
                                  simple_value=float(value_loss_B))

                summary.value.add(tag='train/wan_A',
                                  simple_value=float(games_A))
                summary.value.add(tag='train/wan_B',
                                  simple_value=float(games_B))
                summary.value.add(tag='train/games_finish_in_draw',
                                  simple_value=float(games_finish_in_draw))
                summary_writer.add_summary(summary, update)

                summary_writer.flush()

            if (update % (log_interval * 1)) == 0:
                model.save('../models/tic_tac_toe.cpkt')