def save_sample_data(self): if self.training_in_progress: log.warning("skip writing json (gzipped): %s" % self.sample_data_filename) return gen_samples = datadesc.GenerationSamples() gen_samples.game = self.conf.game gen_samples.date_created = get_date_string() gen_samples.with_generation = self.get_generation_name( self.conf.current_step) # only save the minimal number for this run gen_samples.num_samples = min(len(self.accumulated_samples), self.conf.num_samples_to_train) gen_samples.samples = self.accumulated_samples[:gen_samples. num_samples] # write json file json.encoder.FLOAT_REPR = lambda f: ("%.5f" % f) log.info("writing json (gzipped): %s" % self.sample_data_filename) with gzip.open(self.sample_data_filename, 'w') as f: f.write(attrutil.attr_to_json(gen_samples, pretty=False)) return gen_samples
def handle_stop(self, symbols): assert len(symbols) == 3 match_id = symbols[1] if self.current_match is None: log.warning("rx'd 'stop' for non-current match %s" % match_id) return "busy" if self.current_match.match_id != match_id: log.error("rx'd 'stop' different from current match (%s != %s)" % (match_id, self.current_match.match_id)) return "busy" move = symbols[2] # XXX bug with standford 'player checker'??? XXX need to find out what is going on here? if isinstance(move, str) and move.lower != "nil": move = self.symbol_factory.symbolize("( %s )" % move) res = self.current_match.do_play(move) if res != "done": log.error("Game was NOT done %s" % self.current_match.match_id) else: # cancel any timeout callbacks self.update_gameserver_timeout(None) self.current_match.do_stop() self.current_match = None return "done"
def main_wrap(main_fn, logfile_name=None, **kwds): if logfile_name is None: # if logfile_name not set, derive it from main_fn fn = main_fn.func_code.co_filename logfile_name = os.path.splitext(os.path.basename(fn))[0] setup_once(logfile_name) try: # we might be running under python with no keras/numpy support init(**kwds) except ImportError as exc: log.warning("ImportError: %s" % exc) try: if main_fn.func_code.co_argcount == 0: return main_fn() else: return main_fn(sys.argv[1:]) except Exception as exc: print exc _, _, tb = sys.exc_info() traceback.print_exc() pdb.post_mortem(tb)
def lookup_all_games(): # ensure things are initialised from ggplib.util.init import setup_once setup_once() failed = [] known_to_fail = [ 'amazonsTorus_10x10', 'atariGoVariant_7x7', 'gt_two_thirds_4p', 'gt_two_thirds_6p', 'linesOfAction' ] for game in lookup.get_all_game_names(): if game not in known_to_fail: try: game_info = lookup.by_name(game, build_sm=False) assert game_info.game == game sm = game_info.get_sm() # run some depth charges to ensure we have valid statemachine interface.depth_charge(sm, 1) log.verbose("DONE GETTING Statemachine FOR GAME %s %s" % (game, sm)) except lookup.LookupFailed as exc: log.warning("Failed to lookup %s: %s" % (game, exc)) failed.append(game) if failed: log.error("Failed games %s" % (failed, )) assert False, failed
def init_spaces(self): base_infos = create_base_infos(self.game_info.model) self.board_space = self.create_board_space(base_infos) self.raw_channels_per_state = max(b.channel_id for b in self.board_space) + 1 self.control_space = self.create_control_space(base_infos) self.num_of_controls_channels = len(self.game_desc.control_channels) # warn about any unhandled states self.num_unhandled_states = 0 for b_info in base_infos: if not b_info.used: self.num_unhandled_states += 1 if self.num_unhandled_states: log.warning("Number of unhandled states %d" % self.num_unhandled_states) # sort by channel self.by_channel = {} for b in self.board_space: self.by_channel.setdefault(b.channel_id, []).append(b) for cs in self.control_space: self.by_channel.setdefault(cs.channel_id + self.raw_channels_per_state, []).append(cs) if self.verbose: for channel_id, all in self.by_channel.items(): print() print("channel_id: %s" % channel_id) for x in all: print("%s -> %s" % (base_infos[x.base_indx].terms, x))
def verify_db(self): ' checks summary against existing files ' db_path = os.path.join(self.data_path, "__db__") if not os.path.exists(db_path): return False try: self.db = bcolz.open(db_path, mode='a') # XXX check columns are correct types if self.summary.total_samples != self.db.size: msg = "db and summary file different sizes summary: %s != %s" % ( self.db.size, self.summary.total_samples) log.warning(msg) if self.db.size > self.summary.total_samples: log.warning("resizing") self.db.resize(self.summary.total_samples) else: raise Check(msg) except Exception as exc: log.error("error accessing db directory: %s" % exc) return False return True
def configure_self_play(self): assert self.self_play_conf is not None if self.nn is None: self.nn = get_manager().load_network(self.game_info.game, self.latest_generation_name) if self.supervisor is None: self.supervisor = cppinterface.Supervisor( self.sm, self.nn, batch_size=self.conf.self_play_batch_size, sleep_between_poll=self.conf.sleep_between_poll) self.supervisor.start_self_play(self.self_play_conf, self.conf.num_workers) else: # force exit of the worker if there was an update to the config if self.conf.exit_on_update_config: os._exit(0) log.info("Latest generation: %s" % self.latest_generation_name) gen = int(self.latest_generation_name.split("_")[-1]) if gen % self.conf.replace_network_every_n_gens == 0: log.warning("Updating network to: %s" % gen) self.supervisor.update_nn(self.nn) self.supervisor.clear_unique_states()
def load_module(kif_filename): ''' attempts to load a python module with the same filename. If it does not exist, will run java and use ggp-base to create the module. ''' basename, props_file = kif_filename_to_propfile(kif_filename) for cmd in [ "java -XX:+UseSerialGC -Xmx8G propnet_convert.Convert %s %s" % (kif_filename, props_file), "java propnet_convert.Convert %s %s" % (kif_filename, props_file), "SOMETHING IS BROKEN in install ..." ]: try: # rather unsafe cache, if kif file changes underneath our feet - tough luck. module = importlib.import_module("ggplib.props." + basename) break except ImportError: # run java ggp-base to create a propnet. The resultant propnet will be in props_dir, which can be imported. log.debug("Running: %s" % cmd) return_code, out, err = run(cmd, shell=True, timeout=60) if return_code != 0: log.warning("Error code: %s" % err) else: for l in out.splitlines(): log.info("... %s" % l) if "SOMETHING" in cmd: raise return module
def save(self): # XXX set generation attributes man = get_manager() man.save_network(self.nn, generation_name=self.next_generation) self.do_callbacks() ############################################################################### # save a previous model for next time if self.controller.retrain_best is None: log.warning("No retraining network") return log.info("Saving retraining network with val_policy_acc: %.4f" % (self.controller.retrain_best_val_policy_acc)) # there is an undocumented keras clone function, but this is sure to work (albeit slow and evil) from ggpzero.util.keras import keras_models for_next_generation = "%s_prev" % self.next_generation prev_model = keras_models.model_from_json( self.nn.keras_model.to_json()) prev_model.set_weights(self.controller.retrain_best) prev_generation_descr = attrutil.clone(self.nn.generation_descr) prev_generation_descr.name = for_next_generation prev_nn = network.NeuralNetwork(self.nn.gdl_bases_transformer, prev_model, prev_generation_descr) man.save_network(prev_nn, for_next_generation) self.do_callbacks()
def run(args, verbose=False, cwd=None, shell=False, kill_tree=True, timeout=-1, env=None): ' Run a command with a timeout after which it will be forcibly killed. ' p = Popen(args, shell=shell, cwd=cwd, stdout=PIPE, stderr=PIPE, env=env) if timeout != -1: signal(SIGALRM, alarm_handler) alarm(timeout) try: stdout, stderr = p.communicate() if verbose: print stdout, stderr if timeout != -1: alarm(0) except Alarm, e: log.warning("Alarm triggered: %s" % e) pids = [p.pid] if kill_tree: pids.extend(get_process_children(p.pid)) for pid in pids: # process might have died before getting to this line # so wrap to avoid OSError: no such process try: log.warning("killing %s" % pid) kill(pid, SIGKILL) except OSError: pass return -9, '', ''
def cleanup(self, keep_sm=False): try: self.player.cleanup() if self.verbose: log.verbose("done cleanup player: %s" % self.player) except Exception as exc: log.error("FAILED TO CLEANUP PLAYER: %s" % exc) type, value, tb = sys.exc_info() log.error(traceback.format_exc()) # cleanup c++ stuff if self.verbose: log.warning("cleaning up c++ stuff") # all the basestates for bs in self.states: # cleanup bs interface.dealloc_basestate(bs) self.states = [] if self.joint_move: interface.dealloc_jointmove(self.joint_move) self.joint_move = None if self.sm and not keep_sm: interface.dealloc_statemachine(self.sm) self.sm = None if self.verbose: log.info("match - done cleaning up")
def on_meta_gaming(self, finish_time): if self.conf.verbose: log.info("PUCTPlayer, match id: %s" % self.match.match_id) if self.sm is None or "*" in self.conf.generation: if "*" in self.conf.generation: log.warning("Using recent generation %s" % self.conf.generation) game_info = self.match.game_info self.sm = game_info.get_sm() man = get_manager() gen = self.conf.generation self.nn = man.load_network(game_info.game, gen) self.poller = PlayPoller(self.sm, self.nn, attr.asdict(self.conf)) def get_noop_idx(actions): for idx, a in enumerate(actions): if "noop" in a: return idx assert False, "did not find noop" self.role0_noop_legal, self.role1_noop_legal = map( get_noop_idx, game_info.model.actions) self.poller.player_reset(self.match.game_depth)
def init(data_format='channels_first'): assert K.backend() == "tensorflow" if K.image_data_format() != data_format: was = K.image_data_format() K.set_image_data_format(data_format) log.warning("Changing image_data_format: %s -> %s" % (was, K.image_data_format())) constrain_resources_tf()
def handle(self, request): content = request.content.getvalue() # Tiltyard seems to ping with empty content... if content == "": return self.handle_info() try: symbols = list(self.symbol_factory.symbolize(content)) # get head if len(symbols) == 0: log.warning('Empty symbols') return self.handle_info() head = symbols[0] if head.lower() == "info": res = self.handle_info() elif head.lower() == "start": log.debug("HEADERS : %s" % pprint.pformat(request.getAllHeaders())) log.debug(str(symbols)) res = self.handle_start(symbols) elif head.lower() == "play": log.debug(str(symbols)) res = self.handle_play(symbols) elif head.lower() == "stop": log.debug(str(symbols)) res = self.handle_stop(symbols) elif head.lower() == "abort": log.debug(str(symbols)) res = self.handle_abort(symbols) else: log.error("UNHANDLED REQUEST %s" % symbols) except Exception as exc: log.error("ERROR - aborting: %s" % exc) log.error(traceback.format_exc()) if self.current_match: self.abort() res = "aborted" return res
def test_building_desc_variations(): for game, propnet in get_propnets(): log.warning("test_building_desc() for: %s" % game) desc = builder.build_standard_sm(propnet) pprint.pprint(desc) desc = builder.build_goals_only_sm(propnet) pprint.pprint(desc) desc = builder.build_combined_state_machine(propnet) pprint.pprint(desc) desc = builder.build_goalless_sm(propnet) pprint.pprint(desc)
def handle_abort(self, symbols): assert len(symbols) == 2 match_id = symbols[1] if self.current_match is None: log.warning("rx'd 'abort' for non-current match %s" % match_id) return "busy" if self.current_match.match_id != match_id: log.error("rx'd 'abort' different from current match (%s != %s)" % (match_id, self.current_match.match_id)) return "busy" # cancel any timeout callbacks self.abort() return "aborted"
def get_puct_config(gen, **kwds): eval_config = confs.PUCTEvaluatorConfig(verbose=True, puct_constant=0.85, puct_constant_root=3.0, dirichlet_noise_pct=-1, fpu_prior_discount=0.25, fpu_prior_discount_root=0.15, choose="choose_temperature", temperature=2.0, depth_temperature_max=10.0, depth_temperature_start=0, depth_temperature_increment=0.75, depth_temperature_stop=1, random_scale=1.0, max_dump_depth=2, top_visits_best_guess_converge_ratio=0.8, think_time=2.0, converged_visits=2000, batch_size=32) config = confs.PUCTPlayerConfig(name="puct", verbose=True, generation=gen, playouts_per_iteration=-1, playouts_per_iteration_noop=0, evaluator_config=eval_config) for k, v in kwds.items(): updated = False if at.has(eval_config, k): updated = True setattr(eval_config, k, v) if at.has(config, k): updated = True setattr(config, k, v) if not updated: log.warning("Unused setting %s:%s" % (k, v)) return config
def compile(self, compile_strategy, learning_rate=None, value_weight=1.0): value_objective = "mean_squared_error" policy_objective = 'categorical_crossentropy' if compile_strategy == "SGD": if learning_rate: optimizer = SGD(lr=learning_rate, momentum=0.9) else: optimizer = SGD(lr=1e-2, momentum=0.9) elif compile_strategy == "adam": if learning_rate: optimizer = Adam(lr=learning_rate) else: optimizer = Adam() elif compile_strategy == "amsgrad": if learning_rate: optimizer = Adam(lr=learning_rate, amsgrad=True) else: optimizer = Adam(amsgrad=True) else: log.error("UNKNOWN compile strategy %s" % compile_strategy) raise Exception("UNKNOWN compile strategy %s" % compile_strategy) num_policies = len(self.gdl_bases_transformer.policy_dist_count) loss = [policy_objective] * num_policies loss.append(value_objective) loss_weights = [1.0] * num_policies loss_weights.append(value_weight) if learning_rate is not None: msg = "Compiling with %s (learning_rate=%.4f, value_weight=%.3f)" log.warning(msg % (optimizer, learning_rate, value_weight)) else: log.warning("Compiling with %s (value_weight=%.3f)" % (optimizer, value_weight)) def top_3_acc(y_true, y_pred): return keras_metrics.top_k_categorical_accuracy(y_true, y_pred, k=3) self.keras_model.compile(loss=loss, optimizer=optimizer, loss_weights=loss_weights, metrics=["acc", top_3_acc])
def on_sample_response(self, worker, msg): info = self.workers[worker] if len(msg.samples) > 0: dupe_count = self.add_new_samples(msg.samples) if dupe_count: log.warning("dropping %s inflight duplicate state(s)" % dupe_count) if msg.duplicates_seen: log.info("worker saw %s duplicates" % msg.duplicates_seen) log.info("len accumulated_samples: %s" % len(self.accumulated_samples)) self.free_players.append(info) reactor.callLater(0, self.schedule_players)
def add_new_samples(self, samples): dupe_count = 0 for sample in samples: state = sample.state # need to check it isn't a duplicate and drop it if state in self.unique_states_set: dupe_count += 1 else: self.unique_states_set.add(state) self.unique_states.append(state) self.accumulated_samples.append(sample) if dupe_count: log.warning("seen %s duplicate state(s)" % dupe_count)
def load(self, verbose=True): if verbose: log.info("Building the database") filenames = self.rulesheets_store.listdir("*.kif") for fn in sorted(filenames): # skip tmp files if fn.startswith("tmp"): continue game = fn.replace(".kif", "") # get the gdl gdl_str = self.rulesheets_store.load_contents(fn) info = GameInfo(game, gdl_str) # first does the game directory exist? the_game_store = self.games_store.get_directory(game, create=True) if the_game_store.file_exists("sig.json"): info.idx = the_game_store.load_json("sig.json")['idx'] else: if verbose: log.verbose("Creating signature for %s" % game) info.get_symbol_map() if info.symbol_map is None: log.warning("FAILED to add: %s" % game) raise Exception("FAILED TO add %s" % game) # save as json assert info.idx is not None the_game_store.save_json("sig.json", dict(idx=info.idx)) assert info.idx is not None if info.idx in self.idx_mapping: other_info = self.idx_mapping[info.idx] log.warning("DUPE GAMES: %s %s!=%s" % (info.idx, game, other_info.game)) raise Exception("Dupes not allowed in database") self.idx_mapping[info.idx] = info self.game_mapping[info.game] = info
def create_db(self): ' delete existing bcolz db (warn) and then create a fresh ' db_path = os.path.join(self.data_path, "__db__") if os.path.exists(db_path): log.warning("Please delete old db") sys.exit(1) # these are columns for bcolz table cols = fake_columns(self.transformer) # and create a table self.db = bcolz.ctable(cols, names=["channels", "policy0", "policy1", "value"], rootdir=db_path) # remove the single row self.db.resize(0) self.db.flush() log.info("Created new db")
def schedule_players(self): if len(self.accumulated_samples) > self.conf.num_samples_to_train: # if we haven't started training yet, lets speed things up... if not self.training_in_progress: self.checkpoint() if not self.free_players: return new_free_players = [] for worker_info in self.free_players: if not worker_info.valid: continue if self.training_in_progress and worker_info.conf.do_training: # will be added back in at the end of training continue if not worker_info.self_play_configured: worker_info.worker.send_msg(self.configure_selfplay_msg) else: if self.need_more_samples(): updates = worker_info.get_and_update(self.unique_states) m = msgs.RequestSamples(updates) if updates: log.debug("sending request with %s updates" % len(updates)) worker_info.worker.send_msg(m) else: log.warning("capacity full! %d" % len(self.accumulated_samples)) new_free_players.append(worker_info) if time.time() > worker_info.ping_time_sent: self.do_ping(worker_info) self.free_players = new_free_players if self.free_players: reactor.callLater(10.0, self.schedule_players) if self.the_nn_trainer is None: log.warning("There is no nn trainer - please start")
def files_to_sample_data(self, conf): man = get_manager() assert isinstance(conf, confs.TrainNNConfig) step = conf.next_step - 1 starting_step = conf.starting_step if starting_step < 0: starting_step = max(step + starting_step, 0) while step >= starting_step: store_path = man.samples_path(conf.game, conf.generation_prefix) fn = os.path.join(store_path, "gendata_%s_%s.json.gz" % (conf.game, step)) if fn not in self.sample_data_cache: raw_data = attrutil.json_to_attr(gzip.open(fn).read()) data = SamplesData(raw_data.game, raw_data.with_generation, raw_data.num_samples) total_draws = 0 for s in raw_data.samples: if abs(s.final_score[0] - 0.5) < 0.01: total_draws += 1 draws_ratio = total_draws / float(len(raw_data.samples)) log.info("Draws ratio %.2f" % draws_ratio) for s in raw_data.samples: data.add_sample(s) if len(data.samples) != data.num_samples: # pretty inconsequential, but we should at least notify msg = "num_samples (%d) versus actual samples (%s) differ... trimming" log.warning(msg % (data.num_samples, len(data.samples))) data.num_samples = min(len(data.samples), data.num_samples) data.samples = data.samples[:data.num_samples] self.sample_data_cache[fn] = data yield fn, self.sample_data_cache[fn] step -= 1
def check_files_exist(self): # first check that the directories exist man = get_manager() for p in (man.model_path(self.conf.game), man.weights_path(self.conf.game), man.generation_path(self.conf.game), man.samples_path(self.conf.game, self.conf.generation_prefix)): if os.path.exists(p): if not os.path.isdir(p): critical_error("Path exists and not directory: %s") else: log.warning("Attempting to create path: %s" % p) os.makedirs(p) if not os.path.exists(p) or not os.path.isdir(p): critical_error("Failed to create directory: %s" % p) self.check_nn_files_exist()
def get_network(self, nn_model_config, generation_descr): # abbreviate, easier on the eyes conf = self.train_config attrutil.pprint(nn_model_config) man = get_manager() if man.can_load(conf.game, self.next_generation): msg = "Generation already exists %s / %s" % (conf.game, self.next_generation) log.error(msg) if not conf.overwrite_existing: raise TrainException("Generation already exists %s / %s" % (conf.game, self.next_generation)) nn = None retraining = False if conf.use_previous: # default to next_generation_prefix, otherwise use conf.generation_descr candidates = [self.next_generation_prefix] if conf.generation_prefix != self.next_generation_prefix: candidates.append(conf.generation_prefix) for gen in candidates: prev_generation = "%s_%s" % (gen, conf.next_step - 1) if man.can_load(conf.game, prev_generation): log.info("Previous generation found: %s" % prev_generation) nn = man.load_network(conf.game, prev_generation) retraining = True break else: log.warning("Previous generation %s not found..." % (prev_generation)) if nn is None: nn = man.create_new_network(conf.game, nn_model_config, generation_descr) nn.summary() self.nn = nn self.retraining = retraining log.info("Network %s, retraining: %s" % (self.nn, self.retraining))
def check_summary(self): ' checks summary against existing files ' total_samples = 0 try: if self.summary.game != self.transformer.game: raise Check("Game not same %s/%s" % (self.summary.game, self.transformer.game)) expect = 0 for step_sum, (step, file_path, md5sum) in zip(self.summary.step_summaries, self.list_files()): # special case exception, this should never happen! if step_sum.step != expect: raise Exception( "Weird - step_sum.step != expect, please check %s/%s" % (step_sum.step, expect)) if step_sum.step != step: raise Check("step_sum(%d) != step(%d)" % (step_sum.step, step)) if step_sum.md5sum != md5sum: msg = "Summary check: for file %s, md5sum(%s) != md5sum(%s)" % ( file_path, step_sum.md5sum, md5sum) log.warning(msg) # raise Check(msg) total_samples += step_sum.num_samples expect += 1 if self.summary.total_samples != total_samples: raise Check("Total samples mismatch %s != %s" % (self.summary.total_samples, total_samples)) except Check as exc: log.error("Summary check failed: %s" % exc) return False return True
def add_new_samples(self, samples, dedupe=True): dupe_counts = Counter() dropped_dupe_count = 0 dropped_draw_count = 0 num_matches = 0 cur_match = None cur_match_is_draw = False for sample in samples: if sample.match_identifier != cur_match: cur_match = sample.match_identifier num_matches += 1 cur_match_is_draw = math.fabs(sample.final_score[0] - 0.5) < 0.01 state = sample.state if cur_match_is_draw: if random.random() > 0.5: dropped_draw_count += 1 if dedupe: continue # need to check it isn't a duplicate and drop it if state in self.unique_states_set: dupe_counts[sample.depth] += 1 if dupe_counts[sample.depth] > 1: if random.random() > 1.25 / dupe_counts[sample.depth]: dropped_dupe_count += 1 if dedupe: continue else: self.unique_states_set.add(state) self.unique_states.append(state) self.accumulated_samples.append(sample) log.info("Rx'd matches %s" % num_matches) if dropped_dupe_count or dropped_draw_count: log.warning("duplicate: %s, dropped dupe %s, dropped_draw %s" % (dupe_counts, dropped_dupe_count, dropped_draw_count))
def roll_generation(self): # training is done self.conf.current_step += 1 self.check_nn_files_exist() # reconfigure player workers for _, info in self.workers.items(): info.reset() self.create_self_play_config() # rotate these self.accumulated_samples = self.accumulated_samples[ self.conf.num_samples_to_train:] self.unique_states = self.unique_states[self.conf. num_samples_to_train:] self.unique_states_set = set(self.unique_states) self.checkpoint() assert len(self.unique_states) == len(self.unique_states_set) if not self.conf.base_training_config.use_previous: log.warning("use_previous was False, setting to True") self.conf.base_training_config.use_previous = True # store the server config self.save_our_config(rolled=True) self.pending_gen_samples = None self.training_in_progress = False if self.the_nn_trainer.conf.do_self_play: if self.the_nn_trainer not in self.free_players: self.free_players.append(self.the_nn_trainer) log.warning( "roll_generation() complete. We have %s samples leftover" % len(self.accumulated_samples)) self.schedule_players()
def handle_play(self, symbols): assert len(symbols) == 3 match_id = symbols[1] if self.current_match is None: log.warning("rx'd play for non-current match %s" % match_id) return "busy" if self.current_match.match_id != match_id: log.error("rx'd play different from current match (%s != %s)" % (match_id, self.current_match.match_id)) return "busy" move = symbols[2] if isinstance(move, ListTerm): move = list(move) else: assert move.lower() == 'nil', "Move is %s" % move move = None # update gameserver timeout self.update_gameserver_timeout(self.current_match.move_time) return self.current_match.do_play(move)