Exemplo n.º 1
0
def _train_epoch(
    train_device: torch.device,
    model: torch.jit.ScriptModule,
    ddpmodel: ModelWrapperForDDP,
    model_path: Path,
    optim: torch.optim.Optimizer,
    assembler: tube.ChannelAssembler,
    stat: utils.MultiCounter,
    epoch: int,
    optim_params: OptimParams,
    sync_period: int,
) -> None:
    global _train_epoch_waiting_time
    pre_num_add = assembler.buffer_num_add()
    pre_num_sample = assembler.buffer_num_sample()
    sync_s = 0.
    num_sync = 0
    t = time.time()
    time.sleep(_train_epoch_waiting_time)
    lossmodel = DDPWrapperForModel(ddpmodel) if ddpmodel is not None else model
    for eid in range(optim_params.epoch_len):
        batch = assembler.sample(optim_params.batchsize)
        batch = utils.to_device(batch, train_device)
        loss = model.loss(lossmodel, batch["s"], batch["v"], batch["pi"],
                          batch["pi_mask"], stat)
        loss.backward()
        grad_norm = nn.utils.clip_grad_norm_(model.parameters(),
                                             optim_params.grad_clip)
        optim.step()
        optim.zero_grad()

        if (epoch * optim_params.epoch_len + eid + 1) % sync_period == 0:
            sync_t0 = time.time()
            assembler.update_model(model.state_dict())
            sync_s += time.time() - sync_t0
            num_sync += 1

        stat["loss"].feed(loss.detach().item())
        stat["grad_norm"].feed(grad_norm)

    post_num_add = assembler.buffer_num_add()
    post_num_sample = assembler.buffer_num_sample()
    time_elapsed = time.time() - t
    delta_add = post_num_add - pre_num_add
    print("buffer add rate: %.2f / s" % (delta_add / time_elapsed))
    delta_sample = post_num_sample - pre_num_sample
    if delta_sample > 8 * delta_add:  # If the sample rate is not at least 8x the add rate, everything is fine.
        _train_epoch_waiting_time += time_elapsed
    else:
        _train_epoch_waiting_time = 0
    print("buffer sample rate: %.2f / s" % (delta_sample / time_elapsed))
    print(
        f"syncing duration: {sync_s:2f}s for {num_sync} syncs ({int(100 * sync_s / time_elapsed)}% of train time)"
    )

    stat.summary(epoch)
    stat.reset()
Exemplo n.º 2
0
 def _save_load_mobile_module(self, script_module: torch.jit.ScriptModule):
     buffer = io.BytesIO(
         script_module._save_to_buffer_for_lite_interpreter(
             _save_mobile_debug_info=True))
     buffer.seek(0)
     mobile_module = _load_for_lite_interpreter(buffer)
     return mobile_module
Exemplo n.º 3
0
def _get_bundled_inputs_preserved_attributes(
        script_module: torch.jit.ScriptModule,
        preserved_methods: List[str]) -> List[str]:

    bundled_inputs_attributes = []
    # Has bundled inputs for forward
    if hasattr(script_module, 'get_all_bundled_inputs'):
        bundled_inputs_attributes.append('get_all_bundled_inputs')
        bundled_inputs_attributes.append('get_num_bundled_inputs')
        bundled_inputs_attributes.append('run_on_bundled_input')

    # Bundled inputs in module after the change that introduced bundled inputs for multiple functions
    if hasattr(script_module, 'get_bundled_inputs_functions_and_info'):
        bundled_inputs_attributes.append(
            'get_bundled_inputs_functions_and_info')
        all_info = script_module.get_bundled_inputs_functions_and_info()
        for function_name in all_info:
            if function_name not in preserved_methods:
                bundled_inputs_attributes.append(function_name)
            bundled_inputs_attributes.append("get_all_bundled_inputs_for_" +
                                             function_name)
            bundled_inputs_attributes.append("_bundled_inputs_deflated_" +
                                             function_name)

    return bundled_inputs_attributes
Exemplo n.º 4
0
def generate_mobile_module_lints(script_module: torch.jit.ScriptModule):
    """
    Args:
        script_module: An instance of torch script module with type of ScriptModule

    Returns:
        lint_map: A list of dictionary that contains modules lints
    """
    if not isinstance(script_module, torch.jit.ScriptModule):
        raise TypeError(
            'Got {}, but ScriptModule is expected.'.format(type(script_module)))

    lint_list = []

    if not hasattr(script_module, "_generate_bundled_inputs"):
        lint_list.append({"name": LintCode.BUNDLED_INPUT.name, "message": "No bundled input, please add bundled inputs before "
                          "saving the module using torch.utils.bundled_inputs.augment_model_with_bundled_inputs."})

    for name, param in script_module.named_parameters():
        if param.requires_grad:
            lint_list.append({"name": LintCode.REQUIRES_GRAD.name, "message": "Param {} requires grad, "
                             "please set torch.no_grad() to reduce memory usage and improve computation speed during "
                              "inference phase.".format(name)})

    op_names = torch.jit.export_opnames(script_module)
    for op_name in op_names:
        if "dropout" in op_name:
            lint_list.append({"name": LintCode.DROPOUT.name, "message": "Operator {} exists, remember to call eval() before "
                              "saving the module.".format(op_name)})
        if "batch_norm" in op_name:
            lint_list.append({"name": LintCode.BATCHNORM.name, "message": "Operator {} exists, remember to call eval() before "
                              "saving the module and call torch.utils.mobile_optimizer.optimize_for_mobile to drop batch_norm "
                              "operator.".format(op_name)})

    return lint_list
