def test_model_runs(self): from loader import ContactDataset from torch.utils.data import DataLoader from models import ContactTracingTransformer from addict import Dict batch_size = 5 path = self.DATASET_PATH dataset = ContactDataset(path) dataloader = DataLoader( dataset, batch_size=batch_size, collate_fn=ContactDataset.collate_fn ) batch = next(iter(dataloader)) def test_output(mod): output = Dict(mod(batch)) # print(output.latent_variable.shape) self.assertEqual(output.latent_variable.shape[0], batch_size) self.assertEqual(output.encounter_variables.shape[0], batch_size) ctt = ContactTracingTransformer() test_output(ctt) ctt = ContactTracingTransformer(use_encounter_partner_id_embedding=False) test_output(ctt) ctt = ContactTracingTransformer(use_learned_time_embedding=True) test_output(ctt) ctt = ContactTracingTransformer(encounter_duration_embedding_mode="sines") test_output(ctt)
def test_model_padding(self): import torch from loader import ContactDataset from torch.utils.data import DataLoader from models import ContactTracingTransformer torch.random.manual_seed(42) batch_size = 5 path = self.DATASET_PATH dataset = ContactDataset(path) dataloader = DataLoader( dataset, batch_size=batch_size, collate_fn=ContactDataset.collate_fn ) batch = next(iter(dataloader)) # Padding test -- pad everything that has to do with encounters, and check # whether it changes results pad_size = 1 def pad(tensor): if tensor.dim() == 3: zeros = torch.zeros( (tensor.shape[0], pad_size, tensor.shape[2]), dtype=tensor.dtype ) else: zeros = torch.zeros((tensor.shape[0], pad_size), dtype=tensor.dtype) return torch.cat([tensor, zeros], dim=1) padded_batch = { key: (pad(tensor) if key.startswith("encounter") else tensor) for key, tensor in batch.items() } # Pad the mask padded_batch["mask"] = pad(padded_batch["mask"]) # Make the model and set it to eval # noinspection PyUnresolvedReferences ctt = ContactTracingTransformer(num_sabs=1).eval() with torch.no_grad(), ctt.diagnose(): output = ctt(batch) padded_output = ctt(padded_batch) encounter_soll_wert = output["encounter_variables"][..., 0] encounter_ist_wert = padded_output["encounter_variables"][..., :-pad_size, 0] latent_soll_wert = output["latent_variable"] latent_ist_wert = padded_output["latent_variable"] self.assertSequenceEqual(encounter_soll_wert.shape, encounter_ist_wert.shape) self.assertSequenceEqual(latent_ist_wert.shape, latent_soll_wert.shape) self.assert_(torch.allclose(encounter_soll_wert, encounter_ist_wert)) self.assert_(torch.allclose(latent_soll_wert, latent_ist_wert))
def test_losses(self): from loader import ContactDataset from torch.utils.data import DataLoader from models import ContactTracingTransformer from losses import ContagionLoss, InfectiousnessLoss from addict import Dict batch_size = 5 path = self.DATASET_PATH dataset = ContactDataset(path) dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=ContactDataset.collate_fn) batch = next(iter(dataloader)) ctt = ContactTracingTransformer(pool_latent_entities=False, use_logit_sink=False) output = Dict(ctt(batch)) loss_fn = ContagionLoss(allow_multiple_exposures=True) loss = loss_fn(batch, output) loss_fn = ContagionLoss(allow_multiple_exposures=False) loss = loss_fn(batch, output) loss_fn = InfectiousnessLoss() loss = loss_fn(batch, output)
import torch path = "output.pkl" dataloader = get_dataloader(batch_size=1, shuffle=False, num_workers=0, path=path) batch = next(iter(dataloader)) #Load ONNX model onnx_model = onnx.load("model_onnx_10.onnx") tf_model = prepare(onnx_model) #Inputs to the model print('inputs:', tf_model.inputs) # Output nodes from the model print('outputs:', tf_model.outputs) # All nodes in the model # print('tensor_dict:') # print(tf_model.tensor_dict) output = tf_model.run(batch) print(output) tf_model.export_graph('tf_graph2.pb') #Sanity check with the PyTorch Model ctt = ContactTracingTransformer(pool_latent_entities=False, use_logit_sink=False) ctt.load_state_dict(torch.load('model.pth')) ctt.eval() output = ctt(batch) print(output)
def _build_model(self): self.model: nn.Module = to_device( ContactTracingTransformer(**self.get("model/kwargs", {})), self.device)
def _build(self): self.preprocessor = ContactPreprocessor( relative_days=self.get("data/loader_kwargs/relative_days", True)) self.model: torch.nn.Module = ContactTracingTransformer( **self.get("model/kwargs", {})) self.load()
import torch from models import ContactTracingTransformer import torch.onnx from loader import get_dataloader model = ContactTracingTransformer() model.load_state_dict(torch.load("model.pth")) model.eval() path = "output.pkl" dataloader = get_dataloader(batch_size=1, shuffle=False, num_workers=0, path=path) batch = next(iter(dataloader)) #List of inputs as in the batch input_names = [] for i in batch: input_names.append(i) torch.onnx.export(model, batch, "model_onnx_10.onnx", export_params=True, opset_version=10, do_constant_folding=True, input_names=input_names, output_names=['latent_variable', 'encounter_variable'])