def framework_initialize_stage(self, stack): # Configure and create session and graph for stage. session_config = tf.ConfigProto() session_config.intra_op_parallelism_threads = cfg.get( 'intra_op_parallelism_threads', 0) session_config.inter_op_parallelism_threads = cfg.get( 'inter_op_parallelism_threads', 0) session_config.log_device_placement = cfg.get('log_device_placement', 0) if cfg.use_gpu: per_process_gpu_memory_fraction = getattr( cfg, 'per_process_gpu_memory_fraction', None) if per_process_gpu_memory_fraction: session_config.gpu_options.per_process_gpu_memory_fraction = per_process_gpu_memory_fraction gpu_allow_growth = getattr(cfg, 'gpu_allow_growth', None) if gpu_allow_growth: session_config.gpu_options.allow_growth = gpu_allow_growth _print("Using GPU if available.") _print("Using {}% of GPU memory.".format( 100 * session_config.gpu_options.per_process_gpu_memory_fraction)) _print("Allowing growth of GPU memory: {}".format( session_config.gpu_options.allow_growth)) graph = tf.Graph() sess = tf.Session(graph=graph, config=session_config) # This HAS to come after the creation of the session, otherwise # it allocates all GPU memory if using the GPU. _print("\nAvailable devices: ") from tensorflow.python.client import device_lib _print(device_lib.list_local_devices()) if not cfg.use_gpu: _print("Not using GPU.") stack.enter_context(graph.device("/cpu:0")) stack.enter_context(graph.as_default()) stack.enter_context(sess) stack.enter_context(sess.as_default()) # Set the seed for the stage. tf_seed = gen_seed() _print( "Setting tensorflow seed to generated seed: {}\n".format(tf_seed)) tf.set_random_seed(tf_seed) tf.logging.set_verbosity(tf.logging.ERROR)
def _build_graph(self): self.data_manager = DataManager(self.env.datasets['train'], self.env.datasets['val'], self.env.datasets['test'], cfg.batch_size) self.data_manager.build_graph() data = self.data_manager.iterator.get_next() self.inp = data["image"] network_outputs = self.network(data, self.data_manager.is_training) network_tensors = network_outputs["tensors"] network_recorded_tensors = network_outputs["recorded_tensors"] network_losses = network_outputs["losses"] self.tensors = network_tensors self.recorded_tensors = recorded_tensors = dict(global_step=tf.train.get_or_create_global_step()) # --- loss --- self.loss = tf.constant(0., tf.float32) for name, tensor in network_losses.items(): self.loss += tensor recorded_tensors['loss_' + name] = tensor recorded_tensors['loss'] = self.loss # --- train op --- if cfg.do_train and not cfg.get('no_gradient', False): tvars = self.trainable_variables(for_opt=True) self.train_op, self.train_records = build_gradient_train_op( self.loss, tvars, self.optimizer_spec, self.lr_schedule, self.max_grad_norm, self.noise_schedule, grad_n_record_groups=self.grad_n_record_groups) sess = tf.get_default_session() for k, v in getattr(sess, 'scheduled_values', None).items(): if k in recorded_tensors: recorded_tensors['scheduled_' + k] = v else: recorded_tensors[k] = v # --- recorded values --- intersection = recorded_tensors.keys() & network_recorded_tensors.keys() assert not intersection, "Key sets have non-zero intersection: {}".format(intersection) recorded_tensors.update(network_recorded_tensors) intersection = recorded_tensors.keys() & self.network.eval_funcs.keys() assert not intersection, "Key sets have non-zero intersection: {}".format(intersection) if self.network.eval_funcs: eval_funcs = self.network.eval_funcs else: eval_funcs = {} # For running functions, during evaluation, that are not implemented in tensorflow self.evaluator = Evaluator(eval_funcs, network_tensors, self)
def _update(self, batch_size): if cfg.get('no_gradient', False): return dict(train=dict()) feed_dict = self.data_manager.do_train() sess = tf.get_default_session() _, record, train_record = sess.run( [self.train_op, self.recorded_tensors, self.train_records], feed_dict=feed_dict) record.update(train_record) return dict(train=record)
def find_atari_dir(atari_game, after_warp): directory = cfg.get('atari_data_dir', None) if not directory: directory = os.path.join(cfg.data_dir, "atari_data") print(f"Using `{directory}` as atari data dir.") dirs = os.listdir(directory) game_full_name = "{}NoFrameskip-v4".format(atari_game) starts_with = "atari_data_env={}.datetime=".format(game_full_name) matching_dirs = [d for d in dirs if d.startswith(starts_with)] if not matching_dirs: pprint(sorted(dirs)) raise Exception("No data found for game {}".format(atari_game)) directory = os.path.join(directory, sorted(matching_dirs)[-1]) directory = os.path.join(directory, ("after" if after_warp else "before") + "_warp_recording") return directory
def framework_load_weights(self): """ Adapted from the tensorflow version, roughly treats a pytorch module as equivalant to a tensorflow variable scope. Most general form a dictionary entry is: {"<dest_module_path>": "<source_module_path>:<file_path>"} Maps tensors located at module path `source_module_path` in file `file_path` to module path `dest_module_path` in the current model. """ omit_modules = cfg.get('omit_modules_from_loading', []) for dest_module_path, path in self.get_load_paths(): _print("Loading submodule \"{}\" from {}.".format(dest_module_path, path)) if ":" in path: source_module_path, source_path = path.split(':') else: source_path = path source_module_path = dest_module_path start = time.time() device = get_pytorch_device() loaded_state_dict = torch.load(source_path, map_location=device)['model'] if source_module_path: source_module_path_with_sep = source_module_path + '.' loaded_state_dict = type(loaded_state_dict)( {k: v for k, v in loaded_state_dict.items() if k.startswith(source_module_path_with_sep)} ) assert loaded_state_dict, ( f"File contains no tensors with prefix `{source_module_path_with_sep}` (file: {source_path})" ) if dest_module_path != source_module_path: # Rename variables from the loaded state dict by replacing `source_module_path` with `dest_module_path`. _source_module_path = source_module_path + '.' if source_module_path else source_module_path _dest_module_path = dest_module_path + '.' if dest_module_path else dest_module_path loaded_state_dict = { k.replace(_source_module_path, _dest_module_path, 1): v for k, v in loaded_state_dict.items() } module = self.updater.model state_dict = module.state_dict() intersection = set(state_dict.keys()) & set(loaded_state_dict.keys()) if not intersection: raise Exception( f"Loading variables with spec ({dest_module_path}, {path}) " f"would have no effect (no variables found)." ) loaded_state_dict = {k: loaded_state_dict[k] for k in intersection} if omit_modules: omitted_variables = { k: v for k, v in loaded_state_dict.items() if any(k.startswith(o) for o in omit_modules) } print("Omitting the following variables from loading:") describe_structure(omitted_variables) loaded_state_dict = { k: v for k, v in loaded_state_dict.items() if k not in omitted_variables } _print("Loading variables:") describe_structure(loaded_state_dict) state_dict.update(loaded_state_dict) module.load_state_dict(state_dict, strict=True) _print("Done loading weights for module {}, took {} seconds.".format(dest_module_path, time.time() - start))
def update(self, batch_size, step): print_time = step % 100 == 0 self.model.train() data = AttrDict(next(self.train_iterator)) self.model.update_global_step(step) detect_grad_anomalies = cfg.get('detect_grad_anomalies', False) with torch.autograd.set_detect_anomaly(detect_grad_anomalies): profile_step = cfg.get('pytorch_profile_step', 0) if profile_step > 0 and step % profile_step == 0: with torch.autograd.profiler.profile(use_cuda=True) as prof: tensors, data, recorded_tensors, losses = self.model(data, step) print(prof) else: with timed_block('forward', print_time): tensors, data, recorded_tensors, losses = self.model(data, step) # --- loss --- losses = AttrDict(losses) loss = 0.0 for name, tensor in losses.flatten().items(): loss += tensor recorded_tensors['loss_' + name] = tensor recorded_tensors['loss'] = loss with timed_block('zero_grad', print_time): # Apparently this is faster, according to https://www.youtube.com/watch?v=9mS1fIYj1So, 10:37 for param in self.model.parameters(): param.grad = None # self.optimizer.zero_grad() with timed_block('loss backward', print_time): loss.backward() with timed_block('process grad', print_time): if self.grad_norm_recorder is not None: self.grad_norm_recorder.update() if step % self.print_grad_norm_step == 0: self.grad_norm_recorder.display() parameters = list(self.model.parameters()) pure_grad_norm = grad_norm(parameters) if self.max_grad_norm is not None and self.max_grad_norm > 0.0: torch.nn.utils.clip_grad_norm_(parameters, self.max_grad_norm) clipped_grad_norm = grad_norm(parameters) with timed_block('optimizer step', print_time): self.optimizer.step() if self.scheduler is not None: self.scheduler.step() update_result = self._update(batch_size) if isinstance(update_result, dict): recorded_tensors.update(update_result) self._n_experiences += batch_size recorded_tensors.update( grad_norm_pure=pure_grad_norm, grad_norm_clipped=clipped_grad_norm ) scheduled_values = self.model.get_scheduled_values() recorded_tensors.update(scheduled_values) recorded_tensors = map_structure( lambda t: t.mean() if isinstance(t, torch.Tensor) else t, recorded_tensors, is_leaf=lambda rec: not isinstance(rec, dict)) return recorded_tensors
def _run_stage(self, stage_idx, updater): """ Run main training loop for a stage of the curriculum. """ threshold_reached = False reason = "NotStarted" # Parse stopping criteria, set up early stopping stopping_criteria = cfg.get("stopping_criteria", None) if not stopping_criteria: stopping_criteria = updater.stopping_criteria if isinstance(stopping_criteria, str): stopping_criteria = stopping_criteria.split(",") self.stopping_criteria_name = stopping_criteria[0] if "max" in stopping_criteria[1]: self.maximize_sc = True elif "min" in stopping_criteria[1]: self.maximize_sc = False else: raise Exception("Ambiguous stopping criteria specification: {}".format(stopping_criteria[1])) early_stop = EarlyStopHook(patience=cfg.patience, maximize=self.maximize_sc) # Start stage print("\n" + "-" * 10 + " Training begins " + "-" * 10) self.timestamp("") print() total_hooks_time = 0.0 time_per_hook = 0.0 total_eval_time = 0.0 time_per_eval = 0.0 total_train_time = 0.0 time_per_example = 0.0 time_per_update = 0.0 n_eval = 0 while True: # Check whether to keep training if updater.n_updates >= cfg.max_steps: reason = "Maximum number of steps-per-stage reached" break if updater.n_experiences >= cfg.max_experiences: reason = "Maximum number of experiences-per-stage reached" break local_step = updater.n_updates global_step = self.global_step if local_step > 0 and local_step % cfg.checkpoint_step == 0: self.data.dump_data(local_step) evaluate = (local_step % cfg.eval_step) == 0 display = (local_step % cfg.display_step) == 0 render = (cfg.render_step > 0 and (local_step % cfg.render_step) == 0 and (local_step > 0 or cfg.render_first)) data_to_store = [] # --------------- Run hooks ------------------- hooks_start = time.time() for hook in cfg.hooks: if hook.call_per_timestep: run_hook = local_step == 0 and hook.initial run_hook |= local_step > 0 and local_step % hook.n == 0 if run_hook: hook_record = hook.step(self, updater, local_step) if hook_record: data_to_store.extend(dict(hook_record).items()) hooks_duration = time.time() - hooks_start if render and cfg.render_hook is not None: print("Rendering...") cfg.render_hook(updater) print("Done rendering.") if display: print("Displaying...") self.data.summarize_current_stage( local_step, global_step, updater.n_experiences, self.n_global_experiences) print("\nMy PID: {}\n".format(os.getpid())) print("Physical memory use: {}mb".format(memory_usage(physical=True))) print("Virtual memory use: {}mb".format(memory_usage(physical=False))) print("Avg time per update: {}s".format(time_per_update)) print("Avg time per eval: {}s".format(time_per_eval)) print("Avg time for hooks: {}s".format(time_per_hook)) if cfg.use_gpu: print(nvidia_smi()) # --------------- Possibly evaluate ------------------- if evaluate: print("Evaluating...") eval_start_time = time.time() val_record = updater.evaluate(cfg.batch_size, mode="val") eval_duration = time.time() - eval_start_time print("Done evaluating") val_record["duration"] = eval_duration n_eval += 1 total_eval_time += eval_duration time_per_eval = total_eval_time / n_eval data_to_store.append(("val", val_record)) if self.stopping_criteria_name not in val_record: print("Stopping criteria {} not in record returned " "by updater, using 0.0.".format(self.stopping_criteria_name)) stopping_criteria = val_record.get(self.stopping_criteria_name, 0.0) new_best, stop = early_stop.check(stopping_criteria, local_step, val_record) if new_best: print("Storing new best on step (l={}, g={}), " "constituting (l={}, g={}) experiences, " "with stopping criteria ({}) of {}.".format( local_step, global_step, updater.n_experiences, self.n_global_experiences, self.stopping_criteria_name, stopping_criteria)) best_path = self.data.path_for( 'weights/best_of_stage_{}'.format(stage_idx)) best_path = cfg.get('save_path', best_path) weight_start = time.time() best_path = updater.save(tf.get_default_session(), best_path) print("Done saving weights, took {} seconds".format(time.time() - weight_start)) self.data.record_values_for_stage( best_path=best_path, best_global_step=global_step) self.data.record_values_for_stage( **{'best_' + k: v for k, v in early_stop.best.items()}) if stop: print("Early stopping triggered.") reason = "Early stopping triggered" break if self.maximize_sc: threshold_reached = stopping_criteria >= cfg.threshold else: threshold_reached = stopping_criteria <= cfg.threshold if threshold_reached: reason = "Stopping criteria threshold reached" break # --------------- Perform an update ------------------- if cfg.do_train: if local_step % 100 == 0: print("Running update step {}...".format(local_step)) update_start_time = time.time() _old_n_experiences = updater.n_experiences update_record = updater.update(cfg.batch_size) update_duration = time.time() - update_start_time update_record["train"]["duration"] = update_duration if local_step % 100 == 0: print("Done update step.") if local_step % 100 == 0: start = time.time() update_record["train"]["memory_physical_mb"] = memory_usage(physical=True) update_record["train"]["memory_virtual_mb"] = memory_usage(physical=False) update_record["train"]["memory_gpu_mb"] = gpu_memory_usage() print("Memory check duration: {}".format(time.time() - start)) data_to_store.extend(dict(update_record).items()) n_experiences_delta = updater.n_experiences - _old_n_experiences self.n_global_experiences += n_experiences_delta total_train_time += update_duration time_per_example = total_train_time / updater.n_experiences time_per_update = total_train_time / updater.n_updates total_hooks_time += hooks_duration time_per_hook = total_hooks_time / updater.n_updates # --------------- Store data ------------------- records = defaultdict(dict) for mode, r in data_to_store: records[mode].update(r) self.data.store_step_data_and_summaries( stage_idx, local_step, global_step, updater.n_experiences, self.n_global_experiences, **records) self.data.record_values_for_stage( time_per_example=time_per_example, time_per_update=time_per_update, time_per_eval=time_per_eval, time_per_hook=time_per_hook, n_steps=local_step, n_experiences=updater.n_experiences, ) self.global_step += 1 # If `do_train` is False, we do no training and evaluate # exactly once, so only one iteration is required. if not cfg.do_train: reason = "`do_train` set to False" break return threshold_reached, reason
def _run(self): print(cfg.to_string()) threshold_reached = True self.global_step = 0 self.n_global_experiences = 0 self.curriculum_remaining = self.curriculum + [] self.curriculum_complete = [] stage_idx = 0 while self.curriculum_remaining: print("\n" + "=" * 50) self.timestamp("Starting stage {}".format(stage_idx)) print("\n") if cfg.start_tensorboard: restart_tensorboard(self.experiment_store.path, cfg.tbport, cfg.reload_interval) stage_config = self.curriculum_remaining.pop(0) stage_config = Config(stage_config) self.data.start_stage(stage_idx, stage_config) with ExitStack() as stack: # --------------- Stage set-up ------------------- print("\n" + "-" * 10 + " Stage set-up " + "-" * 10) print("\nNew config values for this stage are: \n{}\n".format(pformat(stage_config))) stack.enter_context(stage_config) stage_prepare_func = cfg.get("stage_prepare_func", None) if callable(stage_prepare_func): stage_prepare_func() # Modify the stage config in arbitrary ways before starting stage self.mpi_context.start_stage() # Configure and create session and graph for stage. session_config = tf.ConfigProto() session_config.intra_op_parallelism_threads = cfg.get('intra_op_parallelism_threads', 0) session_config.inter_op_parallelism_threads = cfg.get('inter_op_parallelism_threads', 0) session_config.log_device_placement = cfg.get('log_device_placement', 0) if cfg.use_gpu: per_process_gpu_memory_fraction = getattr(cfg, 'per_process_gpu_memory_fraction', None) if per_process_gpu_memory_fraction: session_config.gpu_options.per_process_gpu_memory_fraction = per_process_gpu_memory_fraction gpu_allow_growth = getattr(cfg, 'gpu_allow_growth', None) if gpu_allow_growth: session_config.gpu_options.allow_growth = gpu_allow_growth if cfg.use_gpu: print("Using GPU if available.") print("Using {}% of GPU memory.".format( 100 * session_config.gpu_options.per_process_gpu_memory_fraction)) print("Allowing growth of GPU memory: {}".format(session_config.gpu_options.allow_growth)) graph = tf.Graph() sess = tf.Session(graph=graph, config=session_config) # This HAS to come after the creation of the session, otherwise # it allocates all GPU memory if using the GPU. print("\nAvailable devices: ") from tensorflow.python.client import device_lib print(device_lib.list_local_devices()) if not cfg.use_gpu: print("Not using GPU.") stack.enter_context(graph.device("/cpu:0")) stack.enter_context(graph.as_default()) stack.enter_context(sess) stack.enter_context(sess.as_default()) # Set the seed for the stage. Notice we generate a new tf seed for each stage. tf_seed = gen_seed() print("Setting tensorflow seed to generated seed: {}\n".format(tf_seed)) tf.set_random_seed(tf_seed) # Set limit on CPU RAM for the stage cpu_ram_limit_mb = cfg.get("cpu_ram_limit_mb", None) if cpu_ram_limit_mb is not None: stack.enter_context(memory_limit(cfg.cpu_ram_limit_mb)) print("Building env...\n") # Maybe build env if stage_idx == 0 or not cfg.preserve_env: if getattr(self, 'env', None): self.env.close() self.env = cfg.build_env() if hasattr(self.env, "print_memory_footprint"): self.env.print_memory_footprint() print("\nDone building env.\n") print("Building updater...\n") import warnings with warnings.catch_warnings(): warnings.simplefilter('once') if cfg.n_procs > 1: updater = cfg.get_updater(self.env, mpi_context=self.mpi_context) else: updater = cfg.get_updater(self.env) updater.stage_idx = stage_idx updater.exp_dir = self.exp_dir updater.build_graph() print("\nDone building updater.\n") walk_variable_scopes(max_depth=3) # Maybe initialize network weights. # Let a *path_specification* be one of three things: # 1. An integer specifying a stage to load the best hypothesis from. # 2. A string of format: "stage_idx,kind" where `stage_idx` specifies a stage to load from # and `kind` is either "final" or "best", specifying whether to load final or best # hypothesis from that stage. # 3. A path on the filesystem that gives a prefix for a tensorflow checkpoint file to load from. # # Then cfg.load_path can either be a path_specification itself, in which case all variables # in the network will be loaded from that path_specification, or a dictionary mapping from # variable scope names to path specifications, in which case all variables in each supplied # variable scope name will be loaded from the path_specification paired with that scope name. load_path = cfg.load_path if load_path is not None: if isinstance(load_path, str) or isinstance(load_path, int): load_path = {"": load_path} load_path = dict(load_path) # Sort in increasing order, so that it if one variable scope lies within another scope, # the outer scope gets loaded before the inner scope, rather than having the outer scope # wipe out the inner scope. items = sorted(load_path.items()) for var_scope, path in items: variables = {v.name: v for v in trainable_variables(var_scope, for_opt=False)} if not variables: print("No variables to load in scope {}.".format(str(var_scope))) continue saver = tf.train.Saver(variables) load_stage, kind = None, None if isinstance(path, int): load_stage = path kind = "best" elif isinstance(path, str): try: split = path.split(',') load_stage = int(split[0]) kind = 'best' if len(split) > 1 else split[1] assert kind in 'best final'.split(), "path={}".format(path) except Exception: load_stage, kind = None, None if load_stage is not None: if stage_idx == 0: print( "Not loading var scope \"{}\" from stage {}, " "currently in stage 0.".format(var_scope, load_stage)) continue else: key = kind + '_path' completed_history = self.data.history[:-1] path = completed_history[load_stage][key] path = os.path.realpath(path) saver.restore(tf.get_default_session(), path) print("Loading var scope \"{}\" from {}.".format(var_scope, path)) else: print("Using a fresh set of weights, not loading anything.") tf.train.get_or_create_global_step() sess.run(uninitialized_variables_initializer()) sess.run(tf.assert_variables_initialized()) for hook in cfg.hooks: assert isinstance(hook, Hook) hook.start_stage(self, updater, stage_idx) threshold_reached = False reason = None try: # --------------- Run stage ------------------- start = time.time() phys_memory_before = memory_usage(physical=True) gpu_memory_before = gpu_memory_usage() threshold_reached, reason = self._run_stage(stage_idx, updater) except KeyboardInterrupt: reason = "User interrupt" except NotImplementedError as e: # There is a bug in pdb_postmortem that prevents instances of `NotImplementedError` # from being handled properly, so replace it with an instance of `Exception`. if cfg.robust: traceback.print_exc() reason = "Exception occurred ({})".format(repr(e)) else: raise Exception("NotImplemented") from e except Exception as e: reason = "Exception occurred ({})".format(repr(e)) if cfg.robust: traceback.print_exc() else: raise except Alarm: reason = "Time limit exceeded" raise finally: phys_memory_after = memory_usage(physical=True) gpu_memory_after = gpu_memory_usage() self.data.record_values_for_stage( stage_duration=time.time()-start, phys_memory_before_mb=phys_memory_before, phys_memory_delta_mb=phys_memory_after - phys_memory_before, gpu_memory_before_mb=gpu_memory_before, gpu_memory_delta_mb=gpu_memory_after - gpu_memory_before ) self.data.record_values_for_stage(reason=reason) print("\n" + "-" * 10 + " Optimization complete " + "-" * 10) print("\nReason: {}.\n".format(reason)) final_path = self.data.path_for('weights/final_for_stage_{}'.format(stage_idx)) final_path = cfg.get('save_path', final_path) final_path = updater.save(tf.get_default_session(), final_path) self.data.record_values_for_stage(final_path=final_path) # --------------- Maybe render performance of best hypothesis ------------------- do_final_testing = ( "Exception occurred" not in reason and reason != "Time limit exceeded" and 'best_path' in self.data.current_stage_record) if do_final_testing: try: print("\n" + "-" * 10 + " Final testing/rendering " + "-" * 10) print("Best hypothesis for this stage was found on " "step (l: {best_local_step}, g: {best_global_step}) " "with stopping criteria ({sc_name}) of {best_stopping_criteria}.".format( sc_name=self.stopping_criteria_name, **self.data.current_stage_record)) best_path = self.data.current_stage_record['best_path'] print("Loading best hypothesis for this stage " "from file {}...".format(best_path)) updater.restore(sess, best_path) test_record = updater.evaluate(cfg.batch_size, mode="test") for hook in cfg.hooks: if hook.call_per_timestep and hook.final: hook_record = hook.step(self, updater) if hook_record: assert len(hook_record) == 1 for k, d in dict(hook_record).items(): test_record.update(d) self.data.record_values_for_stage( **{'_test_' + k: v for k, v in test_record.items()}) if cfg.render_step > 0 and cfg.render_hook is not None: print("Rendering...") cfg.render_hook(updater) print("Done rendering.") except BaseException: print("Exception occurred while performing final testing/rendering: ") traceback.print_exc() else: print("\n" + "-" * 10 + " Skipping final testing/rendering " + "-" * 10) # --------------- Finish up the stage ------------------- self.data.end_stage(updater.n_updates) print("\n" + "-" * 10 + " Running end-of-stage hooks " + "-" * 10 + "\n") for hook in cfg.hooks: hook.end_stage(self, stage_idx) print() self.timestamp("Done stage {}".format(stage_idx)) print("=" * 50) stage_idx += 1 self.curriculum_complete.append(stage_config) if not (threshold_reached or cfg.power_through): print("Failed to reach stopping criteria threshold on stage {} " "of the curriculum, terminating.".format(stage_idx)) break
def run(self, start_time): """ Run the training loop. Parameters ---------- start_time: int Start time (in seconds since epoch) for measuring elapsed time for purposes of interrupting the training loop. """ if start_time is None: start_time = time.time() self.start_time = start_time self.timestamp("Entering TrainingLoop.run") prepare_func = cfg.get("prepare_func", None) if callable(prepare_func): prepare_func() # Modify the config in arbitrary ways before training else: try: prepare_funcs = list(prepare_func) except (TypeError, ValueError): pass else: for f in prepare_funcs: if callable(f): f() self.curriculum = cfg.curriculum + [] if cfg.seed is None or cfg.seed < 0: cfg.seed = gen_seed() # Create a directory to store the results of the training session. self.experiment_store = ExperimentStore(os.path.join(cfg.local_experiments_dir, cfg.env_name)) exp_dir = self.experiment_store.new_experiment( self.exp_name, cfg.seed, add_date=1, force_fresh=1, update_latest=False) self.exp_dir = exp_dir cfg.path = exp_dir.path breaker = "-" * 40 header = "{}\nREADME.md - {}\n{}\n\n\n".format(breaker, os.path.basename(exp_dir.path), breaker) readme = header + (cfg.readme if cfg.readme else "") + "\n\n" with open(exp_dir.path_for('README.md'), 'w') as f: f.write(readme) self.data = _TrainingLoopData(exp_dir) self.data.setup() frozen_data = None with ExitStack() as stack: if cfg.pdb: stack.enter_context(pdb_postmortem()) print("`pdb` is turned on, so forcing setting robust=False") cfg.robust = False stack.enter_context(redirect_stream('stdout', self.data.path_for('stdout'), tee=cfg.tee)) stack.enter_context(redirect_stream('stderr', self.data.path_for('stderr'), tee=cfg.tee)) print("\n\n" + "=" * 80) self.timestamp("Starting training run (name={})".format(self.exp_name)) print("\nDirectory for this training run is {}.".format(exp_dir.path)) stack.enter_context(NumpySeed(cfg.seed)) print("\nSet numpy random seed to {}.\n".format(cfg.seed)) limiter = time_limit( self.time_remaining, verbose=True, timeout_callback=lambda limiter: print("Training run exceeded its time limit.")) self.mpi_context = MPI_MasterContext(cfg.get('n_procs', 1), exp_dir) try: with limiter: self._run() finally: self.data.summarize() self.timestamp("Done training run (name={})".format(self.exp_name)) print("=" * 80) print("\n\n") frozen_data = self.data.freeze() self.timestamp("Leaving TrainingLoop.run") return frozen_data
def _build_graph(self): self.data_manager = DataManager(datasets=self.env.datasets) self.data_manager.build_graph() data = self.data_manager.iterator.get_next() network_inp = data[self.feature_name] if network_inp.dtype == tf.uint8: network_inp = tf.image.convert_image_dtype(network_inp, tf.float32) n_classes = self.env.datasets['train'].n_classes is_training = self.data_manager.is_training network_outputs = self.network(network_inp, data['label'], n_classes, is_training) network_tensors = network_outputs["tensors"] network_recorded_tensors = network_outputs["recorded_tensors"] network_losses = network_outputs["losses"] batch_size = tf.shape(network_inp)[0] float_is_training = tf.to_float(is_training) self.tensors = network_tensors self.tensors.update(data) self.tensors.update( network_inp=network_inp, is_training=is_training, float_is_training=float_is_training, batch_size=batch_size, ) self.recorded_tensors = recorded_tensors = dict( global_step=tf.train.get_or_create_global_step(), batch_size=batch_size, is_training=float_is_training, ) tvars = self.trainable_variables(for_opt=True) if self.l2_weight > 0.0: network_losses['l2'] = self.l2_weight * sum( tf.nn.l2_loss(v) for v in tvars if 'weights' in v.name) # --- loss --- self.loss = tf.constant(0., tf.float32) for name, tensor in network_losses.items(): self.loss += tensor recorded_tensors['loss_' + name] = tensor recorded_tensors['loss'] = self.loss # --- train op --- if cfg.do_train and not cfg.get('no_gradient', False): tvars = self.trainable_variables(for_opt=True) self.train_op, self.train_records = build_gradient_train_op( self.loss, tvars, self.optimizer_spec, self.lr_schedule, self.max_grad_norm, self.noise_schedule) sess = tf.get_default_session() for k, v in getattr(sess, 'scheduled_values', {}).items(): if k in recorded_tensors: recorded_tensors['scheduled_' + k] = v else: recorded_tensors[k] = v # --- recorded values --- intersection = recorded_tensors.keys() & network_recorded_tensors.keys( ) assert not intersection, "Key sets have non-zero intersection: {}".format( intersection) recorded_tensors.update(network_recorded_tensors) intersection = recorded_tensors.keys() & self.network.eval_funcs.keys() assert not intersection, "Key sets have non-zero intersection: {}".format( intersection) if self.network.eval_funcs: eval_funcs = self.network.eval_funcs else: eval_funcs = {} self.evaluator = Evaluator(eval_funcs, network_tensors, self)
def run_stage(mpi_context, env, stage_idx, exp_dir): config, seed = mpi_context.start_stage() with ExitStack() as stack: stack.enter_context(config) stack.enter_context(NumpySeed(seed)) # Accept config for new stage print("\n" + "-" * 10 + " Stage set-up " + "-" * 10) print(cfg.to_string()) # Configure and create session and graph for stage. session_config = tf.ConfigProto() session_config.intra_op_parallelism_threads = cfg.get( 'intra_op_parallelism_threads', 0) session_config.inter_op_parallelism_threads = cfg.get( 'inter_op_parallelism_threads', 0) # if cfg.use_gpu: # per_process_gpu_memory_fraction = getattr(cfg, 'per_process_gpu_memory_fraction', None) # if per_process_gpu_memory_fraction: # session_config.gpu_options.per_process_gpu_memory_fraction = \ # per_process_gpu_memory_fraction # gpu_allow_growth = getattr(cfg, 'gpu_allow_growth', None) # if gpu_allow_growth: # session_config.gpu_options.allow_growth = gpu_allow_growth # if cfg.use_gpu: # print("Using GPU if available.") # print("Using {}% of GPU memory.".format( # 100 * session_config.gpu_options.per_process_gpu_memory_fraction)) # print("Allowing growth of GPU memory: {}".format(session_config.gpu_options.allow_growth)) graph = tf.Graph() sess = tf.Session(graph=graph, config=session_config) # This HAS to come after the creation of the session, otherwise # it allocates all GPU memory if using the GPU. print("\nAvailable devices:") from tensorflow.python.client import device_lib print(device_lib.list_local_devices()) # if not cfg.use_gpu: # print("Not using GPU.") # stack.enter_context(graph.device("/cpu:0")) stack.enter_context(graph.device("/cpu:0")) stack.enter_context(graph.as_default()) stack.enter_context(sess) stack.enter_context(sess.as_default()) tf_seed = gen_seed() print( "Setting tensorflow seed to generated seed: {}\n".format(tf_seed)) tf.set_random_seed(tf_seed) # Set limit on CPU RAM for the stage cpu_ram_limit_mb = cfg.get("cpu_ram_limit_mb", None) if cpu_ram_limit_mb is not None: stack.enter_context(memory_limit(cfg.cpu_ram_limit_mb)) print("Building env...\n") # Maybe build env if stage_idx == 0 or not cfg.preserve_env: if env is not None: env.close() env = cfg.build_env() if hasattr(env, "print_memory_footprint"): env.print_memory_footprint() print("\nDone building env.\n") print("Building updater...\n") updater = cfg.get_updater(env, mpi_context=mpi_context) updater.stage_idx = stage_idx updater.exp_dir = exp_dir updater.build_graph() print("\nDone building updater.\n") # walk_variable_scopes(max_depth=3) tf.train.get_or_create_global_step() sess.run(uninitialized_variables_initializer()) sess.run(tf.assert_variables_initialized()) updater.worker_code() stage_idx += 1 return env