Example #1
0
    def on_configure(self, server, msg):
        attrutil.pprint(msg)

        if self.game_info is None:
            self.game_info = lookup.by_name(msg.game)
            self.sm = self.game_info.get_sm()

        else:
            self.game_info.game == msg.game

        self.self_play_conf = msg.self_play_conf
        self.latest_generation_name = msg.generation_name

        # refresh the neural network.  May have to run some commands to get it.
        self.nn = None
        try:
            self.nn = get_manager().load_network(self.game_info.game,
                                                 self.latest_generation_name)
            self.configure_self_play()

        except Exception as exc:
            log.error("error in on_configure(): %s" % exc)
            for l in traceback.format_exc().splitlines():
                log.error(l)

            self.cmds_running = runprocs.RunCmds(
                self.conf.run_post_training_cmds,
                cb_on_completion=self.finished_cmds_running,
                max_time=180.0)
            self.cmds_running.spawn()

        return msgs.Ok("configured")
Example #2
0
def get_database(verbose=True):
    global the_database

    def add_game_to_db(game, sm, model):
        info = GameInfoBypass(game, sm, model)
        the_database.game_mapping[game] = info

    if the_database is None:
        from ggplib.db.store import get_root
        the_database = GameDatabase(get_root())
        the_database.load(verbose=verbose)

        try:
            install_draughts(add_game_to_db)

        except Exception as err:
            log.error("Failed to install draughts: %s" % err)

        try:
            install_hex(add_game_to_db)

        except Exception as err:
            log.error("Failed to install hex: %s" % err)

    return the_database
Example #3
0
    def cleanup(self, keep_sm=False):
        try:
            self.player.cleanup()
            if self.verbose:
                log.verbose("done cleanup player: %s" % self.player)
        except Exception as exc:
            log.error("FAILED TO CLEANUP PLAYER: %s" % exc)
            type, value, tb = sys.exc_info()
            log.error(traceback.format_exc())

        # cleanup c++ stuff
        if self.verbose:
            log.warning("cleaning up c++ stuff")

        # all the basestates
        for bs in self.states:
            # cleanup bs
            interface.dealloc_basestate(bs)

        self.states = []

        if self.joint_move:
            interface.dealloc_jointmove(self.joint_move)
            self.joint_move = None

        if self.sm and not keep_sm:
            interface.dealloc_statemachine(self.sm)
            self.sm = None

        if self.verbose:
            log.info("match - done cleaning up")
Example #4
0
    def handle_stop(self, symbols):
        assert len(symbols) == 3
        match_id = symbols[1]
        if self.current_match is None:
            log.warning("rx'd 'stop' for non-current match %s" % match_id)
            return "busy"
        if self.current_match.match_id != match_id:
            log.error("rx'd 'stop' different from current match (%s != %s)" % (match_id, self.current_match.match_id))
            return "busy"

        move = symbols[2]

        # XXX bug with standford 'player checker'??? XXX need to find out what is going on here?
        if isinstance(move, str) and move.lower != "nil":
            move = self.symbol_factory.symbolize("( %s )" % move)

        res = self.current_match.do_play(move)
        if res != "done":
            log.error("Game was NOT done %s" % self.current_match.match_id)

        else:
            # cancel any timeout callbacks
            self.update_gameserver_timeout(None)
            self.current_match.do_stop()

        self.current_match = None
        return "done"
Example #5
0
def lookup_all_games():
    # ensure things are initialised
    from ggplib.util.init import setup_once
    setup_once()

    failed = []
    known_to_fail = [
        'amazonsTorus_10x10', 'atariGoVariant_7x7', 'gt_two_thirds_4p',
        'gt_two_thirds_6p', 'linesOfAction'
    ]
    for game in lookup.get_all_game_names():
        if game not in known_to_fail:
            try:
                game_info = lookup.by_name(game, build_sm=False)
                assert game_info.game == game
                sm = game_info.get_sm()

                # run some depth charges to ensure we have valid statemachine
                interface.depth_charge(sm, 1)

                log.verbose("DONE GETTING Statemachine FOR GAME %s %s" %
                            (game, sm))
            except lookup.LookupFailed as exc:
                log.warning("Failed to lookup %s: %s" % (game, exc))
                failed.append(game)

    if failed:
        log.error("Failed games %s" % (failed, ))
        assert False, failed
