示例#1
0
    def save_sample_data(self):
        if self.training_in_progress:
            log.warning("skip writing json (gzipped): %s" %
                        self.sample_data_filename)
            return

        gen_samples = datadesc.GenerationSamples()
        gen_samples.game = self.conf.game
        gen_samples.date_created = get_date_string()

        gen_samples.with_generation = self.get_generation_name(
            self.conf.current_step)

        # only save the minimal number for this run
        gen_samples.num_samples = min(len(self.accumulated_samples),
                                      self.conf.num_samples_to_train)
        gen_samples.samples = self.accumulated_samples[:gen_samples.
                                                       num_samples]

        # write json file
        json.encoder.FLOAT_REPR = lambda f: ("%.5f" % f)

        log.info("writing json (gzipped): %s" % self.sample_data_filename)
        with gzip.open(self.sample_data_filename, 'w') as f:
            f.write(attrutil.attr_to_json(gen_samples, pretty=False))

        return gen_samples
示例#2
0
文件: server.py 项目: vipmath/ggplib
    def handle_stop(self, symbols):
        assert len(symbols) == 3
        match_id = symbols[1]
        if self.current_match is None:
            log.warning("rx'd 'stop' for non-current match %s" % match_id)
            return "busy"
        if self.current_match.match_id != match_id:
            log.error("rx'd 'stop' different from current match (%s != %s)" % (match_id, self.current_match.match_id))
            return "busy"

        move = symbols[2]

        # XXX bug with standford 'player checker'??? XXX need to find out what is going on here?
        if isinstance(move, str) and move.lower != "nil":
            move = self.symbol_factory.symbolize("( %s )" % move)

        res = self.current_match.do_play(move)
        if res != "done":
            log.error("Game was NOT done %s" % self.current_match.match_id)

        else:
            # cancel any timeout callbacks
            self.update_gameserver_timeout(None)
            self.current_match.do_stop()

        self.current_match = None
        return "done"
示例#3
0
def main_wrap(main_fn, logfile_name=None, **kwds):
    if logfile_name is None:
        # if logfile_name not set, derive it from main_fn
        fn = main_fn.func_code.co_filename
        logfile_name = os.path.splitext(os.path.basename(fn))[0]

    setup_once(logfile_name)

    try:
        # we might be running under python with no keras/numpy support
        init(**kwds)

    except ImportError as exc:
        log.warning("ImportError: %s" % exc)

    try:
        if main_fn.func_code.co_argcount == 0:
            return main_fn()
        else:
            return main_fn(sys.argv[1:])

    except Exception as exc:
        print exc
        _, _, tb = sys.exc_info()
        traceback.print_exc()
        pdb.post_mortem(tb)
示例#4
0
def lookup_all_games():
    # ensure things are initialised
    from ggplib.util.init import setup_once
    setup_once()

    failed = []
    known_to_fail = [
        'amazonsTorus_10x10', 'atariGoVariant_7x7', 'gt_two_thirds_4p',
        'gt_two_thirds_6p', 'linesOfAction'
    ]
    for game in lookup.get_all_game_names():
        if game not in known_to_fail:
            try:
                game_info = lookup.by_name(game, build_sm=False)
                assert game_info.game == game
                sm = game_info.get_sm()

                # run some depth charges to ensure we have valid statemachine
                interface.depth_charge(sm, 1)

                log.verbose("DONE GETTING Statemachine FOR GAME %s %s" %
                            (game, sm))
            except lookup.LookupFailed as exc:
                log.warning("Failed to lookup %s: %s" % (game, exc))
                failed.append(game)

    if failed:
        log.error("Failed games %s" % (failed, ))
        assert False, failed
示例#5
0
    def init_spaces(self):
        base_infos = create_base_infos(self.game_info.model)

        self.board_space = self.create_board_space(base_infos)
        self.raw_channels_per_state = max(b.channel_id for b in self.board_space) + 1

        self.control_space = self.create_control_space(base_infos)

        self.num_of_controls_channels = len(self.game_desc.control_channels)

        # warn about any unhandled states
        self.num_unhandled_states = 0
        for b_info in base_infos:
            if not b_info.used:
                self.num_unhandled_states += 1

        if self.num_unhandled_states:
            log.warning("Number of unhandled states %d" % self.num_unhandled_states)

        # sort by channel
        self.by_channel = {}
        for b in self.board_space:
            self.by_channel.setdefault(b.channel_id, []).append(b)

        for cs in self.control_space:
            self.by_channel.setdefault(cs.channel_id + self.raw_channels_per_state, []).append(cs)

        if self.verbose:
            for channel_id, all in self.by_channel.items():
                print()
                print("channel_id: %s" % channel_id)

                for x in all:
                    print("%s -> %s" % (base_infos[x.base_indx].terms, x))
