def _jit_safe_call(model, *named_args): if (not hasattr(model, "input_device") or not hasattr(model, "output_device")): # try to automatically determine the input & output # device of the model model_type = type(model) device = determine_device(model) if len(device) > 1: raise RuntimeError( "Failed to automatically determine i/o device " "of your model: {}\n" "Detected multiple devices: {}\n" "You need to manually specify i/o device of " "your model.\n" "Wrap your model of type nn.Module with one " "of: \n" "1. static_module_wrapper " "from machin.model.nets.base \n" "1. dynamic_module_wrapper " "from machin.model.nets.base \n" "Or construct your own module & model with: \n" "NeuralNetworkModule from machin.model.nets.base".format( model_type, device)) else: # assume that i/o devices are the same as parameter device # print a warning default_logger.warning( "You have not specified the i/o device of" "your model {}, automatically determined and" " set to: {}\n" "The framework is not responsible for any " "un-matching device issues caused by this" "operation.".format(model_type, device[0])) model = static_module_wrapper(model, device[0], device[0]) input_device = model.input_device # set in __init__ args = model.arg_spec.args[1:] + model.arg_spec.kwonlyargs if model.arg_spec.defaults is not None: args_with_defaults = args[-len(model.arg_spec.defaults):] else: args_with_defaults = [] required_args = (set(args) - set(args_with_defaults) - set(model.arg_spec.kwonlydefaults.keys() if model. arg_spec.kwonlydefaults is not None else [])) model_type = model.model_type # t.jit._fork does not support keyword args # fill arguments in by their positions. args_list = [None for _ in args] args_filled = [False for _ in args] for na in named_args: for k, v in na.items(): if k in args: if k not in args: pass args_filled[args.index(k)] = True if t.is_tensor(v): args_list[args.index(k)] = v.to(input_device) else: args_list[args.index(k)] = v if not all(args_filled): not_filled = [ arg for filled, arg in zip(args_filled, args) if not filled ] required_not_filled = set(not_filled).intersection(required_args) if len(required_not_filled) > 0: raise RuntimeError("\n" "The signature of the forward function " "of Model {} is {}\n" "Missing required arguments: {}, " "check your storage functions.".format( model_type, required_args, required_not_filled)) return t.jit._fork(model, *args_list)
def safe_call(model, *named_args): """ Call a model and discard unnecessary arguments. safe_call will automatically move tensors in named_args to the input device of the model Any input tensor in named_args must not be contained inside any container, such as list, dict, tuple, etc. Because they will be automatically moved to the input device of the specified model. Args: model: Model to be called, must be a wrapped nn.Module or an instance of NeuralNetworkModule. named_args: A dictionary of argument, key is argument's name, value is argument's value. Returns: Whatever returned by your module. If result is not a tuple, always wrap results inside a tuple """ org_model = None if isinstance( model, (nn.parallel.DistributedDataParallel, nn.parallel.DataParallel)): org_model = model model = model.module if (not hasattr(model, "input_device") or not hasattr(model, "output_device")): # try to automatically determine the input & output device of the model model_type = type(model) device = determine_device(model) if len(device) > 1: raise RuntimeError( "Failed to automatically determine i/o device " "of your model: {}\n" "Detected multiple devices: {}\n" "You need to manually specify i/o device of " "your model.\n" "Wrap your model of type nn.Module with one " "of: \n" "1. static_module_wrapper " "from machin.model.nets.base \n" "1. dynamic_module_wrapper " "from machin.model.nets.base \n" "Or construct your own module & model with: \n" "NeuralNetworkModule from machin.model.nets.base".format( model_type, device)) else: # assume that i/o devices are the same as parameter device # print a warning default_logger.warning( "You have not specified the i/o device of " "your model {}, automatically determined and" " set to: {}\n" "The framework is not responsible for any " "un-matching device issues caused by this " "operation.".format(model_type, device[0])) model = static_module_wrapper(model, device[0], device[0]) input_device = model.input_device arg_spec = inspect.getfullargspec(model.forward) # exclude self in arg_spec.args args = arg_spec.args[1:] + arg_spec.kwonlyargs if arg_spec.defaults is not None: args_with_defaults = args[-len(arg_spec.defaults):] else: args_with_defaults = [] required_args = (set(args) - set(args_with_defaults) - set(arg_spec.kwonlydefaults.keys() if arg_spec. kwonlydefaults is not None else [])) args_dict = {} # fill in args for na in named_args: for k, v in na.items(): if k in args: if torch.is_tensor(v): args_dict[k] = v.to(input_device) else: args_dict[k] = v # check for necessary args missing = required_args - set(args_dict.keys()) if len(missing) > 0: raise RuntimeError("\n" "The signature of the forward function of Model {} " "is {}\n" "Missing required arguments: {}, " "check your storage functions.".format( type(model), required_args, missing)) if org_model is not None: result = org_model(**args_dict) else: result = model(**args_dict) if isinstance(result, tuple): return result else: return (result, )