Exemplo n.º 5
0
def local_energy(
    state: torch.jit.ScriptModule,
    hamiltonian: _C.Heisenberg,
    spins: np.ndarray,
    log_values: Optional[np.ndarray] = None,
    batch_size: int = 128,
) -> np.ndarray:
    r"""Computes local estimators ⟨σ|H|ψ⟩/⟨σ|ψ⟩ for all σ.

    :param state: wavefunction ``ψ``. ``state`` should be a function
        mapping ``R^{batch_size x in_features}`` to ``R^{batch_size x 2}``.
        Columns of the output are interpreted as real and imaginary parts of
        ``log(⟨σ|ψ⟩)``.
    :param hamiltonian: Hamiltonian ``H``.
    :param spins: spin configurations ``σ``. Should be a non-empty NumPy array
        of py:class:`CompactSpin`.
    :param log_values: pre-computed ``log(⟨σ|ψ⟩)``. Should be a NumPy array of
        ``complex64``.
    :param batch_size: batch size to use for forward propagation through
        ``state``.

    :return: local energies ⟨σ|H|ψ⟩/⟨σ|ψ⟩ as a NumPy array of ``complex64``.
    """
    with torch.no_grad():
        with torch.jit.optimized_execution(True):
            if log_values is None:
                log_values = _forward_with_batches(state, spins, batch_size)
                log_values = log_values.numpy().view(np.complex64)

            # Since torch.jit.ScriptModules can't be directly passed to C++
            # code as torch::jit::script::Modules, we first save ψ to a
            # temporary file and then load it back in C++ code.
            with tempfile.NamedTemporaryFile(delete=False) as f:
                filename = f.name
            try:
                state.save(filename)
                log_H_values = (_C.PolynomialState(
                    _C.Polynomial(hamiltonian, [0.0]),
                    filename,
                    (batch_size, len(_C.unsafe_get(spins, 0))),
                )(spins).numpy().view(np.complex64))
            finally:
                os.remove(filename)
            return np.exp(log_H_values - log_values).squeeze(axis=1)
Exemplo n.º 6
0
def create_optimizer(
    model: torch.jit.ScriptModule,
    optim_params: OptimParams,
    optim_state_dict: Optional[dict] = None,
) -> torch.optim.Optimizer:
    optim = torch.optim.Adam(
        model.parameters(), lr=optim_params.lr, eps=optim_params.eps
    )
    if optim_state_dict is not None:
        optim.load_state_dict(optim_state_dict)
    return optim
Exemplo n.º 7
0
def save_checkpoint(
    command_history: CommandHistory,
    epoch: int,
    model: torch.jit.ScriptModule,
    optim: torch.optim.Optimizer,
    game_params: GameParams,
    model_params: ModelParams,
    optim_params: OptimParams,
    simulation_params: SimulationParams,
    execution_params: ExecutionParams,
    executor: ThreadPoolExecutor = None,
) -> None:
    checkpoint_dir = execution_params.checkpoint_dir
    save_uncompressed = execution_params.save_uncompressed
    checkpoint_name = f"checkpoint_{epoch}"
    checkpoint = {
        "command_history": command_history,
        "epoch": epoch,
        "model_state_dict": {
            k: v.cpu().clone()
            if isinstance(v, torch.Tensor) else copy.deepcopy(v)
            for k, v in model.state_dict().items()
        },
        "optim_state_dict": {
            k: v.cpu().clone()
            if isinstance(v, torch.Tensor) else copy.deepcopy(v)
            for k, v in optim.state_dict().items()
        },
        "game_params": game_params,
        "model_params": model_params,
        "optim_params": optim_params,
        "simulation_params": simulation_params,
        "execution_params": execution_params,
    }

    def saveit():
        nonlocal save_uncompressed
        nonlocal checkpoint
        nonlocal checkpoint_dir
        if save_uncompressed:
            torch.save(checkpoint, checkpoint_dir / f"{checkpoint_name}.pt")
        else:
            # with zipfile.ZipFile(Path(checkpoint_dir) / f"{checkpoint_name}.zip", "w", allowZip64=True) as z:
            #    with z.open(f"{checkpoint_name}.pt", "w", force_zip64=True) as f:
            #        torch.save(checkpoint, f)
            with gzip.open(checkpoint_dir / f"{checkpoint_name}.pt.gz",
                           "wb") as f:
                torch.save(checkpoint, f)

    if executor is not None:
        return executor.submit(saveit)
    else:
        saveit()
Exemplo n.º 8
0
    def bufferize(worker: AbstractWorker, script_module: torch.jit.ScriptModule) -> ScriptModulePB:
        """
            This method serializes a torch.jit.ScriptModule using ScriptModulePB.

            Args:
                script_module (torch.jit.ScriptModule): input jit.ScriptModule to be serialized.

            Returns:
                protobuf_script (ScriptModulePB): serialized jit.ScriptModule.
        """
        protobuf_script = ScriptModulePB()
        protobuf_script.obj = script_module.save_to_buffer()
        return protobuf_script
Exemplo n.º 9
0
def create_optimizer(
    model: torch.jit.ScriptModule,
    optim_params: OptimParams,
    optim_state_dict: Optional[dict] = None,
) -> torch.optim.Optimizer:
    optim = torch.optim.Adam(model.parameters(),
                             lr=optim_params.lr,
                             eps=optim_params.eps)
    if optim_state_dict is not None and not optim_params.reset_optimizer_state:
        try:
            optim.load_state_dict(optim_state_dict)
        except ValueError:
            print("Optimizer state not compatible... skipping.")
    return optim
Exemplo n.º 10
0
def _get_bundled_inputs_attributes_and_methods(
        script_module: torch.jit.ScriptModule) -> Tuple[List[str], List[str]]:
    methods: List[str] = []
    attributes: List[str] = []

    # Has bundled inputs for forward
    if hasattr(script_module, 'get_all_bundled_inputs'):
        methods.append('get_all_bundled_inputs')
        methods.append('get_num_bundled_inputs')
        methods.append('run_on_bundled_input')

    if hasattr(script_module, 'get_bundled_inputs_functions_and_info'):
        methods.append('get_bundled_inputs_functions_and_info')
        all_info = script_module.get_bundled_inputs_functions_and_info()
        for function_name in all_info:
            methods.append("get_all_bundled_inputs_for_" + function_name)
            methods.append("_generate_bundled_inputs_for_" + function_name)
            attributes.append("_bundled_inputs_deflated_" + function_name)
    return (methods, attributes)