示例#6
0
    def verify_db(self):
        ' checks summary against existing files '
        db_path = os.path.join(self.data_path, "__db__")
        if not os.path.exists(db_path):
            return False

        try:
            self.db = bcolz.open(db_path, mode='a')

            # XXX check columns are correct types

            if self.summary.total_samples != self.db.size:
                msg = "db and summary file different sizes summary: %s != %s" % (
                    self.db.size, self.summary.total_samples)
                log.warning(msg)
                if self.db.size > self.summary.total_samples:
                    log.warning("resizing")
                    self.db.resize(self.summary.total_samples)

                else:
                    raise Check(msg)

        except Exception as exc:
            log.error("error accessing db directory: %s" % exc)
            return False

        return True
示例#7
0
    def configure_self_play(self):
        assert self.self_play_conf is not None

        if self.nn is None:
            self.nn = get_manager().load_network(self.game_info.game,
                                                 self.latest_generation_name)

        if self.supervisor is None:
            self.supervisor = cppinterface.Supervisor(
                self.sm,
                self.nn,
                batch_size=self.conf.self_play_batch_size,
                sleep_between_poll=self.conf.sleep_between_poll)

            self.supervisor.start_self_play(self.self_play_conf,
                                            self.conf.num_workers)

        else:
            # force exit of the worker if there was an update to the config
            if self.conf.exit_on_update_config:
                os._exit(0)

            log.info("Latest generation: %s" % self.latest_generation_name)
            gen = int(self.latest_generation_name.split("_")[-1])
            if gen % self.conf.replace_network_every_n_gens == 0:
                log.warning("Updating network to: %s" % gen)
                self.supervisor.update_nn(self.nn)

            self.supervisor.clear_unique_states()
示例#8
0
def load_module(kif_filename):
    ''' attempts to load a python module with the same filename.  If it does not exist, will run
        java and use ggp-base to create the module. '''

    basename, props_file = kif_filename_to_propfile(kif_filename)
    for cmd in [
            "java -XX:+UseSerialGC -Xmx8G propnet_convert.Convert %s %s" %
        (kif_filename, props_file),
            "java propnet_convert.Convert %s %s" % (kif_filename, props_file),
            "SOMETHING IS BROKEN in install ..."
    ]:
        try:
            # rather unsafe cache, if kif file changes underneath our feet - tough luck.
            module = importlib.import_module("ggplib.props." + basename)
            break
        except ImportError:
            # run java ggp-base to create a propnet.  The resultant propnet will be in props_dir, which can be imported.
            log.debug("Running: %s" % cmd)
            return_code, out, err = run(cmd, shell=True, timeout=60)
            if return_code != 0:
                log.warning("Error code: %s" % err)
            else:
                for l in out.splitlines():
                    log.info("... %s" % l)

            if "SOMETHING" in cmd:
                raise

    return module
示例#9
0
    def save(self):
        # XXX set generation attributes

        man = get_manager()

        man.save_network(self.nn, generation_name=self.next_generation)
        self.do_callbacks()

        ###############################################################################
        # save a previous model for next time
        if self.controller.retrain_best is None:
            log.warning("No retraining network")
            return

        log.info("Saving retraining network with val_policy_acc: %.4f" %
                 (self.controller.retrain_best_val_policy_acc))

        # there is an undocumented keras clone function, but this is sure to work (albeit slow and evil)
        from ggpzero.util.keras import keras_models

        for_next_generation = "%s_prev" % self.next_generation

        prev_model = keras_models.model_from_json(
            self.nn.keras_model.to_json())
        prev_model.set_weights(self.controller.retrain_best)

        prev_generation_descr = attrutil.clone(self.nn.generation_descr)
        prev_generation_descr.name = for_next_generation
        prev_nn = network.NeuralNetwork(self.nn.gdl_bases_transformer,
                                        prev_model, prev_generation_descr)
        man.save_network(prev_nn, for_next_generation)
        self.do_callbacks()
