Example #1
0
    def _jit_safe_call(model, *named_args):
        if (not hasattr(model, "input_device")
                or not hasattr(model, "output_device")):
            # try to automatically determine the input & output
            # device of the model
            model_type = type(model)
            device = determine_device(model)
            if len(device) > 1:
                raise RuntimeError(
                    "Failed to automatically determine i/o device "
                    "of your model: {}\n"
                    "Detected multiple devices: {}\n"
                    "You need to manually specify i/o device of "
                    "your model.\n"
                    "Wrap your model of type nn.Module with one "
                    "of: \n"
                    "1. static_module_wrapper "
                    "from machin.model.nets.base \n"
                    "1. dynamic_module_wrapper "
                    "from machin.model.nets.base \n"
                    "Or construct your own module & model with: \n"
                    "NeuralNetworkModule from machin.model.nets.base".format(
                        model_type, device))
            else:
                # assume that i/o devices are the same as parameter device
                # print a warning
                default_logger.warning(
                    "You have not specified the i/o device of"
                    "your model {}, automatically determined and"
                    " set to: {}\n"
                    "The framework is not responsible for any "
                    "un-matching device issues caused by this"
                    "operation.".format(model_type, device[0]))
                model = static_module_wrapper(model, device[0], device[0])
        input_device = model.input_device
        # set in __init__
        args = model.arg_spec.args[1:] + model.arg_spec.kwonlyargs
        if model.arg_spec.defaults is not None:
            args_with_defaults = args[-len(model.arg_spec.defaults):]
        else:
            args_with_defaults = []
        required_args = (set(args) - set(args_with_defaults) -
                         set(model.arg_spec.kwonlydefaults.keys() if model.
                             arg_spec.kwonlydefaults is not None else []))
        model_type = model.model_type
        # t.jit._fork does not support keyword args
        # fill arguments in by their positions.
        args_list = [None for _ in args]
        args_filled = [False for _ in args]

        for na in named_args:
            for k, v in na.items():
                if k in args:
                    if k not in args:
                        pass
                    args_filled[args.index(k)] = True
                    if t.is_tensor(v):
                        args_list[args.index(k)] = v.to(input_device)
                    else:
                        args_list[args.index(k)] = v

        if not all(args_filled):
            not_filled = [
                arg for filled, arg in zip(args_filled, args) if not filled
            ]
            required_not_filled = set(not_filled).intersection(required_args)
            if len(required_not_filled) > 0:
                raise RuntimeError("\n"
                                   "The signature of the forward function "
                                   "of Model {} is {}\n"
                                   "Missing required arguments: {}, "
                                   "check your storage functions.".format(
                                       model_type, required_args,
                                       required_not_filled))

        return t.jit._fork(model, *args_list)
Example #2
0
def safe_call(model, *named_args):
    """
    Call a model and discard unnecessary arguments. safe_call will automatically
    move tensors in named_args to the input device of the model

    Any input tensor in named_args must not be contained inside any container,
    such as list, dict, tuple, etc. Because they will be automatically moved
    to the input device of the specified model.

    Args:
        model: Model to be called, must be a wrapped nn.Module or an instance of
               NeuralNetworkModule.
        named_args: A dictionary of argument, key is argument's name, value is
                    argument's value.

    Returns:
        Whatever returned by your module. If result is not a tuple, always
        wrap results inside a tuple
    """
    org_model = None
    if isinstance(
            model,
        (nn.parallel.DistributedDataParallel, nn.parallel.DataParallel)):
        org_model = model
        model = model.module
    if (not hasattr(model, "input_device")
            or not hasattr(model, "output_device")):
        # try to automatically determine the input & output device of the model
        model_type = type(model)
        device = determine_device(model)
        if len(device) > 1:
            raise RuntimeError(
                "Failed to automatically determine i/o device "
                "of your model: {}\n"
                "Detected multiple devices: {}\n"
                "You need to manually specify i/o device of "
                "your model.\n"
                "Wrap your model of type nn.Module with one "
                "of: \n"
                "1. static_module_wrapper "
                "from machin.model.nets.base \n"
                "1. dynamic_module_wrapper "
                "from machin.model.nets.base \n"
                "Or construct your own module & model with: \n"
                "NeuralNetworkModule from machin.model.nets.base".format(
                    model_type, device))
        else:
            # assume that i/o devices are the same as parameter device
            # print a warning
            default_logger.warning(
                "You have not specified the i/o device of "
                "your model {}, automatically determined and"
                " set to: {}\n"
                "The framework is not responsible for any "
                "un-matching device issues caused by this "
                "operation.".format(model_type, device[0]))
            model = static_module_wrapper(model, device[0], device[0])

    input_device = model.input_device
    arg_spec = inspect.getfullargspec(model.forward)
    # exclude self in arg_spec.args
    args = arg_spec.args[1:] + arg_spec.kwonlyargs
    if arg_spec.defaults is not None:
        args_with_defaults = args[-len(arg_spec.defaults):]
    else:
        args_with_defaults = []
    required_args = (set(args) - set(args_with_defaults) -
                     set(arg_spec.kwonlydefaults.keys() if arg_spec.
                         kwonlydefaults is not None else []))
    args_dict = {}

    # fill in args
    for na in named_args:
        for k, v in na.items():
            if k in args:
                if torch.is_tensor(v):
                    args_dict[k] = v.to(input_device)
                else:
                    args_dict[k] = v

    # check for necessary args
    missing = required_args - set(args_dict.keys())
    if len(missing) > 0:
        raise RuntimeError("\n"
                           "The signature of the forward function of Model {} "
                           "is {}\n"
                           "Missing required arguments: {}, "
                           "check your storage functions.".format(
                               type(model), required_args, missing))

    if org_model is not None:
        result = org_model(**args_dict)
    else:
        result = model(**args_dict)

    if isinstance(result, tuple):
        return result
    else:
        return (result, )