Example #6
0
    def on_configure(self, server, msg):
        attrutil.pprint(msg)

        if self.game_info is None:
            self.game_info = lookup.by_name(msg.game)
            self.sm = self.game_info.get_sm()

        else:
            if self.game_info.game != msg.game:
                log.critical("Game changed to %s" % msg.game)
                sys.exit(1)

        self.self_play_conf = msg.self_play_conf
        self.latest_generation_name = msg.generation_name

        # refresh the neural network.  May have to run some commands to get it.
        self.nn = None
        while self.nn is None:
            try:
                self.nn = get_manager().load_network(self.game_info.game,
                                                     self.latest_generation_name)

            except Exception as exc:
                log.error("error in on_configure(): %s" % exc)
                for l in traceback.format_exc().splitlines():
                    log.error(l)
                time.sleep(1.0)

        self.configure_self_play()
        return msgs.Ok("configured")
Example #7
0
    def checkpoint(self):
        num_samples = len(self.accumulated_samples)
        log.verbose("entering checkpoint with %s sample accumulated" %
                    num_samples)
        if num_samples > 0:
            gen_samples = self.save_sample_data()

            if num_samples > self.conf.num_samples_to_train:
                if self.pending_gen_samples is None:
                    log.info(
                        "data done for: %s" %
                        self.get_generation_name(self.conf.current_step + 1))
                    self.pending_gen_samples = gen_samples

                if not self.training_in_progress:
                    if self.the_nn_trainer is None:
                        log.error("There is no trainer - please start")
                    else:
                        self.send_request_to_train_nn()

        # cancel any existing cb
        if self.checkpoint_cb is not None and self.checkpoint_cb.active():
            self.checkpoint_cb.cancel()

        # call checkpoint again in n seconds
        self.checkpoint_cb = reactor.callLater(self.conf.checkpoint_interval,
                                               self.checkpoint)
Example #8
0
def by_gdl(gdl):
    try:
        gdl_str = gdl
        if not isinstance(gdl, str):
            lines = []
            for s in gdl:
                lines.append(str(s))
            gdl_str = "\n".join(lines)

        db = get_database()
        try:
            info, mapping = db.lookup(gdl_str)

        except LookupFailed as exc:
            etype, value, tb = sys.exc_info()
            traceback.print_exc()
            raise LookupFailed("Did not find game %s" % exc)

        return mapping, info

    except Exception as exc:
        # creates temporary files
        log.error("Lookup failed: %s" % exc)

        model, sm = builder.build_sm(gdl)
        info = TempGameInfo("unknown", gdl, sm, model)
        return None, info
Example #9
0
    def verify_db(self):
        ' checks summary against existing files '
        db_path = os.path.join(self.data_path, "__db__")
        if not os.path.exists(db_path):
            return False

        try:
            self.db = bcolz.open(db_path, mode='a')

            # XXX check columns are correct types

            if self.summary.total_samples != self.db.size:
                msg = "db and summary file different sizes summary: %s != %s" % (
                    self.db.size, self.summary.total_samples)
                log.warning(msg)
                if self.db.size > self.summary.total_samples:
                    log.warning("resizing")
                    self.db.resize(self.summary.total_samples)

                else:
                    raise Check(msg)

        except Exception as exc:
            log.error("error accessing db directory: %s" % exc)
            return False

        return True
