Exemple #1
0
def load_or_train(train_config, var_scope, path, target_var_scope=None, sess=None):
    """ Attempts to load variables into ``var_scope`` from checkpoint stored at ``path``.

    If said checkpoint is not found, trains a model using the function
    ``train`` and stores the resulting variables for future use.

    Returns True iff model was successfully loaded, False otherwise.

    If `target_var_scope` is not None, look for the variables under that scope name in the file
    that we load from, instead of `var_scope`.

    """
    sess = sess or tf.get_default_session()

    to_be_loaded = trainable_variables(var_scope, for_opt=False)
    if target_var_scope is not None:
        _tbl = {}
        for var in to_be_loaded:
            assert var.name.startswith(var_scope.name)
            bare_name = var.name[len(var_scope.name):]
            while bare_name.startswith('/'):
                bare_name = bare_name[1:]
            name_in_file = target_var_scope + '/' + bare_name
            _tbl[name_in_file] = var
        to_be_loaded = _tbl
    else:
        to_be_loaded = {v.name: v for v in to_be_loaded}

    saver = tf.train.Saver(to_be_loaded)

    if path is not None:
        os.makedirs(os.path.dirname(path), exist_ok=True)

    success = False
    try:
        saver.restore(sess, path)
        success = True
    except tf.errors.NotFoundError:
        with ExitStack() as stack:
            stack.enter_context(ClearConfig())
            stack.enter_context(train_config.copy(save_path=path))

            output = training_loop(var_scope.name)

            stem = os.path.splitext(path)[0]
            shutil.copyfile(output.path_for('stdout'), stem + '.stdout')
            shutil.copyfile(output.path_for('stderr'), stem + '.stderr')

        saver.restore(sess, path)
    return success
Exemple #2
0
    def framework_load_weights(self):
        for var_scope, path in self.get_load_paths():
            _print("Loading var scope \"{}\" from {}.".format(var_scope, path))

            start = time.time()
            variables = {
                v.name: v
                for v in trainable_variables(var_scope, for_opt=False)
            }
            if not variables:
                _print("No variables to load in scope {}.".format(
                    str(var_scope)))
                continue

            saver = tf.train.Saver(variables)
            saver.restore(tf.get_default_session(), path)

            _print(
                "Done loading var scope, took {} seconds.".format(time.time() -
                                                                  start))
Exemple #3
0
 def trainable_variables(self, for_opt):
     return trainable_variables(self.controller.scope.name, for_opt=for_opt)
