예제 #1
0
    def get_load_paths(self):

        load_path = cfg.load_path
        _print("\nMaybe loading weights, load_path={} ...".format(load_path))

        if load_path:
            if isinstance(load_path, str) or isinstance(load_path, int):
                load_path = {"": load_path}

            load_path = dict(load_path)

            # Sort in increasing order, so that it if one variable scope lies within another scope,
            # the outer scope gets loaded before the inner scope, rather than having the outer scope
            # wipe out the inner scope.
            items = sorted(load_path.items())
            return items

        else:
            _print("`load_path` is null, using a fresh set of weights.")
            return []
예제 #2
0
파일: train.py 프로젝트: lqiang2003cn/dps
    def framework_load_weights(self):
        for var_scope, path in self.get_load_paths():
            _print("Loading var scope \"{}\" from {}.".format(var_scope, path))

            start = time.time()
            variables = {
                v.name: v
                for v in trainable_variables(var_scope, for_opt=False)
            }
            if not variables:
                _print("No variables to load in scope {}.".format(
                    str(var_scope)))
                continue

            saver = tf.train.Saver(variables)
            saver.restore(tf.get_default_session(), path)

            _print(
                "Done loading var scope, took {} seconds.".format(time.time() -
                                                                  start))
예제 #3
0
    def framework_load_weights(self):
        """
        Adapted from the tensorflow version, roughly treats a pytorch module as equivalant
        to a tensorflow variable scope.

        Most general form a dictionary entry is: {"<dest_module_path>": "<source_module_path>:<file_path>"}
        Maps tensors located at module path `source_module_path` in file `file_path` to module path `dest_module_path`
        in the current model.

        """
        omit_modules = cfg.get('omit_modules_from_loading', [])

        for dest_module_path, path in self.get_load_paths():
            _print("Loading submodule \"{}\" from {}.".format(dest_module_path, path))

            if ":" in path:
                source_module_path, source_path = path.split(':')
            else:
                source_path = path
                source_module_path = dest_module_path

            start = time.time()

            device = get_pytorch_device()

            loaded_state_dict = torch.load(source_path, map_location=device)['model']

            if source_module_path:
                source_module_path_with_sep = source_module_path + '.'

                loaded_state_dict = type(loaded_state_dict)(
                    {k: v for k, v in loaded_state_dict.items() if k.startswith(source_module_path_with_sep)}
                )

                assert loaded_state_dict, (
                    f"File contains no tensors with prefix `{source_module_path_with_sep}` (file: {source_path})"
                )

            if dest_module_path != source_module_path:
                # Rename variables from the loaded state dict by replacing `source_module_path` with `dest_module_path`.

                _source_module_path = source_module_path + '.' if source_module_path else source_module_path
                _dest_module_path = dest_module_path + '.' if dest_module_path else dest_module_path

                loaded_state_dict = {
                    k.replace(_source_module_path, _dest_module_path, 1): v
                    for k, v in loaded_state_dict.items()
                }

            module = self.updater.model

            state_dict = module.state_dict()

            intersection = set(state_dict.keys()) & set(loaded_state_dict.keys())

            if not intersection:
                raise Exception(
                    f"Loading variables with spec ({dest_module_path}, {path}) "
                    f"would have no effect (no variables found)."
                )
            loaded_state_dict = {k: loaded_state_dict[k] for k in intersection}

            if omit_modules:
                omitted_variables = {
                    k: v for k, v in loaded_state_dict.items()
                    if any(k.startswith(o) for o in omit_modules)
                }

                print("Omitting the following variables from loading:")
                describe_structure(omitted_variables)

                loaded_state_dict = {
                    k: v for k, v in loaded_state_dict.items()
                    if k not in omitted_variables
                }

            _print("Loading variables:")
            describe_structure(loaded_state_dict)

            state_dict.update(loaded_state_dict)

            module.load_state_dict(state_dict, strict=True)

            _print("Done loading weights for module {}, took {} seconds.".format(dest_module_path, time.time() - start))
예제 #4
0
    def framework_initialize_stage(self, stack):
        # Set the seed for the stage.
        torch_seed = gen_seed()
        _print("Setting pytorch seed to generated seed: {}\n".format(torch_seed))
        torch.manual_seed(torch_seed)

        torch.backends.cudnn.enabled = True

        torch.backends.cudnn.benchmark = cfg.pytorch_cudnn_benchmark
        torch.backends.cudnn.deterministic = cfg.pytorch_cudnn_deterministic

        if cfg.use_gpu:
            _print("Trying to use GPU...")
            try:
                device = torch.cuda.current_device()
                use_gpu = True
            except AssertionError:
                tb.print_exc()
                use_gpu = False
        else:
            use_gpu = False

        if use_gpu:
            _print("Using GPU.")

            _print("Device count: {}".format(torch.cuda.device_count()))
            _print("Device idx: {}".format(device))
            _print("Device name: {}".format(torch.cuda.get_device_name(device)))
            _print("Device capability: {}".format(torch.cuda.get_device_capability(device)))

            set_pytorch_device('cuda')
        else:
            _print("Not using GPU.")
            set_pytorch_device('cpu')

        torch.set_printoptions(profile='full')
예제 #5
0
파일: train.py 프로젝트: lqiang2003cn/dps
    def framework_initialize_stage(self, stack):
        # Configure and create session and graph for stage.
        session_config = tf.ConfigProto()
        session_config.intra_op_parallelism_threads = cfg.get(
            'intra_op_parallelism_threads', 0)
        session_config.inter_op_parallelism_threads = cfg.get(
            'inter_op_parallelism_threads', 0)
        session_config.log_device_placement = cfg.get('log_device_placement',
                                                      0)

        if cfg.use_gpu:
            per_process_gpu_memory_fraction = getattr(
                cfg, 'per_process_gpu_memory_fraction', None)
            if per_process_gpu_memory_fraction:
                session_config.gpu_options.per_process_gpu_memory_fraction = per_process_gpu_memory_fraction

            gpu_allow_growth = getattr(cfg, 'gpu_allow_growth', None)
            if gpu_allow_growth:
                session_config.gpu_options.allow_growth = gpu_allow_growth

            _print("Using GPU if available.")
            _print("Using {}% of GPU memory.".format(
                100 *
                session_config.gpu_options.per_process_gpu_memory_fraction))
            _print("Allowing growth of GPU memory: {}".format(
                session_config.gpu_options.allow_growth))

        graph = tf.Graph()
        sess = tf.Session(graph=graph, config=session_config)

        # This HAS to come after the creation of the session, otherwise
        # it allocates all GPU memory if using the GPU.
        _print("\nAvailable devices: ")
        from tensorflow.python.client import device_lib
        _print(device_lib.list_local_devices())

        if not cfg.use_gpu:
            _print("Not using GPU.")
            stack.enter_context(graph.device("/cpu:0"))

        stack.enter_context(graph.as_default())
        stack.enter_context(sess)
        stack.enter_context(sess.as_default())

        # Set the seed for the stage.
        tf_seed = gen_seed()
        _print(
            "Setting tensorflow seed to generated seed: {}\n".format(tf_seed))
        tf.set_random_seed(tf_seed)

        tf.logging.set_verbosity(tf.logging.ERROR)