def _test_serialization(self, module, inputs): with TemporaryFileName() as fname: torch.jit.save(module, fname) loaded = torch.jit.load(fname) self.assertEqual( module(*inputs).to_dense(), loaded(*inputs).to_dense())
def getExportImportCopy(self, m, also_test_file=True, map_location=None): buffer = io.BytesIO() torch.jit.save(m, buffer) buffer.seek(0) imported = torch.jit.load(buffer, map_location=map_location) if not also_test_file: return imported with TemporaryFileName() as fname: torch.jit.save(imported, fname) return torch.jit.load(fname, map_location=map_location)
def compare_enabled_disabled(self, src): """ Runs the script in `src` with PYTORCH_JIT enabled and disabled and compares their stdout for equality. """ # Write `src` out to a temporary so our source inspection logic works # correctly. with TemporaryFileName() as fname: with open(fname, 'w') as f: f.write(src) with _jit_disabled(): out_disabled = subprocess.check_output( [sys.executable, fname]) out_enabled = subprocess.check_output([sys.executable, fname]) self.assertEqual(out_disabled, out_enabled)
def getExportImportCopy(self, m, also_test_file=True, map_location=None): if isinstance(m, torch._C.Function): src, constants = _jit_python_print(m) cu = torch.jit.CompilationUnit()._import(src, constants) return getattr(cu, m.name) buffer = io.BytesIO() torch.jit.save(m, buffer) buffer.seek(0) imported = torch.jit.load(buffer, map_location=map_location) if not also_test_file: return imported with TemporaryFileName() as fname: imported.save(fname) return torch.jit.load(fname, map_location=map_location)
def test_save_load_with_extra_files(self): class MyMod(torch.jit.ScriptModule): @torch.jit.script_method def forward(self, a): return a expected_extra_files = torch._C.ExtraFilesMap() expected_extra_files['foo'] = 'bar' m = MyMod() # Save to file. with TemporaryFileName() as fname: m.save(fname, _extra_files=expected_extra_files) extra_files = torch._C.ExtraFilesMap() extra_files['foo'] = '' torch.jit.load(fname, _extra_files=extra_files) self.assertEqual('bar', extra_files['foo']) # Use torch.jit API torch.jit.save(m, fname, _extra_files=expected_extra_files) extra_files['foo'] = '' torch.jit.load(fname, _extra_files=extra_files) self.assertEqual('bar', extra_files['foo']) # Save to buffer. buffer = io.BytesIO( m.save_to_buffer(_extra_files=expected_extra_files)) extra_files = torch._C.ExtraFilesMap() extra_files['foo'] = '' torch.jit.load(buffer, _extra_files=extra_files) self.assertEqual('bar', extra_files['foo']) # Use torch.jit API buffer = io.BytesIO() torch.jit.save(m, buffer, _extra_files=expected_extra_files) buffer.seek(0) extra_files = torch._C.ExtraFilesMap() extra_files['foo'] = '' torch.jit.load(buffer, _extra_files=extra_files) self.assertEqual('bar', extra_files['foo']) # Non-existent file 'bar' with self.assertRaises(RuntimeError): extra_files['bar'] = '' torch.jit.load(buffer, _extra_files=extra_files)