Exemplo n.º 11
0
def _get_bundled_inputs_preserved_attributes(script_module: torch.jit.ScriptModule, preserved_methods: List[str]) -> List[str]:

    # Technically it is possible that if a function only bundles inputs for functions besides forward that these wont exist.
    # Haven't seen a reason for that to be a valid usecase yet so not going to account for it
    bundled_inputs_attributes = [
        'get_all_bundled_inputs',
        'get_num_bundled_inputs',
        'run_on_bundled_input',
    ]
    if hasattr(script_module, 'get_bundled_inputs_functions_and_info'):
        bundled_inputs_attributes.append('get_bundled_inputs_functions_and_info')
        all_info = script_module.get_bundled_inputs_functions_and_info()
        for function_name in all_info:
            if function_name not in preserved_methods:
                bundled_inputs_attributes.append(function_name)
            bundled_inputs_attributes.append("get_all_bundled_inputs_for_" + function_name)
            bundled_inputs_attributes.append("_bundled_inputs_deflated_" + function_name)

    return bundled_inputs_attributes
Exemplo n.º 12
0
def save_checkpoint(
    command_history: CommandHistory,
    epoch: int,
    model: torch.jit.ScriptModule,
    optim: torch.optim.Optimizer,
    assembler: tube.ChannelAssembler,
    game_params: GameParams,
    model_params: ModelParams,
    optim_params: OptimParams,
    simulation_params: SimulationParams,
    execution_params: ExecutionParams,
) -> None:
    checkpoint_dir = execution_params.checkpoint_dir
    save_uncompressed = execution_params.save_uncompressed
    do_not_save_replay_buffer = execution_params.do_not_save_replay_buffer
    checkpoint_name = f"checkpoint_{epoch}"
    checkpoint = {
        "command_history": command_history,
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optim_state_dict": optim.state_dict(),
        "game_params": game_params,
        "model_params": model_params,
        "optim_params": optim_params,
        "simulation_params": simulation_params,
        "execution_params": execution_params,
    }
    if not do_not_save_replay_buffer:
        checkpoint.update({"replay_buffer": assembler.buffer})

    if save_uncompressed:
        torch.save(checkpoint, checkpoint_dir / f"{checkpoint_name}.pt")
    else:
        # with zipfile.ZipFile(Path(checkpoint_dir) / f"{checkpoint_name}.zip", "w", allowZip64=True) as z:
        #    with z.open(f"{checkpoint_name}.pt", "w", force_zip64=True) as f:
        #        torch.save(checkpoint, f)
        with gzip.open(checkpoint_dir / f"temp_{checkpoint_name}.pt.gz",
                       "wb") as f:
            torch.save(checkpoint, f)
        os.rename(checkpoint_dir / f"temp_{checkpoint_name}.pt.gz",
                  checkpoint_dir / f"{checkpoint_name}.pt.gz")
Exemplo n.º 13
0
def _get_bundled_inputs_attributes_and_methods(script_module: torch.jit.ScriptModule) -> Tuple[List[str], List[str]]:
    methods: List[str] = []
    attributes: List[str] = []

    # Has bundled inputs for forward
    if hasattr(script_module, 'get_all_bundled_inputs'):
        methods.append('get_all_bundled_inputs')
        methods.append('get_num_bundled_inputs')
        methods.append('run_on_bundled_input')

    if hasattr(script_module, 'get_bundled_inputs_functions_and_info'):
        methods.append('get_bundled_inputs_functions_and_info')
        all_info = script_module.get_bundled_inputs_functions_and_info()
        for function_name in all_info:
            methods.append("get_all_bundled_inputs_for_" + function_name)
            methods.append("_generate_bundled_inputs_for_" + function_name)
            attributes.append("_bundled_inputs_deflated_" + function_name)

            bundled_inputs_fn = getattr(
                script_module,
                f"get_all_bundled_inputs_for_{function_name}"
            )
            num_bundled_inputs: int = len(bundled_inputs_fn())

            # Check inflate helper functions for each function, argument and bundled input
            func = getattr(script_module, function_name, None)
            for arg_idx in range(len(func.schema.arguments) - 1):
                for input_idx in range(num_bundled_inputs):
                    helper_fn_name = _get_inflate_helper_fn_name(
                        arg_idx=arg_idx,
                        input_idx=input_idx,
                        function_name=function_name
                    )
                    # if the arg has an InflatableArg with fmt_fn, add the helper function name
                    if hasattr(script_module, helper_fn_name):
                        methods.append(helper_fn_name)

    return (methods, attributes)
