def bufferize(worker: AbstractWorker, script_module: torch.jit.ScriptFunction) -> ScriptFunctionPB: """ This method serializes a torch.jit.ScriptFunction into a ScriptFunctionPB. Args: script_module (torch.jit.ScriptFunction): input torch.jit.ScriptFunction to be serialized. Returns: protobuf_script (ScriptFunctionPB): serialized torch.jit.ScriptFunction. """ protobuf_script = ScriptFunctionPB() protobuf_script.obj = script_module.save_to_buffer() return protobuf_script
def accelerator_transformerLayers_inputs( model: nn.Module, trace: torch.jit.ScriptFunction, export_options: ExportConfig, dataset_iterable: Iterable, module_path, ): import torch_glow # we use the padding control from the Export Config: if export_options is None: export_options = ExportConfig() if export_options.seq_padding_control is None: raise RuntimeError("seq padding control not specified") if export_options.batch_padding_control is None: raise RuntimeError("batch padding control not specified") batch_padding_control = export_options.batch_padding_control # Restrict seq_padding_control to valid ranges seq_padding_control = [] max_seq_len = trace.get_max_seq_len() for pad in export_options.seq_padding_control: if pad < max_seq_len: seq_padding_control.append(pad) seq_padding_control.append(max_seq_len) # this should use a method, or module_path, instead of being hardcoded # embedding_dim = model.encoder.encoder.transformer.token_embedding.embedding_dim embedding_dim = accelerator.get_embedding_module_from_path( model, module_path) input_examples = [] for seq_len in seq_padding_control: if seq_len <= 0: continue for batch_size in batch_padding_control: if batch_size <= 0: continue # Todo: We directly generate data input instead of using dataset_iterable, enhance later input1 = torch.randn([seq_len, batch_size, embedding_dim], dtype=torch.float32) input2 = torch.randn([batch_size, seq_len]).bool() input_specs = torch_glow.input_specs_from_tensors([input1, input2]) input_examples.append(input_specs) return input_examples
def get_seq_and_batch_padding_control( trace: torch.jit.ScriptFunction, export_options: ExportConfig ): # we use the padding control from the Export Config: if export_options is None: export_options = ExportConfig() if export_options.seq_padding_control is None: raise RuntimeError("seq padding control not specified") if export_options.batch_padding_control is None: raise RuntimeError("batch padding control not specified") batch_padding_control = export_options.batch_padding_control # Restrict seq_padding_control to valid ranges seq_padding_control = [] max_seq_len = trace.get_max_seq_len() for pad in export_options.seq_padding_control: if pad < max_seq_len: seq_padding_control.append(pad) seq_padding_control.append(max_seq_len) return seq_padding_control, batch_padding_control
def _bufferize_script_function( worker: AbstractWorker, script_module: torch.jit.ScriptFunction) -> ScriptFunctionPB: protobuf_script = ScriptFunctionPB() protobuf_script.obj = script_module.save_to_buffer() return protobuf_script