Example #10
0
def by_name(name, build_sm=True):
    try:
        db = get_database(verbose=False)
        return db.get_by_name(name)

    except Exception as exc:
        # creates temporary files
        msg = "Lookup of %s failed: %s" % (name, exc)
        log.error(msg)
        log.error(traceback.format_exc())
        raise LookupFailed(msg)
Example #11
0
 def init_data_rxd(self, data):
     self.start_buf += data
     if len(self.start_buf) == self.CHALLENGE_SIZE:
         if self.expected_response == self.start_buf:
             self.logical_connection = True
             log.info("Logical connection made")
             self.broker.new_broker_client(self)
         else:
             self.logical_connection = True
             log.error("Logical connection failed")
             self.disconnect()
Example #12
0
    def handle_abort(self, symbols):
        assert len(symbols) == 2
        match_id = symbols[1]

        if self.current_match is None:
            log.warning("rx'd 'abort' for non-current match %s" % match_id)
            return "busy"

        if self.current_match.match_id != match_id:
            log.error("rx'd 'abort' different from current match (%s != %s)" % (match_id, self.current_match.match_id))
            return "busy"

        # cancel any timeout callbacks
        self.abort()
        return "aborted"
Example #13
0
    def compile(self, compile_strategy, learning_rate=None, value_weight=1.0):
        value_objective = "mean_squared_error"
        policy_objective = 'categorical_crossentropy'
        if compile_strategy == "SGD":
            if learning_rate:
                optimizer = SGD(lr=learning_rate, momentum=0.9)
            else:
                optimizer = SGD(lr=1e-2, momentum=0.9)

        elif compile_strategy == "adam":
            if learning_rate:
                optimizer = Adam(lr=learning_rate)
            else:
                optimizer = Adam()

        elif compile_strategy == "amsgrad":
            if learning_rate:
                optimizer = Adam(lr=learning_rate, amsgrad=True)
            else:
                optimizer = Adam(amsgrad=True)

        else:
            log.error("UNKNOWN compile strategy %s" % compile_strategy)
            raise Exception("UNKNOWN compile strategy %s" % compile_strategy)

        num_policies = len(self.gdl_bases_transformer.policy_dist_count)

        loss = [policy_objective] * num_policies
        loss.append(value_objective)
        loss_weights = [1.0] * num_policies
        loss_weights.append(value_weight)

        if learning_rate is not None:
            msg = "Compiling with %s (learning_rate=%.4f, value_weight=%.3f)"
            log.warning(msg % (optimizer, learning_rate, value_weight))
        else:
            log.warning("Compiling with %s (value_weight=%.3f)" %
                        (optimizer, value_weight))

        def top_3_acc(y_true, y_pred):
            return keras_metrics.top_k_categorical_accuracy(y_true,
                                                            y_pred,
                                                            k=3)

        self.keras_model.compile(loss=loss,
                                 optimizer=optimizer,
                                 loss_weights=loss_weights,
                                 metrics=["acc", top_3_acc])
Example #14
0
    def get_network(self, nn_model_config, generation_descr):
        # abbreviate, easier on the eyes
        conf = self.train_config

        attrutil.pprint(nn_model_config)

        man = get_manager()
        if man.can_load(conf.game, self.next_generation):
            msg = "Generation already exists %s / %s" % (conf.game,
                                                         self.next_generation)
            log.error(msg)
            if not conf.overwrite_existing:
                raise TrainException("Generation already exists %s / %s" %
                                     (conf.game, self.next_generation))
        nn = None
        retraining = False
        if conf.use_previous:
            # default to next_generation_prefix, otherwise use conf.generation_descr
            candidates = [self.next_generation_prefix]
            if conf.generation_prefix != self.next_generation_prefix:
                candidates.append(conf.generation_prefix)
            for gen in candidates:
                prev_generation = "%s_%s" % (gen, conf.next_step - 1)

                if man.can_load(conf.game, prev_generation):
                    log.info("Previous generation found: %s" % prev_generation)
                    nn = man.load_network(conf.game, prev_generation)
                    retraining = True
                    break

                else:
                    log.warning("Previous generation %s not found..." %
                                (prev_generation))

        if nn is None:
            nn = man.create_new_network(conf.game, nn_model_config,
                                        generation_descr)

        nn.summary()

        self.nn = nn
        self.retraining = retraining
        log.info("Network %s, retraining: %s" % (self.nn, self.retraining))