Exemplo n.º 14
0
def augment_many_model_functions_with_bundled_inputs(
    model: torch.jit.ScriptModule,
    inputs: Dict[Callable, Optional[Sequence[Tuple[Any, ...]]]],
    _receive_inflate_expr: Optional[List[str]] = None,  # For debugging.
    info: Optional[Dict[Callable, List[
        str]]] = None,  # Optional argument to provide info about the function or its inputs
    skip_size_check=False,
) -> None:
    """Add bundled sample inputs to a model for an arbitrary list of public functions.

    Models with bundled inputs can be invoked in a uniform manner by
    benchmarking and code coverage tools.

    Augmented models will support the following methods:

        `get_all_bundled_inputs_for_<function_name>() -> List[Tuple[Any, ...]]`
            Returns a list of tuples suitable for passing to the model like
            `for inp in model.get_all_bundled_inputs_for_foo(): model.foo(*inp)`

        `get_bundled_inputs_functions_and_info() -> Dict[str, Dict[str: List[str]]]`
            Returns a dictionary mapping function names to a metadata dictionary.
            This nested dictionary maps preset strings like:
                'get_inputs_function_name' -> the name of a function attribute in this model that can be
                    run to get back a list of inputs corresponding to that function.
                'info' -> the user provided extra information about the bundled inputs

    If forward has bundled inputs then these following functions are also defined:

        `get_all_bundled_inputs() -> List[Tuple[Any, ...]]`
            Returns a list of tuples suitable for passing to the model like
            `for inp in model.get_all_bundled_inputs(): model(*inp)`

        `get_num_bundled_inputs() -> int`
            Equivalent to `len(model.get_all_bundled_inputs())`,
            but slightly easier to call from C++.

    Inputs can be specified in one of two ways:

      - The model can define `_generate_bundled_inputs_for_<function_name>`.
        If the user chooses this method inputs[<function>] should map to None

      - The `inputs` argument to this function can be a dictionary mapping functions to a
        list of inputs, of the same form that will be returned by get_all_bundled_inputs_for_<function_name>.
        The type of the inputs is List[Tuple[Any, ...]]. The outer list corresponds with a
        list of inputs, the inner tuple is the list of args that together make up one input.
        For inputs of functions that take one arg, this will be a tuple of length one. The Any, ...
        is the actual data that makes up the args, e.g. a tensor.

    Info is an optional parameter that maps functions to a list of strings providing extra information about that
    function's bundled inputs. This could be descriptions, expected outputs, etc.
        - Ex: info={model.forward : ['man eating icecream', 'an airplane', 'a dog']}

    This function will attempt to optimize arguments so that (e.g.)
    arguments like `torch.zeros(1000)` will be represented compactly.
    Only top-level arguments will be optimized.
    Tensors in lists or tuples will not.
    """
    if not isinstance(model, torch.jit.ScriptModule):
        raise Exception("Only ScriptModule is supported.")

    if not inputs:
        raise Exception("Please provide inputs for at least 1 function")

    if hasattr(model, "get_all_bundled_inputs") or hasattr(
            model, "get_bundled_inputs_functions_and_info"):
        raise Exception(
            "Models can only be augmented with bundled inputs once. "
            "This Model seems to have already been augmented with "
            "bundled inputs. Please start afresh with one that "
            "doesn't have bundled inputs.", )

    get_bundled_inputs_functions_and_info_template = ""

    for function, input_list in inputs.items():
        if hasattr(function, "__name__"):
            function_name = function.__name__
        else:
            if hasattr(function, "name"):
                function_name = function.name  # type: ignore[attr-defined]
            else:
                raise Exception(
                    'At least one of your functions has no attribute name please ensure all have one. m.foo.name = "foo"'
                )

        if input_list is not None and not isinstance(input_list, Sequence):
            raise TypeError(
                "Error inputs for function {0} is not a Sequence".format(
                    function_name))

        function_arg_types = [
            arg.type for arg in function.schema.arguments[1:]
        ]  # type: ignore[attr-defined]
        deflated_inputs_type: ListType = ListType(
            TupleType(function_arg_types))
        model._c._register_attribute(
            "_bundled_inputs_deflated_{name}".format(name=function_name),
            deflated_inputs_type, [])

        if hasattr(model, "_generate_bundled_inputs_for_" + function_name):
            if input_list is not None:
                raise Exception(
                    "inputs[{name}] is not None, but _generate_bundled_inputs_for_{name} is already defined"
                    .format(name=function_name))
            # Model author already defined _generate_bundled_inputs_for_<function_name>.
        elif input_list is None or len(input_list) == 0:
            raise Exception(
                "inputs for {name} must be specified if _generate_bundled_inputs_for_{name} is not already defined"
                .format(name=function_name, ))
        else:
            # Iterate over the inputs and args in each input.
            # Accumulate `deflated_inputs` as (possibly) compressed values
            # and `parts` to be joined into the expression that unpacks them.
            deflated_inputs = []
            parts = []
            for inp_idx, args in enumerate(input_list):
                if not isinstance(args, Tuple) and not isinstance(
                        args, List):  # type: ignore[arg-type]
                    raise TypeError(
                        "Error bundled input for function {0} idx: {1} is not a Tuple or a List"
                        .format(function_name, inp_idx))
                deflated_args = []
                parts.append("(")
                for arg_idx, arg in enumerate(args):
                    inflate_helper_fn_name = _get_inflate_helper_fn_name(
                        arg_idx, inp_idx, function_name)
                    deflated, inflater, helper_definition = _inflate_expr(
                        arg,
                        f"deflated[{inp_idx}][{arg_idx}]",
                        inflate_helper_fn_name,
                        skip_size_check=skip_size_check,
                    )
                    deflated_args.append(deflated)
                    parts.append(f"    {inflater},")
                    if helper_definition:
                        model.define(textwrap.dedent(helper_definition))
                deflated_inputs.append(tuple(deflated_args))
                parts.append("),")
            parts.append("")
            expr = "\n".join(parts)

            # Back-channel return this expr for debugging.
            if _receive_inflate_expr is not None:
                _receive_inflate_expr.append(expr)
            setattr(
                model,
                "_bundled_inputs_deflated_{name}".format(name=function_name),
                deflated_inputs)
            definition = textwrap.dedent("""
                def _generate_bundled_inputs_for_{name}(self):
                    deflated = self._bundled_inputs_deflated_{name}
                    return [
                {expr}
                    ]
                """).format(expr=expr, name=function_name)
            model.define(definition)

        # Define get_all_bundled_inputs_for_<function_name> that caches the generated inputs.
        model.define(
            textwrap.dedent("""
            def get_all_bundled_inputs_for_{name}(self):
                all_inputs = self._generate_bundled_inputs_for_{name}()
                assert all_inputs is not None
                return all_inputs
            """).format(name=function_name))

        # Add to the high level helper methods
        inputs_info = repr(
            info[function]) if info and function in info else '[]'
        get_bundled_inputs_functions_and_info_template += """
            temp_dict : Dict[str,List[str]] = {{}}
            info: List[str] = {info}

            temp_dict['info'] = info
            temp_dict['get_inputs_function_name'] = ['get_all_bundled_inputs_for_{name}']
            all_inputs['{name}'] = temp_dict
            """.format(
            name=function_name,
            info=inputs_info,
        )

        # To ensure backwards compatibility and a streamlined api for forward these wrappers are provided
        if function_name == 'forward':
            model.define(
                textwrap.dedent("""
                def get_all_bundled_inputs(self):
                    return self.get_all_bundled_inputs_for_forward()
                """))
            model.define(
                textwrap.dedent("""
                def get_num_bundled_inputs(self):
                    return len(self.get_all_bundled_inputs_for_forward())
                """))

    # Define some high level helper methods that act on all bundled inputs
    model.define(
        textwrap.dedent("""
        def get_bundled_inputs_functions_and_info(self):
            all_inputs : Dict[str, Dict[str,List[str]]] = {{}}
            {template}
            return all_inputs
        """.format(template=get_bundled_inputs_functions_and_info_template)))