示例#10
0
文件: runcmd.py 项目: vipmath/ggplib
def run(args, verbose=False, cwd=None, shell=False, kill_tree=True, timeout=-1, env=None):
    ' Run a command with a timeout after which it will be forcibly killed. '

    p = Popen(args, shell=shell, cwd=cwd, stdout=PIPE, stderr=PIPE, env=env)

    if timeout != -1:
        signal(SIGALRM, alarm_handler)
        alarm(timeout)

    try:
        stdout, stderr = p.communicate()
        if verbose:
            print stdout, stderr

        if timeout != -1:
            alarm(0)

    except Alarm, e:
        log.warning("Alarm triggered: %s" % e)
        pids = [p.pid]

        if kill_tree:
            pids.extend(get_process_children(p.pid))

        for pid in pids:
            # process might have died before getting to this line
            # so wrap to avoid OSError: no such process
            try:
                log.warning("killing %s" % pid)
                kill(pid, SIGKILL)
            except OSError:
                pass

        return -9, '', ''
示例#11
0
    def cleanup(self, keep_sm=False):
        try:
            self.player.cleanup()
            if self.verbose:
                log.verbose("done cleanup player: %s" % self.player)
        except Exception as exc:
            log.error("FAILED TO CLEANUP PLAYER: %s" % exc)
            type, value, tb = sys.exc_info()
            log.error(traceback.format_exc())

        # cleanup c++ stuff
        if self.verbose:
            log.warning("cleaning up c++ stuff")

        # all the basestates
        for bs in self.states:
            # cleanup bs
            interface.dealloc_basestate(bs)

        self.states = []

        if self.joint_move:
            interface.dealloc_jointmove(self.joint_move)
            self.joint_move = None

        if self.sm and not keep_sm:
            interface.dealloc_statemachine(self.sm)
            self.sm = None

        if self.verbose:
            log.info("match - done cleaning up")
示例#12
0
    def on_meta_gaming(self, finish_time):
        if self.conf.verbose:
            log.info("PUCTPlayer, match id: %s" % self.match.match_id)

        if self.sm is None or "*" in self.conf.generation:
            if "*" in self.conf.generation:
                log.warning("Using recent generation %s" %
                            self.conf.generation)

            game_info = self.match.game_info
            self.sm = game_info.get_sm()

            man = get_manager()
            gen = self.conf.generation

            self.nn = man.load_network(game_info.game, gen)
            self.poller = PlayPoller(self.sm, self.nn, attr.asdict(self.conf))

            def get_noop_idx(actions):
                for idx, a in enumerate(actions):
                    if "noop" in a:
                        return idx
                assert False, "did not find noop"

            self.role0_noop_legal, self.role1_noop_legal = map(
                get_noop_idx, game_info.model.actions)

        self.poller.player_reset(self.match.game_depth)
示例#13
0
文件: keras.py 项目: vipmath/ggp-zero
def init(data_format='channels_first'):
    assert K.backend() == "tensorflow"

    if K.image_data_format() != data_format:
        was = K.image_data_format()
        K.set_image_data_format(data_format)
        log.warning("Changing image_data_format: %s -> %s" % (was, K.image_data_format()))

    constrain_resources_tf()
示例#14
0
文件: server.py 项目: vipmath/ggplib
    def handle(self, request):
        content = request.content.getvalue()

        # Tiltyard seems to ping with empty content...
        if content == "":
            return self.handle_info()

        try:
            symbols = list(self.symbol_factory.symbolize(content))

            # get head
            if len(symbols) == 0:
                log.warning('Empty symbols')
                return self.handle_info()

            head = symbols[0]
            if head.lower() == "info":
                res = self.handle_info()

            elif head.lower() == "start":
                log.debug("HEADERS : %s" % pprint.pformat(request.getAllHeaders()))
                log.debug(str(symbols))
                res = self.handle_start(symbols)

            elif head.lower() == "play":
                log.debug(str(symbols))
                res = self.handle_play(symbols)

            elif head.lower() == "stop":
                log.debug(str(symbols))
                res = self.handle_stop(symbols)

            elif head.lower() == "abort":
                log.debug(str(symbols))
                res = self.handle_abort(symbols)

            else:
                log.error("UNHANDLED REQUEST %s" % symbols)

        except Exception as exc:
            log.error("ERROR - aborting: %s" % exc)
            log.error(traceback.format_exc())

            if self.current_match:
                self.abort()

            res = "aborted"

        return res
