def save(self): # XXX set generation attributes man = get_manager() man.save_network(self.nn, generation_name=self.next_generation) self.do_callbacks() ############################################################################### # save a previous model for next time if self.controller.retrain_best is None: log.warning("No retraining network") return log.info("Saving retraining network with val_policy_acc: %.4f" % (self.controller.retrain_best_val_policy_acc)) # there is an undocumented keras clone function, but this is sure to work (albeit slow and evil) from ggpzero.util.keras import keras_models for_next_generation = "%s_prev" % self.next_generation prev_model = keras_models.model_from_json( self.nn.keras_model.to_json()) prev_model.set_weights(self.controller.retrain_best) prev_generation_descr = attrutil.clone(self.nn.generation_descr) prev_generation_descr.name = for_next_generation prev_nn = network.NeuralNetwork(self.nn.gdl_bases_transformer, prev_model, prev_generation_descr) man.save_network(prev_nn, for_next_generation) self.do_callbacks()
def load_network_fixme(self, game): import glob import datetime p = os.path.join(self.data_path, game, "models") os.chdir(p) gens = glob.glob("%s_*" % game) for g in gens: print "doing", game, g generation = os.path.splitext(g)[0] new_style_gen = generation.replace(game + "_", "") print generation, new_style_gen # dummy generation_descr generation_descr = templates.default_generation_desc(game) json_str = open(self.model_path(game, generation)).read() keras_model = keras_models.model_from_json(json_str) keras_model.load_weights(self.weights_path(game, generation)) transformer = self.get_transformer(game, generation_descr) print transformer, keras_model, generation_descr nn = NeuralNetwork(transformer, keras_model, generation_descr) generation_descr.name = new_style_gen generation_descr.trained_losses = "unknown" generation_descr.trained_validation_losses = "unknown" generation_descr.trained_policy_accuracy = "unknown" generation_descr.trained_value_accuracy = "unknown" ctime = os.stat(self.model_path(game, generation)).st_ctime generation_descr.date_created = datetime.datetime.fromtimestamp( ctime).strftime("%Y/%m/%d %H:%M") print generation_descr self.save_network(nn)
def load_network(self, game, generation_name): json_str = open(self.generation_path(game, generation_name)).read() generation_descr = attrutil.json_to_attr(json_str) json_str = open(self.model_path(game, generation_name)).read() keras_model = keras_models.model_from_json(json_str) keras_model.load_weights(self.weights_path(game, generation_name)) transformer = self.get_transformer(game, generation_descr) return NeuralNetwork(transformer, keras_model, generation_descr)