Exemplo n.º 15
0
def _simplify_script_module(worker: AbstractWorker, obj: torch.jit.ScriptModule) -> Tuple:
    """Strategy to serialize a script module using Torch.jit"""
    return (obj.save_to_buffer(),)
Exemplo n.º 16
0
def _bufferize_script_module(
        worker: AbstractWorker,
        script_module: torch.jit.ScriptModule) -> ScriptModulePB:
    protobuf_script = ScriptModulePB()
    protobuf_script.obj = script_module.save_to_buffer()
    return protobuf_script
Exemplo n.º 17
0
def augment_many_model_functions_with_bundled_inputs(
    model: torch.jit.ScriptModule,
    inputs: Dict[Callable, Optional[Sequence[Tuple[Any, ...]]]],
    _receive_inflate_expr: Optional[List[str]] = None,  # For debugging.
    info: Optional[Dict[Callable, List[
        str]]] = None,  # Optional argument to provide info about the function or its inputs
) -> None:
    """Add bundled sample inputs to a model for an arbitrary list of public functions.

    Models with bundled inputs can be invoked in a uniform manner by
    benchmarking and code coverage tools.

    Augmented models will support the following methods:

        `get_all_bundled_inputs_for_<function_name>() -> List[Tuple[Any, ...]]`
            Returns a list of tuples suitable for passing to the model like
            `for inp in model.get_all_bundled_inputs_for_foo(): model.foo(*inp)`

        `get_bundled_inputs_functions_and_info() -> Dict[str, Dict[str: List[str]]]`
            Returns a dictionary mapping function names to a metadata dictionary.
            This nested dictionary maps preset strings like:
                'get_inputs_function_name' -> the name of a function attribute in this model that can be
                    run to get back a list of inputs corresponding to that function.
                'info' -> the user provided extra information about the bundled inputs

    If forward has bundled inputs then these following functions are also defined:

        `get_all_bundled_inputs() -> List[Tuple[Any, ...]]`
            Returns a list of tuples suitable for passing to the model like
            `for inp in model.get_all_bundled_inputs(): model(*inp)`

        `get_num_bundled_inputs() -> int`
            Equivalent to `len(model.get_all_bundled_inputs())`,
            but slightly easier to call from C++.

        `run_on_bundled_input(idx: int) -> Any`
            Run the model on bundled input number `idx`

    Inputs can be specified in one of two ways:

      - The model can define `_generate_bundled_inputs_for_<function_name>`
        get_all_bundled_inputs will simply call this method
        and cache the value. If the user chooses this method inputs[<function>]
        should map to None
      - The `inputs` argument to this function can be a dictionary mapping functions to a
        list of tuples, of the same form that will be returned by get_all_bundled_inputs_for_<function_name>.

      It is highly recommended (though not enforced) that if multiple functions have the same input style, that
      you create separate bundled inputs for each function. Reusing the same input and bundling it to multiple
      functions can cause issues with other torch.jit functionality like freeze

    Info is an optional parameter that maps functions to a list of strings providing extra information about that
    function's bundled inputs. This could be descriptions, expected outputs, etc.
        - Ex: info={model.forward : ['man eating icecream', 'an airplane', 'a dog']}

    This function will attempt to optimize arguments so that (e.g.)
    arguments like `torch.zeros(1000)` will be represented compactly.
    Only top-level arguments will be optimized.
    Tensors in lists or tuples will not.
    """
    if not isinstance(model, torch.jit.ScriptModule):
        raise Exception("Only ScriptModule is supported.")

    get_bundled_inputs_functions_and_info_template = ""

    for function, input_list in inputs.items():
        function_name = function.__name__

        function_arg_types = [
            arg.type for arg in function.schema.arguments[1:]
        ]  # type: ignore
        deflated_inputs_type: ListType = ListType(
            TupleType(function_arg_types))
        inflated_inputs_type: OptionalType[ListType] = OptionalType(
            deflated_inputs_type)
        model._c._register_attribute(
            "_bundled_inputs_deflated_{name}".format(name=function_name),
            deflated_inputs_type, [])
        model._c._register_attribute(
            "_bundled_inputs_inflated_{name}".format(name=function_name),
            inflated_inputs_type, None)

        if hasattr(model, "_generate_bundled_inputs_for_" + function_name):
            if input_list is not None:
                raise Exception(
                    "inputs[{name}] is not None, but _generate_bundled_inputs_for_{name} is already defined"
                    .format(name=function_name))
            # Model author already defined _generate_bundled_inputs_for_<function_name>.
        elif input_list is None or len(input_list) == 0:
            raise Exception(
                "inputs for {name} must be specified if _generate_bundled_inputs_for_{name} is not already defined"
                .format(name=function_name, ))
        else:
            # Iterate over the inputs and args in each input.
            # Accumulate `deflated_inputs` as (possibly) compressed values
            # and `parts` to be joined into the expression that unpacks them.
            deflated_inputs = []
            parts = []
            for inp_idx, args in enumerate(input_list):
                deflated_args = []
                parts.append("(")
                for arg_idx, arg in enumerate(args):
                    deflated, inflater = _inflate_expr(
                        arg, f"deflated[{inp_idx}][{arg_idx}]")
                    deflated_args.append(deflated)
                    parts.append(f"    {inflater},")
                deflated_inputs.append(tuple(deflated_args))
                parts.append("),")
            parts.append("")
            expr = "\n".join(parts)
            # Back-channel return this expr for debugging.
            if _receive_inflate_expr is not None:
                _receive_inflate_expr.append(expr)
            model._bundled_inputs_deflated = deflated_inputs
            setattr(
                model,
                "_bundled_inputs_deflated_{name}".format(name=function_name),
                deflated_inputs)
            definition = textwrap.dedent("""
                def _generate_bundled_inputs_for_{name}(self):
                    deflated = self._bundled_inputs_deflated_{name}
                    return [
                {expr}
                    ]
                """).format(expr=expr, name=function_name)
            model.define(definition)

        # Define get_all_bundled_inputs_for_<function_name> that caches the generated inputs.
        model.define(
            textwrap.dedent("""
            def get_all_bundled_inputs_for_{name}(self):
                if self._bundled_inputs_inflated_{name} is None:
                    self._bundled_inputs_inflated_{name} = self._generate_bundled_inputs_for_{name}()
                all_inputs = self._bundled_inputs_inflated_{name}
                assert all_inputs is not None
                return all_inputs
            """).format(name=function_name))

        # Add to the high level helper methods
        inputs_info = repr(
            info[function]) if info and function in info else '[]'
        get_bundled_inputs_functions_and_info_template += """
            temp_dict : Dict[str,List[str]] = {{}}
            info: List[str] = {info}

            temp_dict['info'] = info
            temp_dict['get_inputs_function_name'] = ['get_all_bundled_inputs_for_{name}']
            all_inputs['{name}'] = temp_dict
            """.format(
            name=function_name,
            info=inputs_info,
        )

        # To ensure backwards compatibility and a streamlined api for forward these wrappers are provided
        if function_name == 'forward':
            model.define(
                textwrap.dedent("""
                def get_all_bundled_inputs(self):
                    return self.get_all_bundled_inputs_for_forward()
                """))
            model.define(
                textwrap.dedent("""
                def get_num_bundled_inputs(self):
                    return len(self.get_all_bundled_inputs_for_forward())
                """))
            model.define(
                textwrap.dedent("""
                def run_on_bundled_input(self, idx: int):
                    return self(*self.get_all_bundled_inputs()[idx])
                """))

    # Define some high level helper methods that act on all bundled inputs
    model.define(
        textwrap.dedent("""
        def get_bundled_inputs_functions_and_info(self):
            all_inputs : Dict[str, Dict[str,List[str]]] = {{}}
            {template}
            return all_inputs
        """.format(template=get_bundled_inputs_functions_and_info_template)))