示例#15
0
def test_building_desc_variations():
    for game, propnet in get_propnets():
        log.warning("test_building_desc() for: %s" % game)

        desc = builder.build_standard_sm(propnet)
        pprint.pprint(desc)

        desc = builder.build_goals_only_sm(propnet)
        pprint.pprint(desc)

        desc = builder.build_combined_state_machine(propnet)
        pprint.pprint(desc)

        desc = builder.build_goalless_sm(propnet)
        pprint.pprint(desc)
示例#16
0
文件: server.py 项目: vipmath/ggplib
    def handle_abort(self, symbols):
        assert len(symbols) == 2
        match_id = symbols[1]

        if self.current_match is None:
            log.warning("rx'd 'abort' for non-current match %s" % match_id)
            return "busy"

        if self.current_match.match_id != match_id:
            log.error("rx'd 'abort' different from current match (%s != %s)" % (match_id, self.current_match.match_id))
            return "busy"

        # cancel any timeout callbacks
        self.abort()
        return "aborted"
示例#17
0
def get_puct_config(gen, **kwds):
    eval_config = confs.PUCTEvaluatorConfig(verbose=True,
                                            puct_constant=0.85,
                                            puct_constant_root=3.0,

                                            dirichlet_noise_pct=-1,

                                            fpu_prior_discount=0.25,
                                            fpu_prior_discount_root=0.15,

                                            choose="choose_temperature",
                                            temperature=2.0,
                                            depth_temperature_max=10.0,
                                            depth_temperature_start=0,
                                            depth_temperature_increment=0.75,
                                            depth_temperature_stop=1,
                                            random_scale=1.0,

                                            max_dump_depth=2,

                                            top_visits_best_guess_converge_ratio=0.8,

                                            think_time=2.0,
                                            converged_visits=2000,

                                            batch_size=32)

    config = confs.PUCTPlayerConfig(name="puct",
                                    verbose=True,
                                    generation=gen,
                                    playouts_per_iteration=-1,
                                    playouts_per_iteration_noop=0,
                                    evaluator_config=eval_config)

    for k, v in kwds.items():
        updated = False
        if at.has(eval_config, k):
            updated = True
            setattr(eval_config, k, v)

        if at.has(config, k):
            updated = True
            setattr(config, k, v)

        if not updated:
            log.warning("Unused setting %s:%s" % (k, v))

    return config
示例#18
0
    def compile(self, compile_strategy, learning_rate=None, value_weight=1.0):
        value_objective = "mean_squared_error"
        policy_objective = 'categorical_crossentropy'
        if compile_strategy == "SGD":
            if learning_rate:
                optimizer = SGD(lr=learning_rate, momentum=0.9)
            else:
                optimizer = SGD(lr=1e-2, momentum=0.9)

        elif compile_strategy == "adam":
            if learning_rate:
                optimizer = Adam(lr=learning_rate)
            else:
                optimizer = Adam()

        elif compile_strategy == "amsgrad":
            if learning_rate:
                optimizer = Adam(lr=learning_rate, amsgrad=True)
            else:
                optimizer = Adam(amsgrad=True)

        else:
            log.error("UNKNOWN compile strategy %s" % compile_strategy)
            raise Exception("UNKNOWN compile strategy %s" % compile_strategy)

        num_policies = len(self.gdl_bases_transformer.policy_dist_count)

        loss = [policy_objective] * num_policies
        loss.append(value_objective)
        loss_weights = [1.0] * num_policies
        loss_weights.append(value_weight)

        if learning_rate is not None:
            msg = "Compiling with %s (learning_rate=%.4f, value_weight=%.3f)"
            log.warning(msg % (optimizer, learning_rate, value_weight))
        else:
            log.warning("Compiling with %s (value_weight=%.3f)" %
                        (optimizer, value_weight))

        def top_3_acc(y_true, y_pred):
            return keras_metrics.top_k_categorical_accuracy(y_true,
                                                            y_pred,
                                                            k=3)

        self.keras_model.compile(loss=loss,
                                 optimizer=optimizer,
                                 loss_weights=loss_weights,
                                 metrics=["acc", top_3_acc])
