Beispiel #1
0
                                        arch.nn_prob_move_unit_valid_mvs
                                    ],
                                                              feed_dict=d)[0]
                                else:
                                    to_coords = arch.sess.run([
                                        arch.nn_max_prob_to_coords_valid_mvs,
                                        arch.nn_max_prob_move_unit_valid_mvs
                                    ],
                                                              feed_dict=d)[0]

                            board_tmp2 = arch.sess.run(arch.gm_vars['board'])
                            n_mvs += board_tmp.sum() - board_tmp2.sum()

                            # move opposing player
                            if gnu:
                                gt.move_nn(to_coords)

                                # mv gnugo
                                ai_to_coords = gt.move_ai()
                                arch.sess.run(
                                    arch.imgs,
                                    feed_dict={arch.moving_player: 1})
                                arch.sess.run(arch.nn_max_move_unit,
                                              feed_dict={
                                                  arch.moving_player:
                                                  1,
                                                  arch.nn_max_to_coords:
                                                  ai_to_coords
                                              })
                            else:
                                arch.sess.run(arch.imgs, feed_dict=ret_d(1))
Beispiel #2
0
def worker(i_WORKER_ID):
    global WORKER_ID, weights_current, weights_eval_current, weights_eval32_current, val_mean_sq_err, pol_cross_entrop_err, val_pearsonr
    global board, winner, tree_probs, save_d, bp_eval_nodes, t_start, run_time, save_nm
    WORKER_ID = i_WORKER_ID

    err_denom = 0
    val_pearsonr = 0
    val_mean_sq_err = 0
    pol_cross_entrop_err = 0
    t_start = datetime.now()
    run_time = datetime.now() - datetime.now()

    #### restore
    save_d = np.load(sdir + save_nm, allow_pickle=True).item()

    for key in save_vars + state_vars + training_ex_vars:
        if (key == 'save_nm') or (key in shared_nms):
            continue
        exec('global ' + key)
        exec('%s = save_d["%s"]' % (key, key))

    EPS_ORIG = EPS
    #EPS = 2e-3 ###################################################### < overrides previous backprop step sizes

    ############# init / load model
    DEVICE = '/gpu:%i' % WORKER_ID
    arch.init_model(DEVICE, N_FILTERS, FILTER_SZS, STRIDES, N_FC1, EPS,
                    MOMENTUM, LSQ_LAMBDA, LSQ_REG_LAMBDA,
                    POL_CROSS_ENTROP_LAMBDA, VAL_LAMBDA, VALR_LAMBDA,
                    L2_LAMBDA)

    bp_eval_nodes = [
        arch.train_step, arch.val_mean_sq_err, arch.pol_cross_entrop_err,
        arch.val_pearsonr
    ]

    # ops for trainable weights
    weights_current = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                        scope='main')
    weights_eval_current = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                             scope='eval/')
    weights_eval32_current = tf.get_collection(
        tf.GraphKeys.TRAINABLE_VARIABLES, scope='eval32')

    if new_model == False:
        print 'restore nm %s' % save_nm
        arch.saver.restore(arch.sess, sdir + save_nm)
        if WORKER_ID == MASTER_WORKER:
            set_all_shared_to_loaded()
    else:  #### sync model weights
        if WORKER_ID == MASTER_WORKER:
            set_all_to_eval32_and_get()
        else:
            while set_weights() == False:  # wait for weights to be set
                continue
    ###### shared variables
    board = np.frombuffer(s_board.get_obj(), 'float16').reshape(
        (BUFFER_SZ, gv.n_rows, gv.n_cols, gv.n_input_channels))
    winner = np.frombuffer(s_winner.get_obj(), 'int8').reshape(
        (N_BATCH_SETS_TOTAL, N_TURNS, 2, gv.BATCH_SZ))
    tree_probs = np.frombuffer(s_tree_probs.get_obj(), 'float32').reshape(
        (BUFFER_SZ, gv.map_szt))

    ######## local variables
    # BUFFER_SZ = N_BATCH_SETS * N_TURNS * 2 * gv.BATCH_SZ
    L_BUFFER_SZ = N_TURNS * 2 * gv.BATCH_SZ
    board_local = np.zeros(
        (L_BUFFER_SZ, gv.n_rows, gv.n_cols, gv.n_input_channels),
        dtype='float16')
    winner_local = np.zeros((N_TURNS, 2, gv.BATCH_SZ), dtype='int8')
    tree_probs_local = np.zeros((L_BUFFER_SZ, gv.map_szt), dtype='float32')

    if EPS_ORIG != EPS:
        #save_nm += 'EPS_%2.4f.npy' % EPS
        save_d['EPS'] = EPS
        print 'saving to', save_nm

    ### sound
    if WORKER_ID == MASTER_WORKER:
        pygame.init()
        pygame.mixer.music.load('/home/tapa/gtr-nylon22.mp3')

    ######
    while True:
        #### generate training batches with `main` model
        arch.sess.run(arch.init_state)
        pu.init_tree()
        turn_start_t = time.time()
        buffer_loc_local = 0
        for turn in range(N_TURNS):
            ### make move
            for player in [0, 1]:
                set_weights()
                run_sim(turn, player)  # using `main` model

                inds = buffer_loc_local + np.arange(
                    gv.BATCH_SZ)  # inds to save training vars at
                board_local[inds], valid_mv_map, pol = arch.sess.run(
                    [arch.imgs, arch.valid_mv_map, arch.pol['main']],
                    feed_dict=ret_d(player))  # generate batch and valid moves

                #########
                pu.add_valid_mvs(player,
                                 valid_mv_map)  # register valid moves in tree
                visit_count_map = pu.choose_moves(
                    player, np.array(pol, dtype='single'),
                    CPUCT)[-1]  # get number of times each node was visited

                tree_probs_local[inds] = visit_count_map / visit_count_map.sum(
                    1)[:, np.newaxis]

                to_coords = arch.sess.run(
                    [arch.tree_prob_visit_coord, arch.tree_prob_move_unit],
                    feed_dict={
                        arch.moving_player: player,
                        arch.visit_count_map: visit_count_map
                    })[0]  # make move in proportion to visit counts

                pu.register_mv(player, np.array(
                    to_coords, dtype='int32'))  # register move in tree

                ###############

                buffer_loc_local += gv.BATCH_SZ

            pu.prune_tree(0)

            if (turn + 1) % 2 == 0:
                print 'finished turn %i (%i sec) GPU %i batch_sets_created %i (total %i)' % (
                    turn, time.time() - turn_start_t, WORKER_ID,
                    batch_sets_created.value, batch_sets_created_total.value)

        ##### create prob maps
        for player in [0, 1]:
            winner_local[:, player] = arch.sess.run(
                arch.winner, feed_dict={arch.moving_player: player})

        #### set shared buffers with training variables we just generated from self-play
        with buffer_lock:
            board[buffer_loc.value:buffer_loc.value +
                  buffer_loc_local] = board_local
            tree_probs[buffer_loc.value:buffer_loc.value +
                       buffer_loc_local] = tree_probs_local
            winner[batch_set.value] = winner_local

            buffer_loc.value += buffer_loc_local
            batch_sets_created.value += 1
            batch_sets_created_total.value += 1
            batch_set.value += 1

            # save checkpoint
            if buffer_loc.value >= BUFFER_SZ or batch_set.value >= N_BATCH_SETS_TOTAL:
                buffer_loc.value = 0
                batch_set.value = 0

                # save batch only
                batch_d = {}
                for key in ['tree_probs', 'winner', 'board']:
                    exec(
                        'batch_d["%s"] = copy.deepcopy(np.array(s_%s.get_obj()))'
                        % (key, key))
                batch_save_nm = sdir + save_nm + '_batches' + str(
                    batch_sets_created_total.value)
                np.save(batch_save_nm, batch_d)
                print 'saved', batch_save_nm
                batch_d = {}

        ################ train/eval/test
        if WORKER_ID == MASTER_WORKER and batch_sets_created.value >= N_BATCH_SETS_BLOCK and batch_sets_created_total.value >= N_BATCH_SETS_MIN:
            ########### train
            with buffer_lock:
                if batch_sets_created_total.value < (
                        N_BATCH_SETS_MIN + N_BATCH_SETS_BLOCK
                ):  # don't overtrain on the initial set
                    batch_sets_created.value = N_BATCH_SETS_BLOCK

                if batch_sets_created.value >= N_BATCH_SETS_TOTAL:  # if for some reason master worker gets delayed
                    batch_sets_created.value = N_BATCH_SETS_BLOCK

                board_c = np.array(board, dtype='single')
                winner_rc = np.array(winner.ravel(), dtype='single')

                valid_entries = np.prod(
                    np.isnan(tree_probs) == False, 1) * np.nansum(
                        tree_probs,
                        1)  # remove examples with nans or no probabilties
                inds_valid = np.nonzero(valid_entries)[0]
                print len(
                    inds_valid), 'out of', BUFFER_SZ, 'valid training examples'

                for rep in range(N_REP_TRAIN):
                    random.shuffle(inds_valid)
                    for batch in range(N_TURNS * batch_sets_created.value):
                        inds = inds_valid[batch * gv.BATCH_SZ +
                                          np.arange(gv.BATCH_SZ)]

                        board2, tree_probs2 = pu.rotate_reflect_imgs(
                            board_c[inds], tree_probs[inds]
                        )  # rotate and reflect board randomly

                        train_dict = {
                            arch.imgs32: board2,
                            arch.pol_target: tree_probs2,
                            arch.val_target: winner_rc[inds]
                        }

                        val_mean_sq_err_tmp, pol_cross_entrop_err_tmp, val_pearsonr_tmp = \
                                arch.sess.run(bp_eval_nodes, feed_dict=train_dict)[1:]

                        # update logs
                        val_mean_sq_err += val_mean_sq_err_tmp
                        pol_cross_entrop_err += pol_cross_entrop_err_tmp
                        val_pearsonr += val_pearsonr_tmp
                        global_batch += 1
                        err_denom += 1

                batch_sets_created.value = 0

            ############### `eval` against prior version of self (`main`)
            set_eval16_to_eval32_start_eval(
            )  # update `eval` tf and shared copies to follow backprop (`eval32`)
            eval_model()  # run match(es)
            with eval_stats_lock:
                print '-------------------'
                model_outperforms, self_eval_perc = print_eval_stats()
                print '------------------'
            if model_outperforms:  # update `eval` AND `main` both tf and shared copies to follow backprop
                set_all_to_eval32_and_get()

            ##### network evaluation against random player and GNU Go
            global_batch_evald = global_batch
            global_batch_saved = global_batch
            t_eval = time.time()
            print 'evaluating nn'

            d = ret_d(0)

            ################## monitor training progress:
            # test `eval` against GNU Go and a player that makes only random moves
            for nm, N_GMS_L in zip(['nn', 'tree'],
                                   [[N_EVAL_NN_GNU_GMS, N_EVAL_NN_GMS],
                                    [N_EVAL_TREE_GMS, N_EVAL_TREE_GNU_GMS]]):
                for gnu, N_GMS in zip([True, False], N_GMS_L):
                    if N_GMS == 0:
                        continue
                    key = '%s%s' % (nm, '' + gnu * '_gnu')
                    t_key = time.time()
                    boards[key] = np.zeros((N_TURNS, ) + gv.INPUTS_SHAPE[:-1],
                                           dtype='int8')
                    n_mvs = 0.
                    win_eval = 0.
                    score_eval = 0.
                    n_captures_eval = np.zeros(2, dtype='single')
                    for gm in range(N_GMS):
                        arch.sess.run(arch.init_state)
                        pu.init_tree()
                        # init gnu state
                        if gnu:
                            gt.init_board(arch.sess.run(arch.gm_vars['board']))

                        for turn in range(N_TURNS):
                            board_tmp = arch.sess.run(arch.gm_vars['board'])

                            #### search / make move
                            if nm == 'tree':
                                run_sim(turn)
                                assert False
                            else:
                                # prob choose first move, deterministically choose remainder
                                if turn == 0:
                                    to_coords = arch.sess.run([
                                        arch.
                                        nn_prob_to_coords_valid_mvs['eval'],
                                        arch.
                                        nn_prob_move_unit_valid_mvs['eval']
                                    ],
                                                              feed_dict=d)[0]
                                else:
                                    to_coords = arch.sess.run([
                                        arch.nn_max_prob_to_coords_valid_mvs[
                                            'eval'], arch.
                                        nn_max_prob_move_unit_valid_mvs['eval']
                                    ],
                                                              feed_dict=d)[0]

                            board_tmp2 = arch.sess.run(arch.gm_vars['board'])
                            n_mvs += board_tmp.sum() - board_tmp2.sum()

                            # move opposing player
                            if gnu:
                                gt.move_nn(to_coords)

                                # mv gnugo
                                ai_to_coords = gt.move_ai()
                                arch.sess.run(
                                    arch.imgs,
                                    feed_dict={arch.moving_player: 1})
                                arch.sess.run(
                                    arch.nn_max_move_unit['eval'],
                                    feed_dict={
                                        arch.moving_player: 1,
                                        arch.nn_max_to_coords['eval']:
                                        ai_to_coords
                                    })
                            else:
                                arch.sess.run(arch.imgs, feed_dict=ret_d(1))
                                arch.sess.run(arch.move_random_ai,
                                              feed_dict=ret_d(1))

                            boards[key][turn] = arch.sess.run(
                                arch.gm_vars['board'])

                            if nm == 'tree':
                                pu.prune_tree(0)
                            # turn

                        # save stats
                        win_tmp, score_tmp, n_captures_tmp = arch.sess.run(
                            [arch.winner, arch.score, arch.n_captures],
                            feed_dict={arch.moving_player: 0})
                        scores[key] = copy.deepcopy(score_tmp)

                        win_eval += win_tmp.mean()
                        score_eval += score_tmp.mean()
                        n_captures_eval += n_captures_tmp.mean(1)
                        # gm

                    # log
                    log['win_' + key].append((win_eval /
                                              (2 * np.single(N_GMS))) + .5)
                    log['n_captures_' + key].append(n_captures_eval[0] /
                                                    np.single(N_GMS))
                    log['n_captures_opp_' + key].append(n_captures_eval[1] /
                                                        np.single(N_GMS))
                    log['score_' + key].append(score_eval / np.single(N_GMS))
                    log['n_mvs_' + key].append(
                        n_mvs / np.single(N_GMS * N_TURNS * gv.BATCH_SZ))

                    log['boards_' + key].append(boards[key][-1])
                    print key, 'eval time', time.time() - t_key
                    # gnu
                # nm
            log['eval_batch'].append(global_batch)
            print 'eval time', time.time() - t_eval
            # eval
            ####################### end network evaluation

            pol, pol_pre = arch.sess.run(
                [arch.pol['eval'], arch.pol_pre['eval']],
                feed_dict={arch.moving_player: 0})

            ##### log
            log['val_mean_sq_err'].append(val_mean_sq_err / err_denom)
            log['pol_cross_entrop'].append(pol_cross_entrop_err / err_denom)
            log['val_pearsonr'].append(val_pearsonr / err_denom)
            log['opt_batch'].append(global_batch)

            log['pol_max_pre'].append(np.median(pol_pre.max(1)))
            log['pol_max'].append(np.median(pol.max(1)))

            log['self_eval_win_rate'].append(
                np.single(eval_games_won.value) /
                (eval_batch_sets_played.value * gv.BATCH_SZ))
            log['model_promoted'].append(model_outperforms)

            log['self_eval_perc'].append(self_eval_perc)

            val_mean_sq_err = 0
            pol_cross_entrop_err = 0
            val_pearsonr = 0
            err_denom = 0

            ########## print
            run_time += datetime.now() - t_start

            if (save_counter % 20) == 0:
                print
                print Style.BRIGHT + Fore.GREEN + save_nm, Fore.WHITE + 'EPS', EPS, 'start', str(start_time).split('.')[0], 'run time', \
                  str(run_time).split('.')[0]
                print
            save_counter += 1

            print_str = '%i' % global_batch
            for key in print_logs:
                print_str += ' %s ' % key
                if isinstance(log[key], int):
                    print_str += str(log[key][-1])
                else:
                    print_str += '%1.4f' % log[key][-1]

            print_str += ' %4.1f' % (datetime.now() - t_start).total_seconds()
            print print_str

            t_start = datetime.now()

            # play sound
            if os.path.isfile('/home/tapa/play_sound.txt'):
                pygame.mixer.music.play()

        ############# save
        if WORKER_ID == MASTER_WORKER:
            with buffer_lock:
                # update state vars
                #shared_nms = ['buffer_loc', 'batch_sets_created', 'batch_set', 's_board', 's_winner', 's_tree_probs', 'weights_changed', 'buffer_lock', 'weights_lock', 'save_nm', 'new_model', 'weights']
                for key in state_vars + training_ex_vars:
                    if key in [
                            'buffer_loc', 'batch_sets_created',
                            'batch_sets_created_total', 'batch_set',
                            'eval_games_won', 'eval_batch_sets_played'
                    ]:
                        exec('save_d["%s"] = %s.value' % (key, key))
                    elif key in ['tree_probs', 'winner', 'board']:
                        exec(
                            'save_d["%s"] = copy.deepcopy(np.array(s_%s.get_obj()))'
                            % (key, key))
                    else:
                        exec('save_d["%s"] = %s' % (key, key))

            save_nms = [save_nm]
            if (datetime.now() - save_t).seconds > CHKP_FREQ:
                save_nms += [save_nm + str(datetime.now())]
                save_t = datetime.now()

            for nm in save_nms:
                np.save(sdir + nm, save_d)
                arch.saver.save(arch.sess, sdir + nm)

            print sdir + nm, 'saved'