Exemplo n.º 18
0
def augment_model_with_bundled_inputs(
        model: torch.jit.ScriptModule,
        inputs: Optional[List[Tuple[Any, ...]]] = None,
        _receive_inflate_expr: Optional[List[str]] = None,  # For debugging.
) -> None:
    """Add bundled sample inputs to a model.

    Models with bundled inputs can be invoked in a uniform manner by
    benchmarking and code coverage tools.

    Augmented models will support the following methods:

      `get_all_bundled_inputs() -> List[Tuple[Any, ...]]`
        Returns a list of tuples suitable for passing to the model like
        `for inp in model.get_all_bundled_inputs(): model(*inp)`

      `get_num_bundled_inputs() -> int`
        Equivalent to `len(model.get_all_bundled_inputs())`,
        but slightly easier to call from C++.

      `run_on_bundled_input(idx: int) -> Any`
        Run the model on bundled input number `idx`

    Inputs can be specified in one of two ways:

      - The model can define `_generate_bundled_inputs`
        get_all_bundled_inputs will simply call this method
        and cache the value.
      - The `inputs` argument to this function can be a list of tuples,
        of the same form that will be returned by get_all_bundled_inputs.
        This function will attempt to optimize arguments so that (e.g.)
        arguments like `torch.zeros(1000)` will be represented compactly.
        Only top-level arguments will be optimized.
        Tensors in lists or tuples will not.
    """
    if not isinstance(model, torch.jit.ScriptModule):
        raise Exception("Only ScriptModule is supported.")

    forward_arg_types = [arg.type for arg in model.forward.schema.arguments[1:]]
    deflated_inputs_type = torch._C.ListType(torch._C.TupleType(forward_arg_types))
    inflated_inputs_type = torch._C.OptionalType(deflated_inputs_type)
    model._c._register_attribute("_bundled_inputs_deflated", deflated_inputs_type, [])
    model._c._register_attribute("_bundled_inputs_inflated", inflated_inputs_type, None)

    if hasattr(model, "_generate_bundled_inputs"):
        if inputs is not None:
            raise Exception(
                "inputs is not None, but _generate_bundled_inputs is already defined")
        # Model author already defined _generate_bundled_inputs.
    elif inputs is None:
        raise Exception(
            "inputs must be specified if _generate_bundled_inputs is not already defined")
    else:
        # Iterate over the inputs and args in each input.
        # Accumulate `deflated_inputs` as (possibly) compressed values
        # and `parts` to be joined into the expression that unpacks them.
        deflated_inputs = []
        parts = []
        for inp_idx, args in enumerate(inputs):
            deflated_args = []
            parts.append("(")
            for arg_idx, arg in enumerate(args):
                deflated, inflater = _inflate_expr(arg, f"deflated[{inp_idx}][{arg_idx}]")
                deflated_args.append(deflated)
                parts.append(f"    {inflater},")
            deflated_inputs.append(tuple(deflated_args))
            parts.append("),")
        parts.append("")
        expr = "\n".join(parts)
        # Back-channel return this expr for debugging.
        if _receive_inflate_expr is not None:
            _receive_inflate_expr.append(expr)
        model._bundled_inputs_deflated = deflated_inputs
        definition = textwrap.dedent("""
            def _generate_bundled_inputs(self):
                deflated = self._bundled_inputs_deflated
                return [
            {}
                ]
            """).format(expr)
        model.define(definition)

    # Define get_all_bundled_inputs that caches the generated inputs.
    model.define(textwrap.dedent("""
        def get_all_bundled_inputs(self):
            if self._bundled_inputs_inflated is None:
                self._bundled_inputs_inflated = self._generate_bundled_inputs()
            all_inputs = self._bundled_inputs_inflated
            assert all_inputs is not None
            return all_inputs
        """))

    # Define some helper methods.
    model.define(textwrap.dedent("""
        def get_num_bundled_inputs(self):
            return len(self.get_all_bundled_inputs())
        """))
    model.define(textwrap.dedent("""
        def run_on_bundled_input(self, idx: int):
            return self(*self.get_all_bundled_inputs()[idx])
        """))