示例#19
0
    def on_sample_response(self, worker, msg):
        info = self.workers[worker]
        if len(msg.samples) > 0:
            dupe_count = self.add_new_samples(msg.samples)
            if dupe_count:
                log.warning("dropping %s inflight duplicate state(s)" %
                            dupe_count)

            if msg.duplicates_seen:
                log.info("worker saw %s duplicates" % msg.duplicates_seen)

            log.info("len accumulated_samples: %s" %
                     len(self.accumulated_samples))

        self.free_players.append(info)
        reactor.callLater(0, self.schedule_players)
示例#20
0
    def add_new_samples(self, samples):
        dupe_count = 0
        for sample in samples:
            state = sample.state

            # need to check it isn't a duplicate and drop it
            if state in self.unique_states_set:
                dupe_count += 1
            else:
                self.unique_states_set.add(state)
                self.unique_states.append(state)

            self.accumulated_samples.append(sample)

        if dupe_count:
            log.warning("seen %s duplicate state(s)" % dupe_count)
示例#21
0
文件: lookup.py 项目: vipmath/ggplib
    def load(self, verbose=True):
        if verbose:
            log.info("Building the database")

        filenames = self.rulesheets_store.listdir("*.kif")
        for fn in sorted(filenames):
            # skip tmp files
            if fn.startswith("tmp"):
                continue

            game = fn.replace(".kif", "")

            # get the gdl
            gdl_str = self.rulesheets_store.load_contents(fn)

            info = GameInfo(game, gdl_str)

            # first does the game directory exist?
            the_game_store = self.games_store.get_directory(game, create=True)
            if the_game_store.file_exists("sig.json"):
                info.idx = the_game_store.load_json("sig.json")['idx']

            else:
                if verbose:
                    log.verbose("Creating signature for %s" % game)

                info.get_symbol_map()

                if info.symbol_map is None:
                    log.warning("FAILED to add: %s" % game)
                    raise Exception("FAILED TO add %s" % game)

                # save as json
                assert info.idx is not None
                the_game_store.save_json("sig.json", dict(idx=info.idx))

            assert info.idx is not None
            if info.idx in self.idx_mapping:
                other_info = self.idx_mapping[info.idx]
                log.warning("DUPE GAMES: %s %s!=%s" %
                            (info.idx, game, other_info.game))
                raise Exception("Dupes not allowed in database")

            self.idx_mapping[info.idx] = info
            self.game_mapping[info.game] = info
示例#22
0
    def create_db(self):
        ' delete existing bcolz db (warn) and then create a fresh  '
        db_path = os.path.join(self.data_path, "__db__")

        if os.path.exists(db_path):
            log.warning("Please delete old db")
            sys.exit(1)

        # these are columns for bcolz table
        cols = fake_columns(self.transformer)

        # and create a table
        self.db = bcolz.ctable(cols, names=["channels", "policy0", "policy1", "value"], rootdir=db_path)

        # remove the single row
        self.db.resize(0)
        self.db.flush()

        log.info("Created new db")
示例#23
0
    def schedule_players(self):
        if len(self.accumulated_samples) > self.conf.num_samples_to_train:
            # if we haven't started training yet, lets speed things up...
            if not self.training_in_progress:
                self.checkpoint()

        if not self.free_players:
            return

        new_free_players = []
        for worker_info in self.free_players:
            if not worker_info.valid:
                continue

            if self.training_in_progress and worker_info.conf.do_training:
                # will be added back in at the end of training
                continue

            if not worker_info.self_play_configured:
                worker_info.worker.send_msg(self.configure_selfplay_msg)

            else:
                if self.need_more_samples():
                    updates = worker_info.get_and_update(self.unique_states)
                    m = msgs.RequestSamples(updates)
                    if updates:
                        log.debug("sending request with %s updates" %
                                  len(updates))

                    worker_info.worker.send_msg(m)
                else:
                    log.warning("capacity full! %d" %
                                len(self.accumulated_samples))
                    new_free_players.append(worker_info)
                    if time.time() > worker_info.ping_time_sent:
                        self.do_ping(worker_info)

        self.free_players = new_free_players
        if self.free_players:
            reactor.callLater(10.0, self.schedule_players)

        if self.the_nn_trainer is None:
            log.warning("There is no nn trainer - please start")