Example #15
0
    def check_summary(self):
        ' checks summary against existing files '
        total_samples = 0

        try:
            if self.summary.game != self.transformer.game:
                raise Check("Game not same %s/%s" %
                            (self.summary.game, self.transformer.game))

            expect = 0
            for step_sum, (step, file_path,
                           md5sum) in zip(self.summary.step_summaries,
                                          self.list_files()):

                # special case exception, this should never happen!
                if step_sum.step != expect:
                    raise Exception(
                        "Weird - step_sum.step != expect, please check %s/%s" %
                        (step_sum.step, expect))

                if step_sum.step != step:
                    raise Check("step_sum(%d) != step(%d)" %
                                (step_sum.step, step))

                if step_sum.md5sum != md5sum:
                    msg = "Summary check: for file %s, md5sum(%s) != md5sum(%s)" % (
                        file_path, step_sum.md5sum, md5sum)
                    log.warning(msg)
                    # raise Check(msg)

                total_samples += step_sum.num_samples
                expect += 1

            if self.summary.total_samples != total_samples:
                raise Check("Total samples mismatch %s != %s" %
                            (self.summary.total_samples, total_samples))

        except Check as exc:
            log.error("Summary check failed: %s" % exc)
            return False

        return True
Example #16
0
    def handle_play(self, symbols):
        assert len(symbols) == 3
        match_id = symbols[1]
        if self.current_match is None:
            log.warning("rx'd play for non-current match %s" % match_id)
            return "busy"
        if self.current_match.match_id != match_id:
            log.error("rx'd play different from current match (%s != %s)" % (match_id, self.current_match.match_id))
            return "busy"

        move = symbols[2]

        if isinstance(move, ListTerm):
            move = list(move)
        else:
            assert move.lower() == 'nil', "Move is %s" % move
            move = None

        # update gameserver timeout
        self.update_gameserver_timeout(self.current_match.move_time)
        return self.current_match.do_play(move)
Example #17
0
    def handle(self, request):
        content = request.content.getvalue()

        # Tiltyard seems to ping with empty content...
        if content == "":
            return self.handle_info()

        try:
            symbols = list(self.symbol_factory.symbolize(content))

            # get head
            if len(symbols) == 0:
                log.warning('Empty symbols')
                return self.handle_info()

            head = symbols[0]
            if head.lower() == "info":
                res = self.handle_info()

            elif head.lower() == "start":
                log.debug("HEADERS : %s" % pprint.pformat(request.getAllHeaders()))
                log.debug(str(symbols))
                res = self.handle_start(symbols)

            elif head.lower() == "play":
                log.debug(str(symbols))
                res = self.handle_play(symbols)

            elif head.lower() == "stop":
                log.debug(str(symbols))
                res = self.handle_stop(symbols)

            elif head.lower() == "abort":
                log.debug(str(symbols))
                res = self.handle_abort(symbols)

            else:
                log.error("UNHANDLED REQUEST %s" % symbols)

        except Exception as exc:
            log.error("ERROR - aborting: %s" % exc)
            log.error(traceback.format_exc())

            if self.current_match:
                self.abort()

            res = "aborted"

        return res
Example #18
0
    def onMessage(self, caller, msg):
        if msg.name not in self.handlers:
            log.error("%s : unknown msg %s" % (caller, str(msg.name)))
            caller.disconnect()
            return

        try:
            cb = self.handlers[msg.name]
            res = cb(caller, msg.payload)

            # doesn't necessarily need to have a response
            if res is not None:
                caller.send_msg(res)

        except Exception as e:
            log.error("%s : exception calling method %s.  " %
                      (caller, str(msg.name)))
            log.error("%s" % e)
            log.error(traceback.format_exc())

            # do this last as might raise also...
            caller.disconnect()