Exemplo n.º 19
0
def train_model(
    command_history: utils.CommandHistory,
    start_time: float,
    model: torch.jit.ScriptModule,
    device: torch.device,
    ddpmodel,
    optim: torch.optim.Optimizer,
    context: tube.Context,
    model_manager: polygames.ModelManager,
    get_train_reward: Callable[[], List[int]],
    game_params: GameParams,
    model_params: ModelParams,
    optim_params: OptimParams,
    simulation_params: SimulationParams,
    execution_params: ExecutionParams,
    epoch: int = 0,
) -> None:

    info = zutils.get_game_info(game_params)
    c, h, w = info["feature_size"][:3]
    rc, rh, rw = info["raw_feature_size"][:3]
    c_prime, h_prime, w_prime = info["action_size"][:3]

    predicts = (2 if game_params.predict_end_state else
                0) + game_params.predict_n_states

    batchsizes = {
        "s": [c, h, w],
        "v": [3 if getattr(model, "logit_value", False) else 1],
        "pred_v": [1],
        "pi": [c_prime, h_prime, w_prime],
        "pi_mask": [c_prime, h_prime, w_prime]
    }

    if game_params.player == "forward":
        batchsizes["action_pi"] = [c_prime, h_prime, w_prime]

    if predicts > 0:
        batchsizes["predict_pi"] = [rc * predicts, rh, rw]
        batchsizes["predict_pi_mask"] = [rc * predicts, rh, rw]

    if getattr(model, "rnn_state_shape", None) is not None:
        batchsizes["rnn_state_mask"] = [1]

    if execution_params.rnn_seqlen > 0:
        for k, v in batchsizes.items():
            batchsizes[k] = [execution_params.rnn_seqlen, *v]

    if getattr(model, "rnn_state_shape", None) is not None:
        batchsizes["rnn_initial_state"] = model.rnn_state_shape

    rank = 0
    if ddpmodel:
        rank = torch.distributed.get_rank()

    executor = ThreadPoolExecutor(max_workers=1)
    savefuture = None

    stat = utils.MultiCounter(execution_params.checkpoint_dir)
    max_time = execution_params.max_time
    init_epoch = epoch
    while max_time is None or time.time() < start_time + max_time:
        if epoch - init_epoch >= optim_params.num_epoch:
            break
        epoch += 1
        if rank == 0 and epoch % execution_params.saving_period == 0:
            model_manager.add_tournament_model("e%d" % (epoch),
                                               model.state_dict())
            savestart = time.time()
            if savefuture is not None:
                savefuture.result()
            savefuture = utils.save_checkpoint(
                command_history=command_history,
                epoch=epoch,
                model=model,
                optim=optim,
                game_params=game_params,
                model_params=model_params,
                optim_params=optim_params,
                simulation_params=simulation_params,
                execution_params=execution_params,
                executor=executor)
            print("checkpoint saved in %gs" % (time.time() - savestart))
        _train_epoch(
            model=model,
            device=device,
            ddpmodel=ddpmodel,
            batchsizes=batchsizes,
            optim=optim,
            model_manager=model_manager,
            stat=stat,
            epoch=epoch,
            optim_params=optim_params,
            sync_period=simulation_params.sync_period,
        )
        # resource usage stats
        print("Resource usage:")
        print(utils.get_res_usage_str())
        print("Context stats:")
        print(context.get_stats_str())
        # train result
        print(
            ">>>train: epoch: %d, %s" %
            (epoch, utils.Result(get_train_reward()).log()),
            flush=True,
        )
    if savefuture is not None:
        savefuture.result()
    # checkpoint last state
    utils.save_checkpoint(
        command_history=command_history,
        epoch=epoch,
        model=model,
        optim=optim,
        game_params=game_params,
        model_params=model_params,
        optim_params=optim_params,
        simulation_params=simulation_params,
        execution_params=execution_params,
    )