示例#24
0
文件: train.py 项目: vipmath/ggp-zero
    def files_to_sample_data(self, conf):
        man = get_manager()

        assert isinstance(conf, confs.TrainNNConfig)

        step = conf.next_step - 1

        starting_step = conf.starting_step
        if starting_step < 0:
            starting_step = max(step + starting_step, 0)

        while step >= starting_step:
            store_path = man.samples_path(conf.game, conf.generation_prefix)
            fn = os.path.join(store_path,
                              "gendata_%s_%s.json.gz" % (conf.game, step))
            if fn not in self.sample_data_cache:
                raw_data = attrutil.json_to_attr(gzip.open(fn).read())
                data = SamplesData(raw_data.game, raw_data.with_generation,
                                   raw_data.num_samples)

                total_draws = 0
                for s in raw_data.samples:
                    if abs(s.final_score[0] - 0.5) < 0.01:
                        total_draws += 1

                draws_ratio = total_draws / float(len(raw_data.samples))
                log.info("Draws ratio %.2f" % draws_ratio)

                for s in raw_data.samples:
                    data.add_sample(s)

                if len(data.samples) != data.num_samples:
                    # pretty inconsequential, but we should at least notify
                    msg = "num_samples (%d) versus actual samples (%s) differ... trimming"
                    log.warning(msg % (data.num_samples, len(data.samples)))

                    data.num_samples = min(len(data.samples), data.num_samples)
                    data.samples = data.samples[:data.num_samples]

                self.sample_data_cache[fn] = data

            yield fn, self.sample_data_cache[fn]
            step -= 1
示例#25
0
    def check_files_exist(self):
        # first check that the directories exist
        man = get_manager()
        for p in (man.model_path(self.conf.game),
                  man.weights_path(self.conf.game),
                  man.generation_path(self.conf.game),
                  man.samples_path(self.conf.game,
                                   self.conf.generation_prefix)):

            if os.path.exists(p):
                if not os.path.isdir(p):
                    critical_error("Path exists and not directory: %s")
            else:
                log.warning("Attempting to create path: %s" % p)
                os.makedirs(p)
                if not os.path.exists(p) or not os.path.isdir(p):
                    critical_error("Failed to create directory: %s" % p)

        self.check_nn_files_exist()
示例#26
0
    def get_network(self, nn_model_config, generation_descr):
        # abbreviate, easier on the eyes
        conf = self.train_config

        attrutil.pprint(nn_model_config)

        man = get_manager()
        if man.can_load(conf.game, self.next_generation):
            msg = "Generation already exists %s / %s" % (conf.game,
                                                         self.next_generation)
            log.error(msg)
            if not conf.overwrite_existing:
                raise TrainException("Generation already exists %s / %s" %
                                     (conf.game, self.next_generation))
        nn = None
        retraining = False
        if conf.use_previous:
            # default to next_generation_prefix, otherwise use conf.generation_descr
            candidates = [self.next_generation_prefix]
            if conf.generation_prefix != self.next_generation_prefix:
                candidates.append(conf.generation_prefix)
            for gen in candidates:
                prev_generation = "%s_%s" % (gen, conf.next_step - 1)

                if man.can_load(conf.game, prev_generation):
                    log.info("Previous generation found: %s" % prev_generation)
                    nn = man.load_network(conf.game, prev_generation)
                    retraining = True
                    break

                else:
                    log.warning("Previous generation %s not found..." %
                                (prev_generation))

        if nn is None:
            nn = man.create_new_network(conf.game, nn_model_config,
                                        generation_descr)

        nn.summary()

        self.nn = nn
        self.retraining = retraining
        log.info("Network %s, retraining: %s" % (self.nn, self.retraining))
