def load_module(kif_filename): ''' attempts to load a python module with the same filename. If it does not exist, will run java and use ggp-base to create the module. ''' basename, props_file = kif_filename_to_propfile(kif_filename) for cmd in [ "java -XX:+UseSerialGC -Xmx8G propnet_convert.Convert %s %s" % (kif_filename, props_file), "java propnet_convert.Convert %s %s" % (kif_filename, props_file), "SOMETHING IS BROKEN in install ..." ]: try: # rather unsafe cache, if kif file changes underneath our feet - tough luck. module = importlib.import_module("ggplib.props." + basename) break except ImportError: # run java ggp-base to create a propnet. The resultant propnet will be in props_dir, which can be imported. log.debug("Running: %s" % cmd) return_code, out, err = run(cmd, shell=True, timeout=60) if return_code != 0: log.warning("Error code: %s" % err) else: for l in out.splitlines(): log.info("... %s" % l) if "SOMETHING" in cmd: raise return module
def on_request_samples(self, server, msg): self.on_request_samples_time = time.time() assert self.supervisor is not None self.samples = [] self.supervisor.reset_stats() log.debug("Got request for sample with number unique states %s" % len(msg.new_states)) # update duplicates for s in msg.new_states: self.supervisor.add_unique_state(decode_state(s)) start_time = time.time() self.supervisor.poll_loop(do_stats=True, cb=self.cb_from_superviser) msg = "#samp %d, pred()s %d/%d, py/pred/all %.1f/%.1f/%.1f" log.info( msg % (len(self.samples), self.supervisor.num_predictions_calls, self.supervisor.total_predictions, self.supervisor.acc_time_polling, self.supervisor.acc_time_prediction, time.time() - start_time)) m = msgs.RequestSampleResponse(self.samples, 0) server.send_msg(m)
def handle_start(self, symbols): assert len(symbols) == 6 match_id = symbols[1] role = symbols[2] gdl = symbols[3] meta_time = int(symbols[4]) move_time = int(symbols[5]) if self.current_match is not None: log.debug("GOT A START message for %s while already playing match" % match_id) return "busy" else: log.info("Starting new match %s" % match_id) # lookup game and create match gdl_symbol_mapping, game_info = lookup.by_gdl(gdl) self.current_match = match.Match(game_info, match_id, role, meta_time, move_time, self.player, cushion_time=CUSHION_TIME, gdl_symbol_mapping=gdl_symbol_mapping) try: # start gameserver timeout self.update_gameserver_timeout(self.current_match.meta_time) self.current_match.do_start() return "ready" except match.BadGame: return "busy"
def get_with_gdl(gdl, name_hint=""): # create a temporary file: name_hint += "__" + str(uuid.uuid4()) name_hint = name_hint.replace('-', '_') name_hint = name_hint.replace('.', '_') # ensure we have gdl symbolized if isinstance(gdl, str): gdl = symbols.SymbolFactory().to_symbols(gdl) # this is very very very likely to be unique, but perhaps we should still check XXX name = "tmp_%s" % name_hint fn = os.path.join(os.path.join(rulesheet_dir, name + ".kif")) log.debug('writing kif file %s' % fn) # write file f = open(fn, "w") for l in gdl: print >> f, l f.close() try: propnet = get_with_filename(fn) finally: # cleanup temp files afterwards basename, props_file = kif_filename_to_propfile(fn) os.remove(fn) os.remove(props_file) for f in glob.glob( os.path.join(props_dir, "__pycache__", basename) + '*.pyc'): os.remove(str(f)) return propnet
def on_request_samples(self, server, msg): self.on_request_samples_time = time.time() assert self.supervisor is not None self.samples = [] self.supervisor.reset_stats() log.debug("Got request for sample with number unique states %s" % len(msg.new_states)) # update duplicates for s in msg.new_states: # note we decode the string and set it rawly. using decode_state() was too slow. self.supervisor.add_unique_state(base64.decodestring(s)) start_time = time.time() self.supervisor.poll_loop(do_stats=True, cb=self.cb_from_superviser) msg = "#samp %d, pred()s %d/%d, py/pred/all %.1f/%.1f/%.1f" time_since_last = time.time() - start_time log.info(msg % (len(self.samples), self.supervisor.num_predictions_calls, self.supervisor.total_predictions, self.supervisor.acc_time_polling, self.supervisor.acc_time_prediction, time_since_last)) predicts_per_sec = self.supervisor.total_predictions / time_since_last log.info("Average pred p/s %.1f" % predicts_per_sec) m = msgs.RequestSampleResponse(self.samples, 0) server.send_msg(m)
def perform_mcs(self, finish_by): self.depth_charge_state.assign(self.match.get_current_state()) self.sm.update_bases(self.depth_charge_state) self.root = {} ls = self.sm.get_legal_state(self.match.our_role_index) our_choices = [ls.get_legal(ii) for ii in range(ls.get_count())] # now create some stats with depth charges for choice in our_choices: move = self.sm.legal_to_move(self.match.our_role_index, choice) self.root[choice] = MoveStat(choice, move, self.role_count) root_visits = 1 while True: if time.time() > finish_by: break if self.max_iterations > 0 and root_visits > self.max_iterations: break if len(our_choices) == 1: if root_visits > 100: break # return to current state self.depth_charge_state.assign(self.match.get_current_state()) self.sm.update_bases(self.depth_charge_state) assert not self.sm.is_terminal() # select and set our move choice = self.select_move(our_choices, root_visits, self.root) self.joint_move.set(self.match.our_role_index, choice) # and a random move from other players for idx, r in enumerate(self.sm.get_roles()): if idx != self.match.our_role_index: ls = self.sm.get_legal_state(idx) choices = [ ls.get_legal(ii) for ii in range(ls.get_count()) ] # only need to set this once :) self.joint_move.set( idx, choices[random.randrange(0, ls.get_count())]) # create a new state self.sm.next_state(self.joint_move, self.depth_charge_state) # do a depth charge, and update scores scores = self.do_depth_charge() self.root[choice].add(scores) # and update the number of visits root_visits += 1 log.debug("Total visits: %s" % root_visits)
def render_GET(self, request): try: augment_header(request.responseHeaders) return json.dumps(attr.asdict(self.obj)) except Exception as exc: log.debug("ERROR %s" % exc) return ""
def render_GET(self, request): try: obj = attrutil.json_to_attr(open(self.match_path).read()) augment_header(request.responseHeaders) return json.dumps(attr.asdict(obj)) except Exception as exc: log.debug("ERROR %s" % exc) return ""
def getChild(self, game, request): if game == 'summary': return self games = "breakthrough cittaceot checkers connectFour escortLatch hex reversi speedChess" if game in games.split(): return SummaryForGame(game) log.debug("Got GET request from: %s/%s" % (request.getClientIP(), request)) log.debug("HEADERS : %s" % pprint.pformat(request.getAllHeaders())) return File(self.path_to_viewer)
def handle_info(self): cur_time = time.time() # do info_counts or we get reports of "0 infos in the last minute" self.info_counts += 1 if cur_time - self.last_info_time > 60: log.debug("Got %s infos in last minute" % self.info_counts) self.info_counts = 0 self.last_info_time = cur_time if self.current_match is None: return "((name %s) (status available))" % self.player.get_name() else: return "((name %s) (status busy))" % self.player.get_name()
def handle(self, request): content = request.content.getvalue() # Tiltyard seems to ping with empty content... if content == "": return self.handle_info() try: symbols = list(self.symbol_factory.symbolize(content)) # get head if len(symbols) == 0: log.warning('Empty symbols') return self.handle_info() head = symbols[0] if head.lower() == "info": res = self.handle_info() elif head.lower() == "start": log.debug("HEADERS : %s" % pprint.pformat(request.getAllHeaders())) log.debug(str(symbols)) res = self.handle_start(symbols) elif head.lower() == "play": log.debug(str(symbols)) res = self.handle_play(symbols) elif head.lower() == "stop": log.debug(str(symbols)) res = self.handle_stop(symbols) elif head.lower() == "abort": log.debug(str(symbols)) res = self.handle_abort(symbols) else: log.error("UNHANDLED REQUEST %s" % symbols) except Exception as exc: log.error("ERROR - aborting: %s" % exc) log.error(traceback.format_exc()) if self.current_match: self.abort() res = "aborted" return res
def do_play(self, move): enter_time = time.time() if self.verbose: log.debug("do_play: %s" % (move, )) if move is not None: self.apply_move(move) current_state = self.get_current_state() if self.verbose: current_str = self.game_info.model.basestate_to_str(current_state) log.info("Current state : '%s'" % current_str) self.sm.update_bases(current_state) if self.sm.is_terminal(): return "done" end_time = enter_time + self.move_time if self.cushion_time > 0: end_time -= self.cushion_time legal_choice = self.player.on_next_move(end_time) # we have no idea what on_next_move() left the state machine. So reverting it back to # correct state here. self.sm.update_bases(self.get_current_state()) # get possible possible legal moves and check 'move' is a valid ls = self.sm.get_legal_state(self.our_role_index) # store last move (in our own mapping, *not* gamemaster) self.last_played_move = self.sm.legal_to_move(self.our_role_index, legal_choice) # check the move remaps and is a legal choice move = self.legal_to_gamemaster_move(legal_choice) legal_moves = [ self.legal_to_gamemaster_move(ls.get_legal(ii)) for ii in range(ls.get_count()) ] if move not in legal_moves: msg = "Choice was %s not in legal choices %s" % (move, legal_moves) log.critical(msg) raise CriticalError(msg) if self.verbose: log.info("(%s) do_play '%s' sending move: %s" % (self.player.name, self.role, move)) return move
def schedule_players(self): if len(self.accumulated_samples) > self.conf.num_samples_to_train: # if we haven't started training yet, lets speed things up... if not self.training_in_progress: self.checkpoint() if not self.free_players: return new_free_players = [] for worker_info in self.free_players: if not worker_info.valid: continue if self.training_in_progress and worker_info.conf.do_training: # will be added back in at the end of training continue if not worker_info.self_play_configured: worker_info.worker.send_msg(self.configure_selfplay_msg) else: if self.need_more_samples(): updates = worker_info.get_and_update(self.unique_states) m = msgs.RequestSamples(updates) if updates: log.debug("sending request with %s updates" % len(updates)) worker_info.worker.send_msg(m) else: log.warning("capacity full! %d" % len(self.accumulated_samples)) new_free_players.append(worker_info) if time.time() > worker_info.ping_time_sent: self.do_ping(worker_info) self.free_players = new_free_players if self.free_players: reactor.callLater(10.0, self.schedule_players) if self.the_nn_trainer is None: log.warning("There is no nn trainer - please start")
def choose(self): assert self.root is not None best_score = -1 best_selection = None # ok - now we dump everything for debug, and return the best score for stat in sorted(self.root.values(), key=lambda x: x.get(self.match.our_role_index), reverse=True): score_str = " / ".join(("%.2f" % stat.get(ii)) for ii in range(self.role_count)) log.info("Move %s, visits %d, scored %s" % (stat.move, stat.visits, score_str)) s = stat.get(self.match.our_role_index) if s > best_score: best_score = s best_selection = stat assert best_selection is not None log.debug("choice move = %s" % best_selection.move) return best_selection.choice
def on_epoch_end(self, _, logs=None): epoch = self.at_epoch self.set_value_overfitting(logs) # deals with more than one head policy_acc, val_policy_acc = self.policy_acc(logs) self.last_policy_accuracy = "combined policy accuracy %.4f/%.4f" % ( policy_acc, val_policy_acc) log.debug(self.last_policy_accuracy) self.last_value_accuracy = "score accuracy %.4f / %.4f" % ( logs["value_acc"], logs["val_value_acc"]) # are we overitting? overfitting = policy_acc - 0.02 > val_policy_acc # store best weights as best val_policy_acc # allow small decrease allow_acc = self.best_val_policy_acc - 0.001 if (self.epoch_last_set_at is None or (val_policy_acc > allow_acc and not overfitting)): log.debug("Setting best to last val_policy_acc %.4f" % val_policy_acc) self.best = self.model.get_weights() self.best_val_policy_acc = max(val_policy_acc, self.best_val_policy_acc) self.epoch_last_set_at = epoch # stop training: if overfitting and self.stop_after_n_epochs_overfit: log.info("Early stopping... since policy accuracy overfitting") self.stop_training = True # if things havent got better - STOP. We can go on forever without improving. if (self.epoch_last_set_at is not None and epoch > self.epoch_last_set_at + self.stop_after_n_epochs_improving): log.info("Early stopping... since not improving (disabled)") self.stop_training = True
def gather_data(self): # abbreviate, easier on the eyes conf = self.train_config if self.samples_buffer is None: print "Recreating samples buffer" self.samples_buffer = SamplesBuffer() self.buckets = Buckets(conf.resample_buckets) total_samples = 0 leveled_data = [] for fn, sample_data in self.samples_buffer.files_to_sample_data(conf): assert sample_data.game == conf.game log.debug("Proccesing %s" % fn) log.debug("Game %s, with gen: %s and sample count %s" % (sample_data.game, sample_data.with_generation, sample_data.num_samples)) if not sample_data.transformed: # sample_data.verify_samples(self.game_info.get_sm()) sample_data.transform_all(self.transformer) level_data = LevelData(len(leveled_data)) for ins, outs in sample_data: level_data.add(ins, outs) log.verbose("Validation split") level_data.validation_split(conf.validation_split) leveled_data.append(level_data) total_samples += len(level_data) log.info("total samples: %s" % total_samples) return leveled_data
def check_running_processes(self): procs, self.procs = self.procs, [] for cmd, proc in procs: retcode = proc.poll() if retcode is not None: log.debug("cmd '%s' exited with return code: %s" % (cmd, retcode)) stdout, stderr = proc.stdout.read().strip(), proc.stderr.read( ).strip() if stdout: log.verbose("stdout:%s" % stdout) if stderr: log.warning("stderr:%s" % stderr) continue self.procs.append((cmd, proc)) if time.time() > self.timeout_time: for cmd, proc in self.procs: if cmd not in self.killing: self.killing.add(cmd) log.warning("cmd '%s' taking too long, terminating" % cmd) os.kill(proc.pid, SIGTERM) if time.time() > self.timeout_time + 1: for cmd, proc in self.procs: if cmd not in self.terminating: self.terminating.add(cmd) log.warning( "cmd '%s' didn't terminate gracefully, killing" % cmd) os.kill(proc.pid, SIGKILL) if self.procs: reactor.callLater(0.1, self.check_running_processes) else: self.cb_on_completion()
def on_epoch_end(self, _, logs=None): epoch = self.at_epoch self.set_value_overfitting(logs) # deals with more than one head policy_acc, val_policy_acc = self.policy_acc(logs) log.debug("combined policy accuracy %.4f/%.4f" % (policy_acc, val_policy_acc)) # are we overitting? overfitting = policy_acc - 0.02 > val_policy_acc # store best weights as best val_policy_acc if (self.epoch_last_set_at is None or (val_policy_acc > self.best_val_policy_acc and not overfitting)): log.debug("Setting best to last val_policy_acc %.4f" % val_policy_acc) self.best = self.model.get_weights() self.best_val_policy_acc = val_policy_acc self.epoch_last_set_at = epoch store_retraining_weights = ( (policy_acc + 0.01) < val_policy_acc and val_policy_acc > self.retrain_best_val_policy_acc) if store_retraining_weights: log.debug("Setting retraining_weights to val_policy_acc %.4f" % val_policy_acc) self.retrain_best = self.model.get_weights() self.retrain_best_val_policy_acc = val_policy_acc # stop training: if (not self.retraining and epoch >= 4 or self.retraining and epoch >= 2): if overfitting: log.info("Early stopping... since policy accuracy overfitting") self.stop_training = True # if things havent got better - STOP. We can go on forever without improving. if self.epoch_last_set_at is not None and epoch > self.epoch_last_set_at + 3: log.info("Early stopping... since not improving") self.stop_training = True
def render_GET(self, request): log.debug("Got GET request from: %s" % request.getClientIP()) return self.handle(request)
def connectionMade(self): self.logical_connection = False log.debug("Client::connectionMade()")
def connectionLost(self, reason=""): self.logical_connection = False log.debug("Client::connectionLost() : %s" % reason)
def buildProtocol(self, addr): log.debug("Connection made from: %s" % addr) return ServerClient(self.broker)
def sync(self): # check summary matches current set of files if not self.check_summary() or not self.verify_db(): self.get_summary(create=True) self.create_db() for step, file_path, md5sum in self.files_to_process(): # lets delete any spurious memory gc.collect() log.debug("Processing %s" % file_path) data = attrutil.json_to_attr(gzip.open(file_path).read()) if len(data.samples) != data.num_samples: # pretty inconsequential, but we should at least notify msg = "num_samples (%d) versus actual samples (%s) differ... trimming" log.warning(msg % (data.num_samples, len(data.samples))) data.num_samples = min(len(data.samples), data.num_samples) data.samples = data.samples[:data.num_samples] log.debug("Game %s, with gen: %s and sample count %s" % (data.game, data.with_generation, data.num_samples)) indx = self.db.size stats = StatsAccumulator() t = self.transformer # ZZZ really slow # ZZZ profile/gather times in loop... (guessing the time is in decoding state) time_check = 0 time_stats = 0 time_decode = 0 time_decode_prevs = 0 time_channels = 0 time_outputs = 0 time_db_resize = 0 time_db_insert = 0 cur_size = indx for sample in self.augment_data(data.samples): # ensure that final scores are clamped before adding to db sample.final_score = [min(1.0, v) for v in sample.final_score] sample.final_score = [max(0.0, v) for v in sample.final_score] sample_is_draw = False if abs(sample.final_score[0] - 0.5) < 0.01: assert abs(sample.final_score[1] - 0.5) < 0.01 sample_is_draw = True # XXX highly experimental if sample_is_draw and self.score_draw_as_random_hack: # the idea is just to randomly asign a win or loss to train on. Then the # network can average out over a 'bazillion' draw samples and determine that # the value should be 0.5. In theory. XXX Who knows? if random.random() > 0.5: sample.final_score = [1.0, 0] else: sample.final_score = [0, 1.0] et = ElaspedTime() # XXX too slow, and only useful for debugging serious bugs - disable # t.check_sample(sample) time_check += et.update() stats.add(sample, was_draw=sample_is_draw) time_stats += et.update() # add channels # only decode if not already decoded (as in the case of augmentation) state = fast_decode_state(sample.state) time_decode += et.update() prev_states = [fast_decode_state(s) for s in sample.prev_states] time_decode_prevs += et.update() cols = [t.state_to_channels(state, prev_states)] time_channels += et.update() for ri, policy in enumerate(sample.policies): cols.append(t.policy_to_array(policy, ri)) time_outputs += et.update() cols.append(t.value_to_array(sample.final_score)) # is this an efficient way to do things? if indx >= cur_size: cur_size += 20 self.db.resize(cur_size) time_db_resize += et.update() for ii, name in enumerate(self.db.names): self.db[name][indx] = cols[ii] indx += 1 time_db_insert += et.update() print "time_check: %.2f" % time_check print "time_stats: %.2f" % time_stats print "time_decode: %.2f" % time_decode print "time_decode_prevs: %.2f" % time_decode_prevs print "time_channels: %.2f" % time_channels print "time_outputs: %.2f" % time_outputs print "time_db_resize: %.2f" % time_db_resize print "time_db_insert: %.2f" % time_db_insert if indx != cur_size: cur_size = indx self.db.resize(indx) self.db.flush() log.debug("Added %d samples to db" % stats.num_samples) # add to the summary and save it step_sum = datadesc.StepSummary(step=step, filename=file_path, with_generation=data.with_generation, num_samples=stats.num_samples, md5sum=md5sum, stats_unique_matches=stats.unique_matches, stats_draw_ratio=stats.draw_ratio, stats_bare_policies_ratio=stats.bare_policies_ratio, stats_av_starting_depth=stats.av_starting_depth, stats_av_ending_depth=stats.av_ending_depth, stats_av_resigns=stats.av_resigns, stats_av_resign_false_positive=stats.av_resign_false_positive, stats_av_puct_visits=stats.av_puct_visits, stats_ratio_of_roles=stats.ratio_of_roles, stats_av_final_scores=stats.av_final_scores, stats_av_puct_score_dist=stats.av_puct_score_dist) print attrutil.attr_to_json(step_sum, pretty=True) self.summary.last_updated = timestamp() self.summary.total_samples = self.db.size self.summary.step_summaries.append(step_sum) self.save_summary_file() log.debug("Saved summary file") # lets delete any spurious memory gc.collect() self.save_summary_file() log.info("Data cache synced, saved summary file.")
def get_indices(self, max_size=None, validation=False, include_all=None): ''' same bucket algorithm as old way, but with indexing: figure out the sizes required from each generation based on buckets (any rounding issues, drop from oldest generation) [also works for scaling down if we add max_number_of_samples] create a range(n) where n is the size of a generation. shuffle. remove head or tail until size. [old version removed tail, but it doesn't matter] combine all (need to offset start index of each generation data] shuffle. ''' do_debug = False # XXX add config option (actually best just an argument here) include_pct = 0.5 levels = self.validation_levels if validation else self.train_levels if do_debug: print "levels", levels sizes = [end - start for _, start, end in levels] if do_debug: print "sizes1", sizes # apply buckets bucket_sizes = [] for depth, sz in enumerate(sizes): percent = self.buckets.get(depth) if percent < 0: continue sz *= percent bucket_sizes.append(int(sz)) # do we have more data than needed for epoch? sizes = bucket_sizes if do_debug: print "sizes2", sizes if max_size is not None or max_size > 0: # XXX whole thing needs a rewrite... if sum(sizes) > max_size: include_sizes = [] remaining_sizes = sizes if include_all is not None: assert include_all == 1, "include_all == 1 ... XXXX only 1 supported" scale = max_size / float(sum(sizes)) sz = sizes[0] if int(math.ceil(sz * scale)) < int(include_pct * sz): include_sizes = [int(include_pct * s) for s in sizes[:include_all]] remaining_sizes = sizes[include_all:] include_total_size = sum(include_sizes) remaining_total_size = sum(remaining_sizes) if max_size > include_total_size: max_remaining_size = max_size - include_total_size scale = max_remaining_size / float(remaining_total_size) remaining_sizes = [int(math.ceil(s * scale)) for s in remaining_sizes] if sum(remaining_sizes) > max_remaining_size: remaining_sizes[-1] -= sum(remaining_sizes) - max_remaining_size if remaining_sizes[-1] <= 0: if DEBUG: print 'uggh XXX.... FIXME', remaining_sizes[-1] print "max_size = sum(sizes)", max_size, sum(sizes) remaining_sizes.pop(-1) max_size = sum(sizes) sizes = include_sizes + remaining_sizes assert sum(sizes) <= max_size if do_debug: print "sizes3", sizes self.debug_create_indices = [] all_indices = [] for ii, s in enumerate(sizes): all_indices += self.create_indices_for_level(ii, validation=validation, max_size=s) log.debug("debug_create_indices (depth, step, start, finish, sz, #games)") if len(self.debug_create_indices) > 5: log.debug("[%s, %s ... %s, %s]" % (self.debug_create_indices[0], self.debug_create_indices[1], self.debug_create_indices[-2], self.debug_create_indices[-1])) else: log.debug("%s" % (self.debug_create_indices,)) total_unique_games = sum([dci[5] for dci in self.debug_create_indices]) def f(x): return (x[4] / float(x[3] - x[2])) * x[5] pct_total_unique_games = int(sum([f(dci) for dci in self.debug_create_indices])) print "pct_total_unique_games", pct_total_unique_games log.info("Considering %s levels, total/pct unique games: %d/%d" % (len(self.debug_create_indices), total_unique_games, pct_total_unique_games)) random.shuffle(all_indices) return all_indices
def new_broker_client(self, worker): self.workers[worker] = WorkerInfo(worker, time.time()) log.debug("New worker %s" % worker) worker.send_msg(msgs.Ping()) worker.send_msg(msgs.RequestConfig())
def apply_move(self, moves): self.game_depth += 1 if self.verbose: log.debug("apply moves: %s" % (moves, )) # we give the player an one time opportunity to return debug/extra information # about the move it just played self.move_info.append(self.player.before_apply_info()) # get the previous state - incase our statemachine is out of sync self.sm.update_bases(self.get_current_state()) # fish tediously for move in available legals our_move = None preserve_move = [] for role_index, gamemaster_move in enumerate(moves): move = gamemaster_move # map the gamemaster move if self.gdl_symbol_mapping: for k, v in self.gdl_symbol_mapping.items(): move = replace_symbols(move, k, v) if self.verbose: log.debug("remapped move from '%s' -> '%s'" % (gamemaster_move, move)) preserve_move.append(move) # find the move found = False ls = self.sm.get_legal_state(role_index) choices = [ls.get_legal(ii) for ii in range(ls.get_count())] for choice in choices: choice_move = self.sm.legal_to_move(role_index, choice) if choice_move == str(move): found = True if role_index == self.our_role_index: our_move = choice_move self.joint_move.set(role_index, choice) break assert found, move assert our_move is not None # check that our move was the same. May be a timeout or ther gamemaster due to bad # network. In these cases, we force an abort (locally) to the game.. if self.last_played_move is not None: if self.last_played_move != our_move: # all we do is log, and continue. Really messed up though. msg = "Gamemaster sent back a different move from played move %s != %s" % ( self.last_played_move, our_move) log.critical(msg) raise CriticalError(msg) new_base_state = self.sm.new_base_state() self.sm.next_state(self.joint_move, new_base_state) self.sm.update_bases(new_base_state) # save for next time / prospserity self.moves.append(preserve_move) self.states.append(new_base_state) # in case player needs to cleanup some state self.player.on_apply_move(self.joint_move)
def sync(self): # check summary matches current set of files if not self.check_summary() or not self.verify_db(): self.get_summary(create=True) self.create_db() for step, file_path, md5sum in self.files_to_process(): # lets delete any spurious memory gc.collect() log.debug("Processing %s" % file_path) data = attrutil.json_to_attr(gzip.open(file_path).read()) if len(data.samples) != data.num_samples: # pretty inconsequential, but we should at least notify msg = "num_samples (%d) versus actual samples (%s) differ... trimming" log.warning(msg % (data.num_samples, len(data.samples))) data.num_samples = min(len(data.samples), data.num_samples) data.samples = data.samples[:data.num_samples] log.debug("Game %s, with gen: %s and sample count %s" % (data.game, data.with_generation, data.num_samples)) indx = self.db.size stats = StatsAccumulator() t = self.transformer # ZZZ really slow # ZZZ profile/gather times in loop... (guessing the time is in decoding state) time_check = 0 time_stats = 0 time_decode = 0 time_decode_prevs = 0 time_channels = 0 time_outputs = 0 time_db_resize = 0 time_db_insert = 0 cur_size = indx for sample in self.augment_data(data.samples): et = ElaspedTime() #t.check_sample(sample) time_check += et.update() stats.add(sample) time_stats += et.update() # add channels # only decode if not already decoded (as in the case of augmentation) state = fast_decode_state(sample.state) time_decode += et.update() prev_states = [ fast_decode_state(s) for s in sample.prev_states ] time_decode_prevs += et.update() cols = [t.state_to_channels(state, prev_states)] time_channels += et.update() for ri, policy in enumerate(sample.policies): cols.append(t.policy_to_array(policy, ri)) time_outputs += et.update() cols.append(t.value_to_array(sample.final_score)) # is this an efficient way to do things? if indx >= cur_size: cur_size += 20 self.db.resize(cur_size) time_db_resize += et.update() for ii, name in enumerate(self.db.names): self.db[name][indx] = cols[ii] indx += 1 time_db_insert += et.update() print "time_check: %.2f" % time_check print "time_stats: %.2f" % time_stats print "time_decode: %.2f" % time_decode print "time_decode_prevs: %.2f" % time_decode_prevs print "time_channels: %.2f" % time_channels print "time_outputs: %.2f" % time_outputs print "time_db_resize: %.2f" % time_db_resize print "time_db_insert: %.2f" % time_db_insert if indx != cur_size: cur_size = indx self.db.resize(indx) self.db.flush() log.debug("Added %d samples to db" % stats.num_samples) # add to the summary and save it step_sum = datadesc.StepSummary( step=step, filename=file_path, with_generation=data.with_generation, num_samples=stats.num_samples, md5sum=md5sum, stats_unique_matches=stats.unique_matches, stats_draw_ratio=stats.draw_ratio, stats_bare_policies_ratio=stats.bare_policies_ratio, stats_av_starting_depth=stats.av_starting_depth, stats_av_ending_depth=stats.av_ending_depth, stats_av_resigns=stats.av_resigns, stats_av_resign_false_positive=stats.av_resign_false_positive, stats_av_puct_visits=stats.av_puct_visits, stats_ratio_of_roles=stats.ratio_of_roles, stats_av_final_scores=stats.av_final_scores, stats_av_puct_score_dist=stats.av_puct_score_dist) print attrutil.attr_to_json(step_sum, pretty=True) self.summary.last_updated = timestamp() self.summary.total_samples = self.db.size self.summary.step_summaries.append(step_sum) self.save_summary_file() log.debug("Saved summary file") # lets delete any spurious memory gc.collect() self.save_summary_file() log.info("Data cache synced, saved summary file.")
def do_start(self, initial_basestate=None, game_depth=0): ''' Optional initial_basestate. Used mostly for testing. If none will use the initial state of state machine (and the game_depth will be zero). Game depth may not be handled by the base player. ''' enter_time = time.time() end_time = enter_time + self.meta_time if self.cushion_time > 0: end_time -= self.cushion_time if self.verbose: log.debug("Match.do_start(), time = %.1f" % (end_time - enter_time)) if self.sm is None: self.sm = self.game_info.get_sm() self.sm.reset() if self.verbose: log.debug("Got state machine %s for game '%s' and match_id: %s" % (self.sm, self.game_info.game, self.match_id)) if initial_basestate: # dupe the state - it could be deleted under our feet bs = self.sm.new_base_state() bs.assign(initial_basestate) initial_basestate = bs if self.verbose: initial_str = self.game_info.model.basestate_to_str( initial_basestate) log.debug("The start state is %s" % initial_str) # update the statemachine self.sm.update_bases(initial_basestate) # check it is not actually finished assert not self.sm.is_terminal() else: initial_basestate = self.sm.get_initial_state() self.states.append(initial_basestate) # store a joint move internally self.joint_move = self.sm.get_joint_move() # set our role index if self.gdl_symbol_mapping: our_role = self.gdl_symbol_mapping[self.role] else: our_role = self.role self.our_mapped_role = our_role if our_role not in self.sm.get_roles(): raise BadGame("Our role not found. %s in %s", (our_role, self.sm.get_roles())) self.our_role_index = self.sm.get_roles().index(our_role) if self.verbose: log.info('roles : %s, our_role : %s, role_index : %s' % (self.sm.get_roles(), our_role, self.our_role_index)) assert self.our_role_index != -1 # starting point for the game (normally zero) self.game_depth = game_depth # FINALLY : call the meta gaming stage on the player # note: on_meta_gaming must use self.match.get_current_state() self.player.reset(self) self.player.on_meta_gaming(end_time)