Exemplo n.º 20
0
def _train_epoch(
    model: torch.jit.ScriptModule,
    device: torch.device,
    ddpmodel: ModelWrapperForDDP,
    batchsizes,
    optim: torch.optim.Optimizer,
    model_manager: polygames.ModelManager,
    stat: utils.MultiCounter,
    epoch: int,
    optim_params: OptimParams,
    sync_period: int,
) -> None:
    global _pre_num_add
    global _pre_num_sample
    global _running_add_rate
    global _running_sample_rate
    global _last_train_time
    global _remote_replay_buffer_inited
    if _pre_num_add is None:
        pre_num_add = model_manager.buffer_num_add()
        pre_num_sample = model_manager.buffer_num_sample()
    else:
        pre_num_add = _pre_num_add
        pre_num_sample = _pre_num_sample
    sync_s = 0.
    num_sync = 0

    train_start_time = time.time()

    if pre_num_sample > 0:
        print("sample/add ratio ", float(pre_num_sample) / pre_num_add)

    if _last_train_time == 0:
        _last_train_time = time.time()

    batchsize = optim_params.batchsize

    lossmodel = DDPWrapperForModel(ddpmodel) if ddpmodel is not None else model

    lossmodel.train()

    world_size = 0
    rank = 0
    if ddpmodel is not None:
        print("DDP is active")
        world_size = torch.distributed.get_world_size()
        rank = torch.distributed.get_rank()

        print("World size %d, rank %d. Waiting for all processes" %
              (world_size, rank))
        torch.distributed.barrier()
        print("Synchronizing model")
        for p in ddpmodel.parameters():
            torch.distributed.broadcast(p.data, 0)
        for p in ddpmodel.buffers():
            torch.distributed.broadcast(p.data, 0)
        print("Synchronized, start training")

    has_predict = False
    cpubatch = {}
    for k, v in batchsizes.items():
        sizes = v.copy()
        sizes.insert(0, batchsize)
        cpubatch[k] = torch.empty(sizes)
        if k == "predict_pi":
            has_predict = True

    for eid in range(optim_params.epoch_len):
        while _running_add_rate * 1.25 < _running_sample_rate:
            print("add rate insufficient, waiting")
            time.sleep(5)
            t = time.time()
            time_elapsed = t - _last_train_time
            _last_train_time = t
            alpha = pow(0.99, time_elapsed)
            post_num_add = model_manager.buffer_num_add()
            post_num_sample = model_manager.buffer_num_sample()
            delta_add = post_num_add - pre_num_add
            delta_sample = post_num_sample - pre_num_sample
            _running_add_rate = _running_add_rate * alpha + (
                delta_add / time_elapsed) * (1 - alpha)
            _running_sample_rate = _running_sample_rate * alpha + (
                delta_sample / time_elapsed) * (1 - alpha)
            pre_num_add = post_num_add
            pre_num_sample = post_num_sample
            print("running add rate: %.2f / s" % (_running_add_rate))
            print("running sample rate: %.2f / s" % (_running_sample_rate))
            print("current add rate: %.2f / s" % (delta_add / time_elapsed))
            print("current sample rate: %.2f / s" %
                  (delta_sample / time_elapsed))

        if world_size > 0:
            batchlist = None
            if rank == 0:
                batchlist = {}
                for k in cpubatch.keys():
                    batchlist[k] = []
                for i in range(world_size):
                    for k, v in model_manager.sample(batchsize).items():
                        batchlist[k].append(v)
            for k, v in cpubatch.items():
                torch.distributed.scatter(v,
                                          batchlist[k] if rank == 0 else None)
            batch = utils.to_device(cpubatch, device)
        else:
            batch = model_manager.sample(batchsize)
            batch = utils.to_device(batch, device)
        for k, v in batch.items():
            batch[k] = v.detach()
        loss, v_err, pi_err, predict_err = model_loss.mcts_loss(
            model, lossmodel, batch)
        loss.backward()

        grad_norm = nn.utils.clip_grad_norm_(lossmodel.parameters(),
                                             optim_params.grad_clip)
        optim.step()
        optim.zero_grad()

        stat["v_err"].feed(v_err.item())
        stat["pi_err"].feed(pi_err.item())
        if has_predict:
            stat["predict_err"].feed(predict_err.item())
        stat["loss"].feed(loss.item())
        stat["grad_norm"].feed(grad_norm)

        if (epoch * optim_params.epoch_len + eid + 1) % sync_period == 0:
            sync_t0 = time.time()
            model_manager.update_model(model.state_dict())
            sync_s += time.time() - sync_t0
            num_sync += 1

        t = time.time()
        time_elapsed = t - _last_train_time
        _last_train_time = t
        alpha = pow(0.99, time_elapsed)
        post_num_add = model_manager.buffer_num_add()
        post_num_sample = model_manager.buffer_num_sample()
        delta_add = post_num_add - pre_num_add
        delta_sample = post_num_sample - pre_num_sample
        _running_add_rate = _running_add_rate * alpha + (
            delta_add / time_elapsed) * (1 - alpha)
        _running_sample_rate = _running_sample_rate * alpha + (
            delta_sample / time_elapsed) * (1 - alpha)
        pre_num_add = post_num_add
        pre_num_sample = post_num_sample

    total_time_elapsed = time.time() - train_start_time

    print("running add rate: %.2f / s" % (_running_add_rate))
    print("running sample rate: %.2f / s" % (_running_sample_rate))
    print("current add rate: %.2f / s" % (delta_add / time_elapsed))
    print("current sample rate: %.2f / s" % (delta_sample / time_elapsed))
    print(
        f"syncing duration: {sync_s:2f}s for {num_sync} syncs ({int(100 * sync_s / total_time_elapsed)}% of train time)"
    )

    _pre_num_add = pre_num_add
    _pre_num_sample = pre_num_sample

    stat.summary(epoch)
    stat.reset()
Exemplo n.º 21
0
def _simplify_script_module(obj: torch.jit.ScriptModule) -> str:
    """Strategy to serialize a script module using Torch.jit"""
    return obj.save_to_buffer()
Exemplo n.º 22
0
def train_model(
    command_history: utils.CommandHistory,
    start_time: float,
    train_device: torch.device,
    model: torch.jit.ScriptModule,
    model_path: Path,
    ddpmodel,
    optim: torch.optim.Optimizer,
    context: tube.Context,
    assembler: tube.ChannelAssembler,
    get_train_reward: Callable[[], List[int]],
    game_params: GameParams,
    model_params: ModelParams,
    optim_params: OptimParams,
    simulation_params: SimulationParams,
    execution_params: ExecutionParams,
    epoch: int = 0,
) -> None:
    stat = utils.MultiCounter(execution_params.checkpoint_dir)
    max_time = execution_params.max_time
    init_epoch = epoch
    while max_time is None or time.time() < start_time + max_time:
        if epoch - init_epoch >= optim_params.num_epoch:
            break
        epoch += 1
        if not (epoch - init_epoch) % execution_params.saving_period:
            assembler.add_tournament_model("e%d" % (epoch), model.state_dict())
            utils.save_checkpoint(
                command_history=command_history,
                epoch=epoch,
                model=model,
                optim=optim,
                assembler=assembler,
                game_params=game_params,
                model_params=model_params,
                optim_params=optim_params,
                simulation_params=simulation_params,
                execution_params=execution_params,
            )
        _train_epoch(
            train_device=train_device,
            model=model,
            ddpmodel=ddpmodel,
            model_path=model_path,
            optim=optim,
            assembler=assembler,
            stat=stat,
            epoch=epoch,
            optim_params=optim_params,
            sync_period=simulation_params.sync_period,
        )
        # resource usage stats
        print("Resource usage:")
        print(utils.get_res_usage_str())
        print("Context stats:")
        print(context.get_stats_str())
        # train result
        print(
            ">>>train: epoch: %d, %s" %
            (epoch, utils.Result(get_train_reward()).log()),
            flush=True,
        )
    # checkpoint last state
    utils.save_checkpoint(
        command_history=command_history,
        epoch=epoch,
        model=model,
        optim=optim,
        assembler=assembler,
        game_params=game_params,
        model_params=model_params,
        optim_params=optim_params,
        simulation_params=simulation_params,
        execution_params=execution_params,
    )