Exemple #4
0
    def _run(self):
        print(cfg.to_string())

        threshold_reached = True
        self.global_step = 0
        self.n_global_experiences = 0
        self.curriculum_remaining = self.curriculum + []
        self.curriculum_complete = []

        stage_idx = 0
        while self.curriculum_remaining:
            print("\n" + "=" * 50)
            self.timestamp("Starting stage {}".format(stage_idx))
            print("\n")

            if cfg.start_tensorboard:
                restart_tensorboard(self.experiment_store.path, cfg.tbport, cfg.reload_interval)

            stage_config = self.curriculum_remaining.pop(0)
            stage_config = Config(stage_config)

            self.data.start_stage(stage_idx, stage_config)

            with ExitStack() as stack:

                # --------------- Stage set-up -------------------

                print("\n" + "-" * 10 + " Stage set-up " + "-" * 10)

                print("\nNew config values for this stage are: \n{}\n".format(pformat(stage_config)))
                stack.enter_context(stage_config)

                stage_prepare_func = cfg.get("stage_prepare_func", None)
                if callable(stage_prepare_func):
                    stage_prepare_func()  # Modify the stage config in arbitrary ways before starting stage

                self.mpi_context.start_stage()

                # Configure and create session and graph for stage.
                session_config = tf.ConfigProto()
                session_config.intra_op_parallelism_threads = cfg.get('intra_op_parallelism_threads', 0)
                session_config.inter_op_parallelism_threads = cfg.get('inter_op_parallelism_threads', 0)
                session_config.log_device_placement = cfg.get('log_device_placement', 0)

                if cfg.use_gpu:
                    per_process_gpu_memory_fraction = getattr(cfg, 'per_process_gpu_memory_fraction', None)
                    if per_process_gpu_memory_fraction:
                        session_config.gpu_options.per_process_gpu_memory_fraction = per_process_gpu_memory_fraction

                    gpu_allow_growth = getattr(cfg, 'gpu_allow_growth', None)
                    if gpu_allow_growth:
                        session_config.gpu_options.allow_growth = gpu_allow_growth

                if cfg.use_gpu:
                    print("Using GPU if available.")
                    print("Using {}% of GPU memory.".format(
                        100 * session_config.gpu_options.per_process_gpu_memory_fraction))
                    print("Allowing growth of GPU memory: {}".format(session_config.gpu_options.allow_growth))

                graph = tf.Graph()
                sess = tf.Session(graph=graph, config=session_config)

                # This HAS to come after the creation of the session, otherwise
                # it allocates all GPU memory if using the GPU.
                print("\nAvailable devices: ")
                from tensorflow.python.client import device_lib
                print(device_lib.list_local_devices())

                if not cfg.use_gpu:
                    print("Not using GPU.")
                    stack.enter_context(graph.device("/cpu:0"))

                stack.enter_context(graph.as_default())
                stack.enter_context(sess)
                stack.enter_context(sess.as_default())

                # Set the seed for the stage. Notice we generate a new tf seed for each stage.
                tf_seed = gen_seed()
                print("Setting tensorflow seed to generated seed: {}\n".format(tf_seed))
                tf.set_random_seed(tf_seed)

                # Set limit on CPU RAM for the stage
                cpu_ram_limit_mb = cfg.get("cpu_ram_limit_mb", None)
                if cpu_ram_limit_mb is not None:
                    stack.enter_context(memory_limit(cfg.cpu_ram_limit_mb))

                print("Building env...\n")

                # Maybe build env
                if stage_idx == 0 or not cfg.preserve_env:
                    if getattr(self, 'env', None):
                        self.env.close()

                    self.env = cfg.build_env()

                if hasattr(self.env, "print_memory_footprint"):
                    self.env.print_memory_footprint()

                print("\nDone building env.\n")
                print("Building updater...\n")

                import warnings
                with warnings.catch_warnings():
                    warnings.simplefilter('once')

                    if cfg.n_procs > 1:
                        updater = cfg.get_updater(self.env, mpi_context=self.mpi_context)
                    else:
                        updater = cfg.get_updater(self.env)

                    updater.stage_idx = stage_idx
                    updater.exp_dir = self.exp_dir

                    updater.build_graph()
                    print("\nDone building updater.\n")

                walk_variable_scopes(max_depth=3)

                # Maybe initialize network weights.
                # Let a *path_specification* be one of three things:
                #     1. An integer specifying a stage to load the best hypothesis from.
                #     2. A string of format: "stage_idx,kind" where `stage_idx` specifies a stage to load from
                #        and `kind` is either "final" or "best", specifying whether to load final or best
                #        hypothesis from that stage.
                #     3. A path on the filesystem that gives a prefix for a tensorflow checkpoint file to load from.
                #
                # Then cfg.load_path can either be a path_specification itself, in which case all variables
                # in the network will be loaded from that path_specification, or a dictionary mapping from
                # variable scope names to path specifications, in which case all variables in each supplied
                # variable scope name will be loaded from the path_specification paired with that scope name.
                load_path = cfg.load_path
                if load_path is not None:
                    if isinstance(load_path, str) or isinstance(load_path, int):
                        load_path = {"": load_path}

                    load_path = dict(load_path)

                    # Sort in increasing order, so that it if one variable scope lies within another scope,
                    # the outer scope gets loaded before the inner scope, rather than having the outer scope
                    # wipe out the inner scope.
                    items = sorted(load_path.items())

                    for var_scope, path in items:
                        variables = {v.name: v for v in trainable_variables(var_scope, for_opt=False)}
                        if not variables:
                            print("No variables to load in scope {}.".format(str(var_scope)))
                            continue

                        saver = tf.train.Saver(variables)

                        load_stage, kind = None, None

                        if isinstance(path, int):
                            load_stage = path
                            kind = "best"
                        elif isinstance(path, str):
                            try:
                                split = path.split(',')
                                load_stage = int(split[0])
                                kind = 'best' if len(split) > 1 else split[1]
                                assert kind in 'best final'.split(), "path={}".format(path)
                            except Exception:
                                load_stage, kind = None, None

                        if load_stage is not None:
                            if stage_idx == 0:
                                print(
                                    "Not loading var scope \"{}\" from stage {}, "
                                    "currently in stage 0.".format(var_scope, load_stage))
                                continue
                            else:
                                key = kind + '_path'
                                completed_history = self.data.history[:-1]
                                path = completed_history[load_stage][key]

                        path = os.path.realpath(path)

                        saver.restore(tf.get_default_session(), path)

                        print("Loading var scope \"{}\" from {}.".format(var_scope, path))
                else:
                    print("Using a fresh set of weights, not loading anything.")

                tf.train.get_or_create_global_step()
                sess.run(uninitialized_variables_initializer())
                sess.run(tf.assert_variables_initialized())

                for hook in cfg.hooks:
                    assert isinstance(hook, Hook)
                    hook.start_stage(self, updater, stage_idx)

                threshold_reached = False
                reason = None

                try:
                    # --------------- Run stage -------------------

                    start = time.time()
                    phys_memory_before = memory_usage(physical=True)
                    gpu_memory_before = gpu_memory_usage()

                    threshold_reached, reason = self._run_stage(stage_idx, updater)

                except KeyboardInterrupt:
                    reason = "User interrupt"

                except NotImplementedError as e:
                    # There is a bug in pdb_postmortem that prevents instances of `NotImplementedError`
                    # from being handled properly, so replace it with an instance of `Exception`.
                    if cfg.robust:
                        traceback.print_exc()
                        reason = "Exception occurred ({})".format(repr(e))
                    else:
                        raise Exception("NotImplemented") from e

                except Exception as e:
                    reason = "Exception occurred ({})".format(repr(e))
                    if cfg.robust:
                        traceback.print_exc()
                    else:
                        raise

                except Alarm:
                    reason = "Time limit exceeded"
                    raise

                finally:
                    phys_memory_after = memory_usage(physical=True)
                    gpu_memory_after = gpu_memory_usage()

                    self.data.record_values_for_stage(
                        stage_duration=time.time()-start,
                        phys_memory_before_mb=phys_memory_before,
                        phys_memory_delta_mb=phys_memory_after - phys_memory_before,
                        gpu_memory_before_mb=gpu_memory_before,
                        gpu_memory_delta_mb=gpu_memory_after - gpu_memory_before
                    )

                    self.data.record_values_for_stage(reason=reason)

                    print("\n" + "-" * 10 + " Optimization complete " + "-" * 10)
                    print("\nReason: {}.\n".format(reason))

                    final_path = self.data.path_for('weights/final_for_stage_{}'.format(stage_idx))
                    final_path = cfg.get('save_path', final_path)
                    final_path = updater.save(tf.get_default_session(), final_path)
                    self.data.record_values_for_stage(final_path=final_path)

                    # --------------- Maybe render performance of best hypothesis -------------------

                    do_final_testing = (
                        "Exception occurred" not in reason
                        and reason != "Time limit exceeded"
                        and 'best_path' in self.data.current_stage_record)

                    if do_final_testing:
                        try:
                            print("\n" + "-" * 10 + " Final testing/rendering " + "-" * 10)

                            print("Best hypothesis for this stage was found on "
                                  "step (l: {best_local_step}, g: {best_global_step}) "
                                  "with stopping criteria ({sc_name}) of {best_stopping_criteria}.".format(
                                      sc_name=self.stopping_criteria_name, **self.data.current_stage_record))

                            best_path = self.data.current_stage_record['best_path']
                            print("Loading best hypothesis for this stage "
                                  "from file {}...".format(best_path))
                            updater.restore(sess, best_path)

                            test_record = updater.evaluate(cfg.batch_size, mode="test")

                            for hook in cfg.hooks:
                                if hook.call_per_timestep and hook.final:
                                    hook_record = hook.step(self, updater)

                                    if hook_record:
                                        assert len(hook_record) == 1
                                        for k, d in dict(hook_record).items():
                                            test_record.update(d)

                            self.data.record_values_for_stage(
                                **{'_test_' + k: v for k, v in test_record.items()})

                            if cfg.render_step > 0 and cfg.render_hook is not None:
                                print("Rendering...")
                                cfg.render_hook(updater)
                                print("Done rendering.")

                        except BaseException:
                            print("Exception occurred while performing final testing/rendering: ")
                            traceback.print_exc()

                    else:
                        print("\n" + "-" * 10 + " Skipping final testing/rendering " + "-" * 10)

                    # --------------- Finish up the stage -------------------

                    self.data.end_stage(updater.n_updates)

                    print("\n" + "-" * 10 + " Running end-of-stage hooks " + "-" * 10 + "\n")
                    for hook in cfg.hooks:
                        hook.end_stage(self, stage_idx)

                    print()
                    self.timestamp("Done stage {}".format(stage_idx))
                    print("=" * 50)

                    stage_idx += 1
                    self.curriculum_complete.append(stage_config)

                if not (threshold_reached or cfg.power_through):
                    print("Failed to reach stopping criteria threshold on stage {} "
                          "of the curriculum, terminating.".format(stage_idx))
                    break
Exemple #5
0
 def trainable_variables(self, for_opt):
     return trainable_variables(self.f.scope, for_opt=for_opt)