コード例 #1
0
    def save(self):
        # XXX set generation attributes

        man = get_manager()

        man.save_network(self.nn, generation_name=self.next_generation)
        self.do_callbacks()

        ###############################################################################
        # save a previous model for next time
        if self.controller.retrain_best is None:
            log.warning("No retraining network")
            return

        log.info("Saving retraining network with val_policy_acc: %.4f" %
                 (self.controller.retrain_best_val_policy_acc))

        # there is an undocumented keras clone function, but this is sure to work (albeit slow and evil)
        from ggpzero.util.keras import keras_models

        for_next_generation = "%s_prev" % self.next_generation

        prev_model = keras_models.model_from_json(
            self.nn.keras_model.to_json())
        prev_model.set_weights(self.controller.retrain_best)

        prev_generation_descr = attrutil.clone(self.nn.generation_descr)
        prev_generation_descr.name = for_next_generation
        prev_nn = network.NeuralNetwork(self.nn.gdl_bases_transformer,
                                        prev_model, prev_generation_descr)
        man.save_network(prev_nn, for_next_generation)
        self.do_callbacks()
コード例 #2
0
    def load_network_fixme(self, game):
        import glob
        import datetime
        p = os.path.join(self.data_path, game, "models")
        os.chdir(p)
        gens = glob.glob("%s_*" % game)

        for g in gens:
            print "doing", game, g
            generation = os.path.splitext(g)[0]
            new_style_gen = generation.replace(game + "_", "")

            print generation, new_style_gen

            # dummy generation_descr
            generation_descr = templates.default_generation_desc(game)

            json_str = open(self.model_path(game, generation)).read()
            keras_model = keras_models.model_from_json(json_str)

            keras_model.load_weights(self.weights_path(game, generation))
            transformer = self.get_transformer(game, generation_descr)
            print transformer, keras_model, generation_descr
            nn = NeuralNetwork(transformer, keras_model, generation_descr)
            generation_descr.name = new_style_gen
            generation_descr.trained_losses = "unknown"
            generation_descr.trained_validation_losses = "unknown"
            generation_descr.trained_policy_accuracy = "unknown"
            generation_descr.trained_value_accuracy = "unknown"
            ctime = os.stat(self.model_path(game, generation)).st_ctime

            generation_descr.date_created = datetime.datetime.fromtimestamp(
                ctime).strftime("%Y/%m/%d %H:%M")
            print generation_descr
            self.save_network(nn)
コード例 #3
0
    def load_network(self, game, generation_name):
        json_str = open(self.generation_path(game, generation_name)).read()
        generation_descr = attrutil.json_to_attr(json_str)

        json_str = open(self.model_path(game, generation_name)).read()
        keras_model = keras_models.model_from_json(json_str)

        keras_model.load_weights(self.weights_path(game, generation_name))
        transformer = self.get_transformer(game, generation_descr)
        return NeuralNetwork(transformer, keras_model, generation_descr)