Example #19
0
    def compile(self,
                compile_strategy,
                learning_rate=None,
                value_weight=1.0,
                l2_loss=None,
                l2_non_residual=True):
        # XXX allow l2_loss on final layers.

        value_objective = "mean_squared_error"
        policy_objective = 'categorical_crossentropy'
        if compile_strategy == "SGD":
            if learning_rate is None:
                learning_rate = 0.01
            optimizer = SGD(lr=learning_rate, momentum=0.9)

        elif compile_strategy == "adam":
            if learning_rate:
                optimizer = Adam(lr=learning_rate)
            else:
                optimizer = Adam()

        elif compile_strategy == "amsgrad":
            if learning_rate:
                optimizer = Adam(lr=learning_rate, amsgrad=True)
            else:
                optimizer = Adam(amsgrad=True)

        else:
            log.error("UNKNOWN compile strategy %s" % compile_strategy)
            raise Exception("UNKNOWN compile strategy %s" % compile_strategy)

        num_policies = len(self.gdl_bases_transformer.policy_dist_count)

        loss = [policy_objective] * num_policies
        loss.append(value_objective)
        loss_weights = [1.0] * num_policies
        loss_weights.append(value_weight)

        if learning_rate is not None:
            msg = "Compiling with %s (learning_rate=%.4f, value_weight=%.3f)"
            log.warning(msg % (optimizer, learning_rate, value_weight))
        else:
            log.warning("Compiling with %s (value_weight=%.3f)" %
                        (optimizer, value_weight))

        def top_3_acc(y_true, y_pred):
            return keras_metrics.top_k_categorical_accuracy(y_true,
                                                            y_pred,
                                                            k=3)

        if l2_loss is not None:
            log.warning("Applying l2 loss (%.5f)" % l2_loss)
            l2_loss = keras_regularizers.l2(l2_loss)

        rebuild_model = False
        for layer in self.keras_model.layers:
            # To get global weight decay in keras regularizers have to be added to every layer
            # in the model.

            if hasattr(layer, 'kernel_regularizer'):

                ignore = False
                if l2_non_residual:
                    ignore = True

                    if "policy" in layer.name or "value" in layer.name:
                        if "flatten" not in layer.name:
                            ignore = False
                else:
                    ignore = "_se_" in layer.name

                if ignore:
                    if layer.kernel_regularizer is not None:
                        log.warning(
                            "Ignoring but regularizer was set @ %s/%s.  Unsetting."
                            % (layer.name, layer))
                        layer.kernel_regularizer = None
                        rebuild_model = True

                    continue

                if l2_loss is not None and layer.kernel_regularizer is None:
                    rebuild_model = True
                    log.info("Applying l2 loss to %s/%s" % (layer.name, layer))
                    layer.kernel_regularizer = l2_loss

                if layer.kernel_regularizer is not None and l2_loss is None:
                    log.info("Unsetting l2 loss on %s/%s" %
                             (layer.name, layer))
                    rebuild_model = True
                    layer.kernel_regularizer = l2_loss

        # This ensures a fresh build of the network (there is no API to do this in keras, hence
        # this hacky workaround).  Furthermore, needing to rebuild the network here, before
        # compiling, is somewhat buggy/idiosyncrasy of keras.
        if rebuild_model:
            config = self.keras_model.get_config()
            weights = self.keras_model.get_weights()
            self.keras_model = keras_models.Model.from_config(config)
            self.keras_model.set_weights(weights)

        self.keras_model.compile(loss=loss,
                                 optimizer=optimizer,
                                 loss_weights=loss_weights,
                                 metrics=["acc", top_3_acc])