示例#27
0
    def check_summary(self):
        ' checks summary against existing files '
        total_samples = 0

        try:
            if self.summary.game != self.transformer.game:
                raise Check("Game not same %s/%s" %
                            (self.summary.game, self.transformer.game))

            expect = 0
            for step_sum, (step, file_path,
                           md5sum) in zip(self.summary.step_summaries,
                                          self.list_files()):

                # special case exception, this should never happen!
                if step_sum.step != expect:
                    raise Exception(
                        "Weird - step_sum.step != expect, please check %s/%s" %
                        (step_sum.step, expect))

                if step_sum.step != step:
                    raise Check("step_sum(%d) != step(%d)" %
                                (step_sum.step, step))

                if step_sum.md5sum != md5sum:
                    msg = "Summary check: for file %s, md5sum(%s) != md5sum(%s)" % (
                        file_path, step_sum.md5sum, md5sum)
                    log.warning(msg)
                    # raise Check(msg)

                total_samples += step_sum.num_samples
                expect += 1

            if self.summary.total_samples != total_samples:
                raise Check("Total samples mismatch %s != %s" %
                            (self.summary.total_samples, total_samples))

        except Check as exc:
            log.error("Summary check failed: %s" % exc)
            return False

        return True
示例#28
0
    def add_new_samples(self, samples, dedupe=True):
        dupe_counts = Counter()
        dropped_dupe_count = 0
        dropped_draw_count = 0

        num_matches = 0
        cur_match = None
        cur_match_is_draw = False
        for sample in samples:
            if sample.match_identifier != cur_match:
                cur_match = sample.match_identifier
                num_matches += 1
                cur_match_is_draw = math.fabs(sample.final_score[0] -
                                              0.5) < 0.01

            state = sample.state

            if cur_match_is_draw:
                if random.random() > 0.5:
                    dropped_draw_count += 1
                    if dedupe:
                        continue

            # need to check it isn't a duplicate and drop it
            if state in self.unique_states_set:
                dupe_counts[sample.depth] += 1

                if dupe_counts[sample.depth] > 1:
                    if random.random() > 1.25 / dupe_counts[sample.depth]:
                        dropped_dupe_count += 1
                        if dedupe:
                            continue
            else:
                self.unique_states_set.add(state)
                self.unique_states.append(state)

            self.accumulated_samples.append(sample)

        log.info("Rx'd matches %s" % num_matches)
        if dropped_dupe_count or dropped_draw_count:
            log.warning("duplicate: %s, dropped dupe %s, dropped_draw %s" %
                        (dupe_counts, dropped_dupe_count, dropped_draw_count))
示例#29
0
    def roll_generation(self):
        # training is done
        self.conf.current_step += 1
        self.check_nn_files_exist()

        # reconfigure player workers
        for _, info in self.workers.items():
            info.reset()

        self.create_self_play_config()

        # rotate these
        self.accumulated_samples = self.accumulated_samples[
            self.conf.num_samples_to_train:]
        self.unique_states = self.unique_states[self.conf.
                                                num_samples_to_train:]
        self.unique_states_set = set(self.unique_states)

        self.checkpoint()

        assert len(self.unique_states) == len(self.unique_states_set)

        if not self.conf.base_training_config.use_previous:
            log.warning("use_previous was False, setting to True")
            self.conf.base_training_config.use_previous = True

        # store the server config
        self.save_our_config(rolled=True)

        self.pending_gen_samples = None
        self.training_in_progress = False

        if self.the_nn_trainer.conf.do_self_play:
            if self.the_nn_trainer not in self.free_players:
                self.free_players.append(self.the_nn_trainer)

        log.warning(
            "roll_generation() complete.  We have %s samples leftover" %
            len(self.accumulated_samples))

        self.schedule_players()
示例#30
0
文件: server.py 项目: vipmath/ggplib
    def handle_play(self, symbols):
        assert len(symbols) == 3
        match_id = symbols[1]
        if self.current_match is None:
            log.warning("rx'd play for non-current match %s" % match_id)
            return "busy"
        if self.current_match.match_id != match_id:
            log.error("rx'd play different from current match (%s != %s)" % (match_id, self.current_match.match_id))
            return "busy"

        move = symbols[2]

        if isinstance(move, ListTerm):
            move = list(move)
        else:
            assert move.lower() == 'nil', "Move is %s" % move
            move = None

        # update gameserver timeout
        self.update_gameserver_timeout(self.current_match.move_time)
        return self.current_match.do_play(move)