def _train_epoch( train_device: torch.device, model: torch.jit.ScriptModule, ddpmodel: ModelWrapperForDDP, model_path: Path, optim: torch.optim.Optimizer, assembler: tube.ChannelAssembler, stat: utils.MultiCounter, epoch: int, optim_params: OptimParams, sync_period: int, ) -> None: global _train_epoch_waiting_time pre_num_add = assembler.buffer_num_add() pre_num_sample = assembler.buffer_num_sample() sync_s = 0. num_sync = 0 t = time.time() time.sleep(_train_epoch_waiting_time) lossmodel = DDPWrapperForModel(ddpmodel) if ddpmodel is not None else model for eid in range(optim_params.epoch_len): batch = assembler.sample(optim_params.batchsize) batch = utils.to_device(batch, train_device) loss = model.loss(lossmodel, batch["s"], batch["v"], batch["pi"], batch["pi_mask"], stat) loss.backward() grad_norm = nn.utils.clip_grad_norm_(model.parameters(), optim_params.grad_clip) optim.step() optim.zero_grad() if (epoch * optim_params.epoch_len + eid + 1) % sync_period == 0: sync_t0 = time.time() assembler.update_model(model.state_dict()) sync_s += time.time() - sync_t0 num_sync += 1 stat["loss"].feed(loss.detach().item()) stat["grad_norm"].feed(grad_norm) post_num_add = assembler.buffer_num_add() post_num_sample = assembler.buffer_num_sample() time_elapsed = time.time() - t delta_add = post_num_add - pre_num_add print("buffer add rate: %.2f / s" % (delta_add / time_elapsed)) delta_sample = post_num_sample - pre_num_sample if delta_sample > 8 * delta_add: # If the sample rate is not at least 8x the add rate, everything is fine. _train_epoch_waiting_time += time_elapsed else: _train_epoch_waiting_time = 0 print("buffer sample rate: %.2f / s" % (delta_sample / time_elapsed)) print( f"syncing duration: {sync_s:2f}s for {num_sync} syncs ({int(100 * sync_s / time_elapsed)}% of train time)" ) stat.summary(epoch) stat.reset()
def _save_load_mobile_module(self, script_module: torch.jit.ScriptModule): buffer = io.BytesIO( script_module._save_to_buffer_for_lite_interpreter( _save_mobile_debug_info=True)) buffer.seek(0) mobile_module = _load_for_lite_interpreter(buffer) return mobile_module
def _get_bundled_inputs_preserved_attributes( script_module: torch.jit.ScriptModule, preserved_methods: List[str]) -> List[str]: bundled_inputs_attributes = [] # Has bundled inputs for forward if hasattr(script_module, 'get_all_bundled_inputs'): bundled_inputs_attributes.append('get_all_bundled_inputs') bundled_inputs_attributes.append('get_num_bundled_inputs') bundled_inputs_attributes.append('run_on_bundled_input') # Bundled inputs in module after the change that introduced bundled inputs for multiple functions if hasattr(script_module, 'get_bundled_inputs_functions_and_info'): bundled_inputs_attributes.append( 'get_bundled_inputs_functions_and_info') all_info = script_module.get_bundled_inputs_functions_and_info() for function_name in all_info: if function_name not in preserved_methods: bundled_inputs_attributes.append(function_name) bundled_inputs_attributes.append("get_all_bundled_inputs_for_" + function_name) bundled_inputs_attributes.append("_bundled_inputs_deflated_" + function_name) return bundled_inputs_attributes
def generate_mobile_module_lints(script_module: torch.jit.ScriptModule): """ Args: script_module: An instance of torch script module with type of ScriptModule Returns: lint_map: A list of dictionary that contains modules lints """ if not isinstance(script_module, torch.jit.ScriptModule): raise TypeError( 'Got {}, but ScriptModule is expected.'.format(type(script_module))) lint_list = [] if not hasattr(script_module, "_generate_bundled_inputs"): lint_list.append({"name": LintCode.BUNDLED_INPUT.name, "message": "No bundled input, please add bundled inputs before " "saving the module using torch.utils.bundled_inputs.augment_model_with_bundled_inputs."}) for name, param in script_module.named_parameters(): if param.requires_grad: lint_list.append({"name": LintCode.REQUIRES_GRAD.name, "message": "Param {} requires grad, " "please set torch.no_grad() to reduce memory usage and improve computation speed during " "inference phase.".format(name)}) op_names = torch.jit.export_opnames(script_module) for op_name in op_names: if "dropout" in op_name: lint_list.append({"name": LintCode.DROPOUT.name, "message": "Operator {} exists, remember to call eval() before " "saving the module.".format(op_name)}) if "batch_norm" in op_name: lint_list.append({"name": LintCode.BATCHNORM.name, "message": "Operator {} exists, remember to call eval() before " "saving the module and call torch.utils.mobile_optimizer.optimize_for_mobile to drop batch_norm " "operator.".format(op_name)}) return lint_list
def local_energy( state: torch.jit.ScriptModule, hamiltonian: _C.Heisenberg, spins: np.ndarray, log_values: Optional[np.ndarray] = None, batch_size: int = 128, ) -> np.ndarray: r"""Computes local estimators ⟨σ|H|ψ⟩/⟨σ|ψ⟩ for all σ. :param state: wavefunction ``ψ``. ``state`` should be a function mapping ``R^{batch_size x in_features}`` to ``R^{batch_size x 2}``. Columns of the output are interpreted as real and imaginary parts of ``log(⟨σ|ψ⟩)``. :param hamiltonian: Hamiltonian ``H``. :param spins: spin configurations ``σ``. Should be a non-empty NumPy array of py:class:`CompactSpin`. :param log_values: pre-computed ``log(⟨σ|ψ⟩)``. Should be a NumPy array of ``complex64``. :param batch_size: batch size to use for forward propagation through ``state``. :return: local energies ⟨σ|H|ψ⟩/⟨σ|ψ⟩ as a NumPy array of ``complex64``. """ with torch.no_grad(): with torch.jit.optimized_execution(True): if log_values is None: log_values = _forward_with_batches(state, spins, batch_size) log_values = log_values.numpy().view(np.complex64) # Since torch.jit.ScriptModules can't be directly passed to C++ # code as torch::jit::script::Modules, we first save ψ to a # temporary file and then load it back in C++ code. with tempfile.NamedTemporaryFile(delete=False) as f: filename = f.name try: state.save(filename) log_H_values = (_C.PolynomialState( _C.Polynomial(hamiltonian, [0.0]), filename, (batch_size, len(_C.unsafe_get(spins, 0))), )(spins).numpy().view(np.complex64)) finally: os.remove(filename) return np.exp(log_H_values - log_values).squeeze(axis=1)
def create_optimizer( model: torch.jit.ScriptModule, optim_params: OptimParams, optim_state_dict: Optional[dict] = None, ) -> torch.optim.Optimizer: optim = torch.optim.Adam( model.parameters(), lr=optim_params.lr, eps=optim_params.eps ) if optim_state_dict is not None: optim.load_state_dict(optim_state_dict) return optim
def save_checkpoint( command_history: CommandHistory, epoch: int, model: torch.jit.ScriptModule, optim: torch.optim.Optimizer, game_params: GameParams, model_params: ModelParams, optim_params: OptimParams, simulation_params: SimulationParams, execution_params: ExecutionParams, executor: ThreadPoolExecutor = None, ) -> None: checkpoint_dir = execution_params.checkpoint_dir save_uncompressed = execution_params.save_uncompressed checkpoint_name = f"checkpoint_{epoch}" checkpoint = { "command_history": command_history, "epoch": epoch, "model_state_dict": { k: v.cpu().clone() if isinstance(v, torch.Tensor) else copy.deepcopy(v) for k, v in model.state_dict().items() }, "optim_state_dict": { k: v.cpu().clone() if isinstance(v, torch.Tensor) else copy.deepcopy(v) for k, v in optim.state_dict().items() }, "game_params": game_params, "model_params": model_params, "optim_params": optim_params, "simulation_params": simulation_params, "execution_params": execution_params, } def saveit(): nonlocal save_uncompressed nonlocal checkpoint nonlocal checkpoint_dir if save_uncompressed: torch.save(checkpoint, checkpoint_dir / f"{checkpoint_name}.pt") else: # with zipfile.ZipFile(Path(checkpoint_dir) / f"{checkpoint_name}.zip", "w", allowZip64=True) as z: # with z.open(f"{checkpoint_name}.pt", "w", force_zip64=True) as f: # torch.save(checkpoint, f) with gzip.open(checkpoint_dir / f"{checkpoint_name}.pt.gz", "wb") as f: torch.save(checkpoint, f) if executor is not None: return executor.submit(saveit) else: saveit()
def bufferize(worker: AbstractWorker, script_module: torch.jit.ScriptModule) -> ScriptModulePB: """ This method serializes a torch.jit.ScriptModule using ScriptModulePB. Args: script_module (torch.jit.ScriptModule): input jit.ScriptModule to be serialized. Returns: protobuf_script (ScriptModulePB): serialized jit.ScriptModule. """ protobuf_script = ScriptModulePB() protobuf_script.obj = script_module.save_to_buffer() return protobuf_script
def create_optimizer( model: torch.jit.ScriptModule, optim_params: OptimParams, optim_state_dict: Optional[dict] = None, ) -> torch.optim.Optimizer: optim = torch.optim.Adam(model.parameters(), lr=optim_params.lr, eps=optim_params.eps) if optim_state_dict is not None and not optim_params.reset_optimizer_state: try: optim.load_state_dict(optim_state_dict) except ValueError: print("Optimizer state not compatible... skipping.") return optim
def _get_bundled_inputs_attributes_and_methods( script_module: torch.jit.ScriptModule) -> Tuple[List[str], List[str]]: methods: List[str] = [] attributes: List[str] = [] # Has bundled inputs for forward if hasattr(script_module, 'get_all_bundled_inputs'): methods.append('get_all_bundled_inputs') methods.append('get_num_bundled_inputs') methods.append('run_on_bundled_input') if hasattr(script_module, 'get_bundled_inputs_functions_and_info'): methods.append('get_bundled_inputs_functions_and_info') all_info = script_module.get_bundled_inputs_functions_and_info() for function_name in all_info: methods.append("get_all_bundled_inputs_for_" + function_name) methods.append("_generate_bundled_inputs_for_" + function_name) attributes.append("_bundled_inputs_deflated_" + function_name) return (methods, attributes)
def _get_bundled_inputs_preserved_attributes(script_module: torch.jit.ScriptModule, preserved_methods: List[str]) -> List[str]: # Technically it is possible that if a function only bundles inputs for functions besides forward that these wont exist. # Haven't seen a reason for that to be a valid usecase yet so not going to account for it bundled_inputs_attributes = [ 'get_all_bundled_inputs', 'get_num_bundled_inputs', 'run_on_bundled_input', ] if hasattr(script_module, 'get_bundled_inputs_functions_and_info'): bundled_inputs_attributes.append('get_bundled_inputs_functions_and_info') all_info = script_module.get_bundled_inputs_functions_and_info() for function_name in all_info: if function_name not in preserved_methods: bundled_inputs_attributes.append(function_name) bundled_inputs_attributes.append("get_all_bundled_inputs_for_" + function_name) bundled_inputs_attributes.append("_bundled_inputs_deflated_" + function_name) return bundled_inputs_attributes
def save_checkpoint( command_history: CommandHistory, epoch: int, model: torch.jit.ScriptModule, optim: torch.optim.Optimizer, assembler: tube.ChannelAssembler, game_params: GameParams, model_params: ModelParams, optim_params: OptimParams, simulation_params: SimulationParams, execution_params: ExecutionParams, ) -> None: checkpoint_dir = execution_params.checkpoint_dir save_uncompressed = execution_params.save_uncompressed do_not_save_replay_buffer = execution_params.do_not_save_replay_buffer checkpoint_name = f"checkpoint_{epoch}" checkpoint = { "command_history": command_history, "epoch": epoch, "model_state_dict": model.state_dict(), "optim_state_dict": optim.state_dict(), "game_params": game_params, "model_params": model_params, "optim_params": optim_params, "simulation_params": simulation_params, "execution_params": execution_params, } if not do_not_save_replay_buffer: checkpoint.update({"replay_buffer": assembler.buffer}) if save_uncompressed: torch.save(checkpoint, checkpoint_dir / f"{checkpoint_name}.pt") else: # with zipfile.ZipFile(Path(checkpoint_dir) / f"{checkpoint_name}.zip", "w", allowZip64=True) as z: # with z.open(f"{checkpoint_name}.pt", "w", force_zip64=True) as f: # torch.save(checkpoint, f) with gzip.open(checkpoint_dir / f"temp_{checkpoint_name}.pt.gz", "wb") as f: torch.save(checkpoint, f) os.rename(checkpoint_dir / f"temp_{checkpoint_name}.pt.gz", checkpoint_dir / f"{checkpoint_name}.pt.gz")
def _get_bundled_inputs_attributes_and_methods(script_module: torch.jit.ScriptModule) -> Tuple[List[str], List[str]]: methods: List[str] = [] attributes: List[str] = [] # Has bundled inputs for forward if hasattr(script_module, 'get_all_bundled_inputs'): methods.append('get_all_bundled_inputs') methods.append('get_num_bundled_inputs') methods.append('run_on_bundled_input') if hasattr(script_module, 'get_bundled_inputs_functions_and_info'): methods.append('get_bundled_inputs_functions_and_info') all_info = script_module.get_bundled_inputs_functions_and_info() for function_name in all_info: methods.append("get_all_bundled_inputs_for_" + function_name) methods.append("_generate_bundled_inputs_for_" + function_name) attributes.append("_bundled_inputs_deflated_" + function_name) bundled_inputs_fn = getattr( script_module, f"get_all_bundled_inputs_for_{function_name}" ) num_bundled_inputs: int = len(bundled_inputs_fn()) # Check inflate helper functions for each function, argument and bundled input func = getattr(script_module, function_name, None) for arg_idx in range(len(func.schema.arguments) - 1): for input_idx in range(num_bundled_inputs): helper_fn_name = _get_inflate_helper_fn_name( arg_idx=arg_idx, input_idx=input_idx, function_name=function_name ) # if the arg has an InflatableArg with fmt_fn, add the helper function name if hasattr(script_module, helper_fn_name): methods.append(helper_fn_name) return (methods, attributes)
def augment_many_model_functions_with_bundled_inputs( model: torch.jit.ScriptModule, inputs: Dict[Callable, Optional[Sequence[Tuple[Any, ...]]]], _receive_inflate_expr: Optional[List[str]] = None, # For debugging. info: Optional[Dict[Callable, List[ str]]] = None, # Optional argument to provide info about the function or its inputs skip_size_check=False, ) -> None: """Add bundled sample inputs to a model for an arbitrary list of public functions. Models with bundled inputs can be invoked in a uniform manner by benchmarking and code coverage tools. Augmented models will support the following methods: `get_all_bundled_inputs_for_<function_name>() -> List[Tuple[Any, ...]]` Returns a list of tuples suitable for passing to the model like `for inp in model.get_all_bundled_inputs_for_foo(): model.foo(*inp)` `get_bundled_inputs_functions_and_info() -> Dict[str, Dict[str: List[str]]]` Returns a dictionary mapping function names to a metadata dictionary. This nested dictionary maps preset strings like: 'get_inputs_function_name' -> the name of a function attribute in this model that can be run to get back a list of inputs corresponding to that function. 'info' -> the user provided extra information about the bundled inputs If forward has bundled inputs then these following functions are also defined: `get_all_bundled_inputs() -> List[Tuple[Any, ...]]` Returns a list of tuples suitable for passing to the model like `for inp in model.get_all_bundled_inputs(): model(*inp)` `get_num_bundled_inputs() -> int` Equivalent to `len(model.get_all_bundled_inputs())`, but slightly easier to call from C++. Inputs can be specified in one of two ways: - The model can define `_generate_bundled_inputs_for_<function_name>`. If the user chooses this method inputs[<function>] should map to None - The `inputs` argument to this function can be a dictionary mapping functions to a list of inputs, of the same form that will be returned by get_all_bundled_inputs_for_<function_name>. The type of the inputs is List[Tuple[Any, ...]]. The outer list corresponds with a list of inputs, the inner tuple is the list of args that together make up one input. For inputs of functions that take one arg, this will be a tuple of length one. The Any, ... is the actual data that makes up the args, e.g. a tensor. Info is an optional parameter that maps functions to a list of strings providing extra information about that function's bundled inputs. This could be descriptions, expected outputs, etc. - Ex: info={model.forward : ['man eating icecream', 'an airplane', 'a dog']} This function will attempt to optimize arguments so that (e.g.) arguments like `torch.zeros(1000)` will be represented compactly. Only top-level arguments will be optimized. Tensors in lists or tuples will not. """ if not isinstance(model, torch.jit.ScriptModule): raise Exception("Only ScriptModule is supported.") if not inputs: raise Exception("Please provide inputs for at least 1 function") if hasattr(model, "get_all_bundled_inputs") or hasattr( model, "get_bundled_inputs_functions_and_info"): raise Exception( "Models can only be augmented with bundled inputs once. " "This Model seems to have already been augmented with " "bundled inputs. Please start afresh with one that " "doesn't have bundled inputs.", ) get_bundled_inputs_functions_and_info_template = "" for function, input_list in inputs.items(): if hasattr(function, "__name__"): function_name = function.__name__ else: if hasattr(function, "name"): function_name = function.name # type: ignore[attr-defined] else: raise Exception( 'At least one of your functions has no attribute name please ensure all have one. m.foo.name = "foo"' ) if input_list is not None and not isinstance(input_list, Sequence): raise TypeError( "Error inputs for function {0} is not a Sequence".format( function_name)) function_arg_types = [ arg.type for arg in function.schema.arguments[1:] ] # type: ignore[attr-defined] deflated_inputs_type: ListType = ListType( TupleType(function_arg_types)) model._c._register_attribute( "_bundled_inputs_deflated_{name}".format(name=function_name), deflated_inputs_type, []) if hasattr(model, "_generate_bundled_inputs_for_" + function_name): if input_list is not None: raise Exception( "inputs[{name}] is not None, but _generate_bundled_inputs_for_{name} is already defined" .format(name=function_name)) # Model author already defined _generate_bundled_inputs_for_<function_name>. elif input_list is None or len(input_list) == 0: raise Exception( "inputs for {name} must be specified if _generate_bundled_inputs_for_{name} is not already defined" .format(name=function_name, )) else: # Iterate over the inputs and args in each input. # Accumulate `deflated_inputs` as (possibly) compressed values # and `parts` to be joined into the expression that unpacks them. deflated_inputs = [] parts = [] for inp_idx, args in enumerate(input_list): if not isinstance(args, Tuple) and not isinstance( args, List): # type: ignore[arg-type] raise TypeError( "Error bundled input for function {0} idx: {1} is not a Tuple or a List" .format(function_name, inp_idx)) deflated_args = [] parts.append("(") for arg_idx, arg in enumerate(args): inflate_helper_fn_name = _get_inflate_helper_fn_name( arg_idx, inp_idx, function_name) deflated, inflater, helper_definition = _inflate_expr( arg, f"deflated[{inp_idx}][{arg_idx}]", inflate_helper_fn_name, skip_size_check=skip_size_check, ) deflated_args.append(deflated) parts.append(f" {inflater},") if helper_definition: model.define(textwrap.dedent(helper_definition)) deflated_inputs.append(tuple(deflated_args)) parts.append("),") parts.append("") expr = "\n".join(parts) # Back-channel return this expr for debugging. if _receive_inflate_expr is not None: _receive_inflate_expr.append(expr) setattr( model, "_bundled_inputs_deflated_{name}".format(name=function_name), deflated_inputs) definition = textwrap.dedent(""" def _generate_bundled_inputs_for_{name}(self): deflated = self._bundled_inputs_deflated_{name} return [ {expr} ] """).format(expr=expr, name=function_name) model.define(definition) # Define get_all_bundled_inputs_for_<function_name> that caches the generated inputs. model.define( textwrap.dedent(""" def get_all_bundled_inputs_for_{name}(self): all_inputs = self._generate_bundled_inputs_for_{name}() assert all_inputs is not None return all_inputs """).format(name=function_name)) # Add to the high level helper methods inputs_info = repr( info[function]) if info and function in info else '[]' get_bundled_inputs_functions_and_info_template += """ temp_dict : Dict[str,List[str]] = {{}} info: List[str] = {info} temp_dict['info'] = info temp_dict['get_inputs_function_name'] = ['get_all_bundled_inputs_for_{name}'] all_inputs['{name}'] = temp_dict """.format( name=function_name, info=inputs_info, ) # To ensure backwards compatibility and a streamlined api for forward these wrappers are provided if function_name == 'forward': model.define( textwrap.dedent(""" def get_all_bundled_inputs(self): return self.get_all_bundled_inputs_for_forward() """)) model.define( textwrap.dedent(""" def get_num_bundled_inputs(self): return len(self.get_all_bundled_inputs_for_forward()) """)) # Define some high level helper methods that act on all bundled inputs model.define( textwrap.dedent(""" def get_bundled_inputs_functions_and_info(self): all_inputs : Dict[str, Dict[str,List[str]]] = {{}} {template} return all_inputs """.format(template=get_bundled_inputs_functions_and_info_template)))
def _simplify_script_module(worker: AbstractWorker, obj: torch.jit.ScriptModule) -> Tuple: """Strategy to serialize a script module using Torch.jit""" return (obj.save_to_buffer(),)
def _bufferize_script_module( worker: AbstractWorker, script_module: torch.jit.ScriptModule) -> ScriptModulePB: protobuf_script = ScriptModulePB() protobuf_script.obj = script_module.save_to_buffer() return protobuf_script
def augment_many_model_functions_with_bundled_inputs( model: torch.jit.ScriptModule, inputs: Dict[Callable, Optional[Sequence[Tuple[Any, ...]]]], _receive_inflate_expr: Optional[List[str]] = None, # For debugging. info: Optional[Dict[Callable, List[ str]]] = None, # Optional argument to provide info about the function or its inputs ) -> None: """Add bundled sample inputs to a model for an arbitrary list of public functions. Models with bundled inputs can be invoked in a uniform manner by benchmarking and code coverage tools. Augmented models will support the following methods: `get_all_bundled_inputs_for_<function_name>() -> List[Tuple[Any, ...]]` Returns a list of tuples suitable for passing to the model like `for inp in model.get_all_bundled_inputs_for_foo(): model.foo(*inp)` `get_bundled_inputs_functions_and_info() -> Dict[str, Dict[str: List[str]]]` Returns a dictionary mapping function names to a metadata dictionary. This nested dictionary maps preset strings like: 'get_inputs_function_name' -> the name of a function attribute in this model that can be run to get back a list of inputs corresponding to that function. 'info' -> the user provided extra information about the bundled inputs If forward has bundled inputs then these following functions are also defined: `get_all_bundled_inputs() -> List[Tuple[Any, ...]]` Returns a list of tuples suitable for passing to the model like `for inp in model.get_all_bundled_inputs(): model(*inp)` `get_num_bundled_inputs() -> int` Equivalent to `len(model.get_all_bundled_inputs())`, but slightly easier to call from C++. `run_on_bundled_input(idx: int) -> Any` Run the model on bundled input number `idx` Inputs can be specified in one of two ways: - The model can define `_generate_bundled_inputs_for_<function_name>` get_all_bundled_inputs will simply call this method and cache the value. If the user chooses this method inputs[<function>] should map to None - The `inputs` argument to this function can be a dictionary mapping functions to a list of tuples, of the same form that will be returned by get_all_bundled_inputs_for_<function_name>. It is highly recommended (though not enforced) that if multiple functions have the same input style, that you create separate bundled inputs for each function. Reusing the same input and bundling it to multiple functions can cause issues with other torch.jit functionality like freeze Info is an optional parameter that maps functions to a list of strings providing extra information about that function's bundled inputs. This could be descriptions, expected outputs, etc. - Ex: info={model.forward : ['man eating icecream', 'an airplane', 'a dog']} This function will attempt to optimize arguments so that (e.g.) arguments like `torch.zeros(1000)` will be represented compactly. Only top-level arguments will be optimized. Tensors in lists or tuples will not. """ if not isinstance(model, torch.jit.ScriptModule): raise Exception("Only ScriptModule is supported.") get_bundled_inputs_functions_and_info_template = "" for function, input_list in inputs.items(): function_name = function.__name__ function_arg_types = [ arg.type for arg in function.schema.arguments[1:] ] # type: ignore deflated_inputs_type: ListType = ListType( TupleType(function_arg_types)) inflated_inputs_type: OptionalType[ListType] = OptionalType( deflated_inputs_type) model._c._register_attribute( "_bundled_inputs_deflated_{name}".format(name=function_name), deflated_inputs_type, []) model._c._register_attribute( "_bundled_inputs_inflated_{name}".format(name=function_name), inflated_inputs_type, None) if hasattr(model, "_generate_bundled_inputs_for_" + function_name): if input_list is not None: raise Exception( "inputs[{name}] is not None, but _generate_bundled_inputs_for_{name} is already defined" .format(name=function_name)) # Model author already defined _generate_bundled_inputs_for_<function_name>. elif input_list is None or len(input_list) == 0: raise Exception( "inputs for {name} must be specified if _generate_bundled_inputs_for_{name} is not already defined" .format(name=function_name, )) else: # Iterate over the inputs and args in each input. # Accumulate `deflated_inputs` as (possibly) compressed values # and `parts` to be joined into the expression that unpacks them. deflated_inputs = [] parts = [] for inp_idx, args in enumerate(input_list): deflated_args = [] parts.append("(") for arg_idx, arg in enumerate(args): deflated, inflater = _inflate_expr( arg, f"deflated[{inp_idx}][{arg_idx}]") deflated_args.append(deflated) parts.append(f" {inflater},") deflated_inputs.append(tuple(deflated_args)) parts.append("),") parts.append("") expr = "\n".join(parts) # Back-channel return this expr for debugging. if _receive_inflate_expr is not None: _receive_inflate_expr.append(expr) model._bundled_inputs_deflated = deflated_inputs setattr( model, "_bundled_inputs_deflated_{name}".format(name=function_name), deflated_inputs) definition = textwrap.dedent(""" def _generate_bundled_inputs_for_{name}(self): deflated = self._bundled_inputs_deflated_{name} return [ {expr} ] """).format(expr=expr, name=function_name) model.define(definition) # Define get_all_bundled_inputs_for_<function_name> that caches the generated inputs. model.define( textwrap.dedent(""" def get_all_bundled_inputs_for_{name}(self): if self._bundled_inputs_inflated_{name} is None: self._bundled_inputs_inflated_{name} = self._generate_bundled_inputs_for_{name}() all_inputs = self._bundled_inputs_inflated_{name} assert all_inputs is not None return all_inputs """).format(name=function_name)) # Add to the high level helper methods inputs_info = repr( info[function]) if info and function in info else '[]' get_bundled_inputs_functions_and_info_template += """ temp_dict : Dict[str,List[str]] = {{}} info: List[str] = {info} temp_dict['info'] = info temp_dict['get_inputs_function_name'] = ['get_all_bundled_inputs_for_{name}'] all_inputs['{name}'] = temp_dict """.format( name=function_name, info=inputs_info, ) # To ensure backwards compatibility and a streamlined api for forward these wrappers are provided if function_name == 'forward': model.define( textwrap.dedent(""" def get_all_bundled_inputs(self): return self.get_all_bundled_inputs_for_forward() """)) model.define( textwrap.dedent(""" def get_num_bundled_inputs(self): return len(self.get_all_bundled_inputs_for_forward()) """)) model.define( textwrap.dedent(""" def run_on_bundled_input(self, idx: int): return self(*self.get_all_bundled_inputs()[idx]) """)) # Define some high level helper methods that act on all bundled inputs model.define( textwrap.dedent(""" def get_bundled_inputs_functions_and_info(self): all_inputs : Dict[str, Dict[str,List[str]]] = {{}} {template} return all_inputs """.format(template=get_bundled_inputs_functions_and_info_template)))
def augment_model_with_bundled_inputs( model: torch.jit.ScriptModule, inputs: Optional[List[Tuple[Any, ...]]] = None, _receive_inflate_expr: Optional[List[str]] = None, # For debugging. ) -> None: """Add bundled sample inputs to a model. Models with bundled inputs can be invoked in a uniform manner by benchmarking and code coverage tools. Augmented models will support the following methods: `get_all_bundled_inputs() -> List[Tuple[Any, ...]]` Returns a list of tuples suitable for passing to the model like `for inp in model.get_all_bundled_inputs(): model(*inp)` `get_num_bundled_inputs() -> int` Equivalent to `len(model.get_all_bundled_inputs())`, but slightly easier to call from C++. `run_on_bundled_input(idx: int) -> Any` Run the model on bundled input number `idx` Inputs can be specified in one of two ways: - The model can define `_generate_bundled_inputs` get_all_bundled_inputs will simply call this method and cache the value. - The `inputs` argument to this function can be a list of tuples, of the same form that will be returned by get_all_bundled_inputs. This function will attempt to optimize arguments so that (e.g.) arguments like `torch.zeros(1000)` will be represented compactly. Only top-level arguments will be optimized. Tensors in lists or tuples will not. """ if not isinstance(model, torch.jit.ScriptModule): raise Exception("Only ScriptModule is supported.") forward_arg_types = [arg.type for arg in model.forward.schema.arguments[1:]] deflated_inputs_type = torch._C.ListType(torch._C.TupleType(forward_arg_types)) inflated_inputs_type = torch._C.OptionalType(deflated_inputs_type) model._c._register_attribute("_bundled_inputs_deflated", deflated_inputs_type, []) model._c._register_attribute("_bundled_inputs_inflated", inflated_inputs_type, None) if hasattr(model, "_generate_bundled_inputs"): if inputs is not None: raise Exception( "inputs is not None, but _generate_bundled_inputs is already defined") # Model author already defined _generate_bundled_inputs. elif inputs is None: raise Exception( "inputs must be specified if _generate_bundled_inputs is not already defined") else: # Iterate over the inputs and args in each input. # Accumulate `deflated_inputs` as (possibly) compressed values # and `parts` to be joined into the expression that unpacks them. deflated_inputs = [] parts = [] for inp_idx, args in enumerate(inputs): deflated_args = [] parts.append("(") for arg_idx, arg in enumerate(args): deflated, inflater = _inflate_expr(arg, f"deflated[{inp_idx}][{arg_idx}]") deflated_args.append(deflated) parts.append(f" {inflater},") deflated_inputs.append(tuple(deflated_args)) parts.append("),") parts.append("") expr = "\n".join(parts) # Back-channel return this expr for debugging. if _receive_inflate_expr is not None: _receive_inflate_expr.append(expr) model._bundled_inputs_deflated = deflated_inputs definition = textwrap.dedent(""" def _generate_bundled_inputs(self): deflated = self._bundled_inputs_deflated return [ {} ] """).format(expr) model.define(definition) # Define get_all_bundled_inputs that caches the generated inputs. model.define(textwrap.dedent(""" def get_all_bundled_inputs(self): if self._bundled_inputs_inflated is None: self._bundled_inputs_inflated = self._generate_bundled_inputs() all_inputs = self._bundled_inputs_inflated assert all_inputs is not None return all_inputs """)) # Define some helper methods. model.define(textwrap.dedent(""" def get_num_bundled_inputs(self): return len(self.get_all_bundled_inputs()) """)) model.define(textwrap.dedent(""" def run_on_bundled_input(self, idx: int): return self(*self.get_all_bundled_inputs()[idx]) """))
def train_model( command_history: utils.CommandHistory, start_time: float, model: torch.jit.ScriptModule, device: torch.device, ddpmodel, optim: torch.optim.Optimizer, context: tube.Context, model_manager: polygames.ModelManager, get_train_reward: Callable[[], List[int]], game_params: GameParams, model_params: ModelParams, optim_params: OptimParams, simulation_params: SimulationParams, execution_params: ExecutionParams, epoch: int = 0, ) -> None: info = zutils.get_game_info(game_params) c, h, w = info["feature_size"][:3] rc, rh, rw = info["raw_feature_size"][:3] c_prime, h_prime, w_prime = info["action_size"][:3] predicts = (2 if game_params.predict_end_state else 0) + game_params.predict_n_states batchsizes = { "s": [c, h, w], "v": [3 if getattr(model, "logit_value", False) else 1], "pred_v": [1], "pi": [c_prime, h_prime, w_prime], "pi_mask": [c_prime, h_prime, w_prime] } if game_params.player == "forward": batchsizes["action_pi"] = [c_prime, h_prime, w_prime] if predicts > 0: batchsizes["predict_pi"] = [rc * predicts, rh, rw] batchsizes["predict_pi_mask"] = [rc * predicts, rh, rw] if getattr(model, "rnn_state_shape", None) is not None: batchsizes["rnn_state_mask"] = [1] if execution_params.rnn_seqlen > 0: for k, v in batchsizes.items(): batchsizes[k] = [execution_params.rnn_seqlen, *v] if getattr(model, "rnn_state_shape", None) is not None: batchsizes["rnn_initial_state"] = model.rnn_state_shape rank = 0 if ddpmodel: rank = torch.distributed.get_rank() executor = ThreadPoolExecutor(max_workers=1) savefuture = None stat = utils.MultiCounter(execution_params.checkpoint_dir) max_time = execution_params.max_time init_epoch = epoch while max_time is None or time.time() < start_time + max_time: if epoch - init_epoch >= optim_params.num_epoch: break epoch += 1 if rank == 0 and epoch % execution_params.saving_period == 0: model_manager.add_tournament_model("e%d" % (epoch), model.state_dict()) savestart = time.time() if savefuture is not None: savefuture.result() savefuture = utils.save_checkpoint( command_history=command_history, epoch=epoch, model=model, optim=optim, game_params=game_params, model_params=model_params, optim_params=optim_params, simulation_params=simulation_params, execution_params=execution_params, executor=executor) print("checkpoint saved in %gs" % (time.time() - savestart)) _train_epoch( model=model, device=device, ddpmodel=ddpmodel, batchsizes=batchsizes, optim=optim, model_manager=model_manager, stat=stat, epoch=epoch, optim_params=optim_params, sync_period=simulation_params.sync_period, ) # resource usage stats print("Resource usage:") print(utils.get_res_usage_str()) print("Context stats:") print(context.get_stats_str()) # train result print( ">>>train: epoch: %d, %s" % (epoch, utils.Result(get_train_reward()).log()), flush=True, ) if savefuture is not None: savefuture.result() # checkpoint last state utils.save_checkpoint( command_history=command_history, epoch=epoch, model=model, optim=optim, game_params=game_params, model_params=model_params, optim_params=optim_params, simulation_params=simulation_params, execution_params=execution_params, )
def _train_epoch( model: torch.jit.ScriptModule, device: torch.device, ddpmodel: ModelWrapperForDDP, batchsizes, optim: torch.optim.Optimizer, model_manager: polygames.ModelManager, stat: utils.MultiCounter, epoch: int, optim_params: OptimParams, sync_period: int, ) -> None: global _pre_num_add global _pre_num_sample global _running_add_rate global _running_sample_rate global _last_train_time global _remote_replay_buffer_inited if _pre_num_add is None: pre_num_add = model_manager.buffer_num_add() pre_num_sample = model_manager.buffer_num_sample() else: pre_num_add = _pre_num_add pre_num_sample = _pre_num_sample sync_s = 0. num_sync = 0 train_start_time = time.time() if pre_num_sample > 0: print("sample/add ratio ", float(pre_num_sample) / pre_num_add) if _last_train_time == 0: _last_train_time = time.time() batchsize = optim_params.batchsize lossmodel = DDPWrapperForModel(ddpmodel) if ddpmodel is not None else model lossmodel.train() world_size = 0 rank = 0 if ddpmodel is not None: print("DDP is active") world_size = torch.distributed.get_world_size() rank = torch.distributed.get_rank() print("World size %d, rank %d. Waiting for all processes" % (world_size, rank)) torch.distributed.barrier() print("Synchronizing model") for p in ddpmodel.parameters(): torch.distributed.broadcast(p.data, 0) for p in ddpmodel.buffers(): torch.distributed.broadcast(p.data, 0) print("Synchronized, start training") has_predict = False cpubatch = {} for k, v in batchsizes.items(): sizes = v.copy() sizes.insert(0, batchsize) cpubatch[k] = torch.empty(sizes) if k == "predict_pi": has_predict = True for eid in range(optim_params.epoch_len): while _running_add_rate * 1.25 < _running_sample_rate: print("add rate insufficient, waiting") time.sleep(5) t = time.time() time_elapsed = t - _last_train_time _last_train_time = t alpha = pow(0.99, time_elapsed) post_num_add = model_manager.buffer_num_add() post_num_sample = model_manager.buffer_num_sample() delta_add = post_num_add - pre_num_add delta_sample = post_num_sample - pre_num_sample _running_add_rate = _running_add_rate * alpha + ( delta_add / time_elapsed) * (1 - alpha) _running_sample_rate = _running_sample_rate * alpha + ( delta_sample / time_elapsed) * (1 - alpha) pre_num_add = post_num_add pre_num_sample = post_num_sample print("running add rate: %.2f / s" % (_running_add_rate)) print("running sample rate: %.2f / s" % (_running_sample_rate)) print("current add rate: %.2f / s" % (delta_add / time_elapsed)) print("current sample rate: %.2f / s" % (delta_sample / time_elapsed)) if world_size > 0: batchlist = None if rank == 0: batchlist = {} for k in cpubatch.keys(): batchlist[k] = [] for i in range(world_size): for k, v in model_manager.sample(batchsize).items(): batchlist[k].append(v) for k, v in cpubatch.items(): torch.distributed.scatter(v, batchlist[k] if rank == 0 else None) batch = utils.to_device(cpubatch, device) else: batch = model_manager.sample(batchsize) batch = utils.to_device(batch, device) for k, v in batch.items(): batch[k] = v.detach() loss, v_err, pi_err, predict_err = model_loss.mcts_loss( model, lossmodel, batch) loss.backward() grad_norm = nn.utils.clip_grad_norm_(lossmodel.parameters(), optim_params.grad_clip) optim.step() optim.zero_grad() stat["v_err"].feed(v_err.item()) stat["pi_err"].feed(pi_err.item()) if has_predict: stat["predict_err"].feed(predict_err.item()) stat["loss"].feed(loss.item()) stat["grad_norm"].feed(grad_norm) if (epoch * optim_params.epoch_len + eid + 1) % sync_period == 0: sync_t0 = time.time() model_manager.update_model(model.state_dict()) sync_s += time.time() - sync_t0 num_sync += 1 t = time.time() time_elapsed = t - _last_train_time _last_train_time = t alpha = pow(0.99, time_elapsed) post_num_add = model_manager.buffer_num_add() post_num_sample = model_manager.buffer_num_sample() delta_add = post_num_add - pre_num_add delta_sample = post_num_sample - pre_num_sample _running_add_rate = _running_add_rate * alpha + ( delta_add / time_elapsed) * (1 - alpha) _running_sample_rate = _running_sample_rate * alpha + ( delta_sample / time_elapsed) * (1 - alpha) pre_num_add = post_num_add pre_num_sample = post_num_sample total_time_elapsed = time.time() - train_start_time print("running add rate: %.2f / s" % (_running_add_rate)) print("running sample rate: %.2f / s" % (_running_sample_rate)) print("current add rate: %.2f / s" % (delta_add / time_elapsed)) print("current sample rate: %.2f / s" % (delta_sample / time_elapsed)) print( f"syncing duration: {sync_s:2f}s for {num_sync} syncs ({int(100 * sync_s / total_time_elapsed)}% of train time)" ) _pre_num_add = pre_num_add _pre_num_sample = pre_num_sample stat.summary(epoch) stat.reset()
def _simplify_script_module(obj: torch.jit.ScriptModule) -> str: """Strategy to serialize a script module using Torch.jit""" return obj.save_to_buffer()
def train_model( command_history: utils.CommandHistory, start_time: float, train_device: torch.device, model: torch.jit.ScriptModule, model_path: Path, ddpmodel, optim: torch.optim.Optimizer, context: tube.Context, assembler: tube.ChannelAssembler, get_train_reward: Callable[[], List[int]], game_params: GameParams, model_params: ModelParams, optim_params: OptimParams, simulation_params: SimulationParams, execution_params: ExecutionParams, epoch: int = 0, ) -> None: stat = utils.MultiCounter(execution_params.checkpoint_dir) max_time = execution_params.max_time init_epoch = epoch while max_time is None or time.time() < start_time + max_time: if epoch - init_epoch >= optim_params.num_epoch: break epoch += 1 if not (epoch - init_epoch) % execution_params.saving_period: assembler.add_tournament_model("e%d" % (epoch), model.state_dict()) utils.save_checkpoint( command_history=command_history, epoch=epoch, model=model, optim=optim, assembler=assembler, game_params=game_params, model_params=model_params, optim_params=optim_params, simulation_params=simulation_params, execution_params=execution_params, ) _train_epoch( train_device=train_device, model=model, ddpmodel=ddpmodel, model_path=model_path, optim=optim, assembler=assembler, stat=stat, epoch=epoch, optim_params=optim_params, sync_period=simulation_params.sync_period, ) # resource usage stats print("Resource usage:") print(utils.get_res_usage_str()) print("Context stats:") print(context.get_stats_str()) # train result print( ">>>train: epoch: %d, %s" % (epoch, utils.Result(get_train_reward()).log()), flush=True, ) # checkpoint last state utils.save_checkpoint( command_history=command_history, epoch=epoch, model=model, optim=optim, assembler=assembler, game_params=game_params, model_params=model_params, optim_params=optim_params, simulation_params=simulation_params, execution_params=execution_params, )