def load_or_train(train_config, var_scope, path, target_var_scope=None, sess=None): """ Attempts to load variables into ``var_scope`` from checkpoint stored at ``path``. If said checkpoint is not found, trains a model using the function ``train`` and stores the resulting variables for future use. Returns True iff model was successfully loaded, False otherwise. If `target_var_scope` is not None, look for the variables under that scope name in the file that we load from, instead of `var_scope`. """ sess = sess or tf.get_default_session() to_be_loaded = trainable_variables(var_scope, for_opt=False) if target_var_scope is not None: _tbl = {} for var in to_be_loaded: assert var.name.startswith(var_scope.name) bare_name = var.name[len(var_scope.name):] while bare_name.startswith('/'): bare_name = bare_name[1:] name_in_file = target_var_scope + '/' + bare_name _tbl[name_in_file] = var to_be_loaded = _tbl else: to_be_loaded = {v.name: v for v in to_be_loaded} saver = tf.train.Saver(to_be_loaded) if path is not None: os.makedirs(os.path.dirname(path), exist_ok=True) success = False try: saver.restore(sess, path) success = True except tf.errors.NotFoundError: with ExitStack() as stack: stack.enter_context(ClearConfig()) stack.enter_context(train_config.copy(save_path=path)) output = training_loop(var_scope.name) stem = os.path.splitext(path)[0] shutil.copyfile(output.path_for('stdout'), stem + '.stdout') shutil.copyfile(output.path_for('stderr'), stem + '.stderr') saver.restore(sess, path) return success
def framework_load_weights(self): for var_scope, path in self.get_load_paths(): _print("Loading var scope \"{}\" from {}.".format(var_scope, path)) start = time.time() variables = { v.name: v for v in trainable_variables(var_scope, for_opt=False) } if not variables: _print("No variables to load in scope {}.".format( str(var_scope))) continue saver = tf.train.Saver(variables) saver.restore(tf.get_default_session(), path) _print( "Done loading var scope, took {} seconds.".format(time.time() - start))
def trainable_variables(self, for_opt): return trainable_variables(self.controller.scope.name, for_opt=for_opt)
def _run(self): print(cfg.to_string()) threshold_reached = True self.global_step = 0 self.n_global_experiences = 0 self.curriculum_remaining = self.curriculum + [] self.curriculum_complete = [] stage_idx = 0 while self.curriculum_remaining: print("\n" + "=" * 50) self.timestamp("Starting stage {}".format(stage_idx)) print("\n") if cfg.start_tensorboard: restart_tensorboard(self.experiment_store.path, cfg.tbport, cfg.reload_interval) stage_config = self.curriculum_remaining.pop(0) stage_config = Config(stage_config) self.data.start_stage(stage_idx, stage_config) with ExitStack() as stack: # --------------- Stage set-up ------------------- print("\n" + "-" * 10 + " Stage set-up " + "-" * 10) print("\nNew config values for this stage are: \n{}\n".format(pformat(stage_config))) stack.enter_context(stage_config) stage_prepare_func = cfg.get("stage_prepare_func", None) if callable(stage_prepare_func): stage_prepare_func() # Modify the stage config in arbitrary ways before starting stage self.mpi_context.start_stage() # Configure and create session and graph for stage. session_config = tf.ConfigProto() session_config.intra_op_parallelism_threads = cfg.get('intra_op_parallelism_threads', 0) session_config.inter_op_parallelism_threads = cfg.get('inter_op_parallelism_threads', 0) session_config.log_device_placement = cfg.get('log_device_placement', 0) if cfg.use_gpu: per_process_gpu_memory_fraction = getattr(cfg, 'per_process_gpu_memory_fraction', None) if per_process_gpu_memory_fraction: session_config.gpu_options.per_process_gpu_memory_fraction = per_process_gpu_memory_fraction gpu_allow_growth = getattr(cfg, 'gpu_allow_growth', None) if gpu_allow_growth: session_config.gpu_options.allow_growth = gpu_allow_growth if cfg.use_gpu: print("Using GPU if available.") print("Using {}% of GPU memory.".format( 100 * session_config.gpu_options.per_process_gpu_memory_fraction)) print("Allowing growth of GPU memory: {}".format(session_config.gpu_options.allow_growth)) graph = tf.Graph() sess = tf.Session(graph=graph, config=session_config) # This HAS to come after the creation of the session, otherwise # it allocates all GPU memory if using the GPU. print("\nAvailable devices: ") from tensorflow.python.client import device_lib print(device_lib.list_local_devices()) if not cfg.use_gpu: print("Not using GPU.") stack.enter_context(graph.device("/cpu:0")) stack.enter_context(graph.as_default()) stack.enter_context(sess) stack.enter_context(sess.as_default()) # Set the seed for the stage. Notice we generate a new tf seed for each stage. tf_seed = gen_seed() print("Setting tensorflow seed to generated seed: {}\n".format(tf_seed)) tf.set_random_seed(tf_seed) # Set limit on CPU RAM for the stage cpu_ram_limit_mb = cfg.get("cpu_ram_limit_mb", None) if cpu_ram_limit_mb is not None: stack.enter_context(memory_limit(cfg.cpu_ram_limit_mb)) print("Building env...\n") # Maybe build env if stage_idx == 0 or not cfg.preserve_env: if getattr(self, 'env', None): self.env.close() self.env = cfg.build_env() if hasattr(self.env, "print_memory_footprint"): self.env.print_memory_footprint() print("\nDone building env.\n") print("Building updater...\n") import warnings with warnings.catch_warnings(): warnings.simplefilter('once') if cfg.n_procs > 1: updater = cfg.get_updater(self.env, mpi_context=self.mpi_context) else: updater = cfg.get_updater(self.env) updater.stage_idx = stage_idx updater.exp_dir = self.exp_dir updater.build_graph() print("\nDone building updater.\n") walk_variable_scopes(max_depth=3) # Maybe initialize network weights. # Let a *path_specification* be one of three things: # 1. An integer specifying a stage to load the best hypothesis from. # 2. A string of format: "stage_idx,kind" where `stage_idx` specifies a stage to load from # and `kind` is either "final" or "best", specifying whether to load final or best # hypothesis from that stage. # 3. A path on the filesystem that gives a prefix for a tensorflow checkpoint file to load from. # # Then cfg.load_path can either be a path_specification itself, in which case all variables # in the network will be loaded from that path_specification, or a dictionary mapping from # variable scope names to path specifications, in which case all variables in each supplied # variable scope name will be loaded from the path_specification paired with that scope name. load_path = cfg.load_path if load_path is not None: if isinstance(load_path, str) or isinstance(load_path, int): load_path = {"": load_path} load_path = dict(load_path) # Sort in increasing order, so that it if one variable scope lies within another scope, # the outer scope gets loaded before the inner scope, rather than having the outer scope # wipe out the inner scope. items = sorted(load_path.items()) for var_scope, path in items: variables = {v.name: v for v in trainable_variables(var_scope, for_opt=False)} if not variables: print("No variables to load in scope {}.".format(str(var_scope))) continue saver = tf.train.Saver(variables) load_stage, kind = None, None if isinstance(path, int): load_stage = path kind = "best" elif isinstance(path, str): try: split = path.split(',') load_stage = int(split[0]) kind = 'best' if len(split) > 1 else split[1] assert kind in 'best final'.split(), "path={}".format(path) except Exception: load_stage, kind = None, None if load_stage is not None: if stage_idx == 0: print( "Not loading var scope \"{}\" from stage {}, " "currently in stage 0.".format(var_scope, load_stage)) continue else: key = kind + '_path' completed_history = self.data.history[:-1] path = completed_history[load_stage][key] path = os.path.realpath(path) saver.restore(tf.get_default_session(), path) print("Loading var scope \"{}\" from {}.".format(var_scope, path)) else: print("Using a fresh set of weights, not loading anything.") tf.train.get_or_create_global_step() sess.run(uninitialized_variables_initializer()) sess.run(tf.assert_variables_initialized()) for hook in cfg.hooks: assert isinstance(hook, Hook) hook.start_stage(self, updater, stage_idx) threshold_reached = False reason = None try: # --------------- Run stage ------------------- start = time.time() phys_memory_before = memory_usage(physical=True) gpu_memory_before = gpu_memory_usage() threshold_reached, reason = self._run_stage(stage_idx, updater) except KeyboardInterrupt: reason = "User interrupt" except NotImplementedError as e: # There is a bug in pdb_postmortem that prevents instances of `NotImplementedError` # from being handled properly, so replace it with an instance of `Exception`. if cfg.robust: traceback.print_exc() reason = "Exception occurred ({})".format(repr(e)) else: raise Exception("NotImplemented") from e except Exception as e: reason = "Exception occurred ({})".format(repr(e)) if cfg.robust: traceback.print_exc() else: raise except Alarm: reason = "Time limit exceeded" raise finally: phys_memory_after = memory_usage(physical=True) gpu_memory_after = gpu_memory_usage() self.data.record_values_for_stage( stage_duration=time.time()-start, phys_memory_before_mb=phys_memory_before, phys_memory_delta_mb=phys_memory_after - phys_memory_before, gpu_memory_before_mb=gpu_memory_before, gpu_memory_delta_mb=gpu_memory_after - gpu_memory_before ) self.data.record_values_for_stage(reason=reason) print("\n" + "-" * 10 + " Optimization complete " + "-" * 10) print("\nReason: {}.\n".format(reason)) final_path = self.data.path_for('weights/final_for_stage_{}'.format(stage_idx)) final_path = cfg.get('save_path', final_path) final_path = updater.save(tf.get_default_session(), final_path) self.data.record_values_for_stage(final_path=final_path) # --------------- Maybe render performance of best hypothesis ------------------- do_final_testing = ( "Exception occurred" not in reason and reason != "Time limit exceeded" and 'best_path' in self.data.current_stage_record) if do_final_testing: try: print("\n" + "-" * 10 + " Final testing/rendering " + "-" * 10) print("Best hypothesis for this stage was found on " "step (l: {best_local_step}, g: {best_global_step}) " "with stopping criteria ({sc_name}) of {best_stopping_criteria}.".format( sc_name=self.stopping_criteria_name, **self.data.current_stage_record)) best_path = self.data.current_stage_record['best_path'] print("Loading best hypothesis for this stage " "from file {}...".format(best_path)) updater.restore(sess, best_path) test_record = updater.evaluate(cfg.batch_size, mode="test") for hook in cfg.hooks: if hook.call_per_timestep and hook.final: hook_record = hook.step(self, updater) if hook_record: assert len(hook_record) == 1 for k, d in dict(hook_record).items(): test_record.update(d) self.data.record_values_for_stage( **{'_test_' + k: v for k, v in test_record.items()}) if cfg.render_step > 0 and cfg.render_hook is not None: print("Rendering...") cfg.render_hook(updater) print("Done rendering.") except BaseException: print("Exception occurred while performing final testing/rendering: ") traceback.print_exc() else: print("\n" + "-" * 10 + " Skipping final testing/rendering " + "-" * 10) # --------------- Finish up the stage ------------------- self.data.end_stage(updater.n_updates) print("\n" + "-" * 10 + " Running end-of-stage hooks " + "-" * 10 + "\n") for hook in cfg.hooks: hook.end_stage(self, stage_idx) print() self.timestamp("Done stage {}".format(stage_idx)) print("=" * 50) stage_idx += 1 self.curriculum_complete.append(stage_config) if not (threshold_reached or cfg.power_through): print("Failed to reach stopping criteria threshold on stage {} " "of the curriculum, terminating.".format(stage_idx)) break
def trainable_variables(self, for_opt): return trainable_variables(self.f.scope, for_opt=for_opt)