示例#1
0
    def test_model_runs(self):
        from loader import ContactDataset
        from torch.utils.data import DataLoader
        from models import ContactTracingTransformer
        from addict import Dict

        batch_size = 5
        path = self.DATASET_PATH
        dataset = ContactDataset(path)
        dataloader = DataLoader(
            dataset, batch_size=batch_size, collate_fn=ContactDataset.collate_fn
        )
        batch = next(iter(dataloader))

        def test_output(mod):
            output = Dict(mod(batch))
            # print(output.latent_variable.shape)
            self.assertEqual(output.latent_variable.shape[0], batch_size)
            self.assertEqual(output.encounter_variables.shape[0], batch_size)

        ctt = ContactTracingTransformer()
        test_output(ctt)

        ctt = ContactTracingTransformer(use_encounter_partner_id_embedding=False)
        test_output(ctt)

        ctt = ContactTracingTransformer(use_learned_time_embedding=True)
        test_output(ctt)

        ctt = ContactTracingTransformer(encounter_duration_embedding_mode="sines")
        test_output(ctt)
示例#2
0
    def test_model_padding(self):
        import torch
        from loader import ContactDataset
        from torch.utils.data import DataLoader
        from models import ContactTracingTransformer

        torch.random.manual_seed(42)

        batch_size = 5
        path = self.DATASET_PATH
        dataset = ContactDataset(path)
        dataloader = DataLoader(
            dataset, batch_size=batch_size, collate_fn=ContactDataset.collate_fn
        )
        batch = next(iter(dataloader))

        # Padding test -- pad everything that has to do with encounters, and check
        # whether it changes results
        pad_size = 1

        def pad(tensor):
            if tensor.dim() == 3:
                zeros = torch.zeros(
                    (tensor.shape[0], pad_size, tensor.shape[2]), dtype=tensor.dtype
                )
            else:
                zeros = torch.zeros((tensor.shape[0], pad_size), dtype=tensor.dtype)
            return torch.cat([tensor, zeros], dim=1)

        padded_batch = {
            key: (pad(tensor) if key.startswith("encounter") else tensor)
            for key, tensor in batch.items()
        }
        # Pad the mask
        padded_batch["mask"] = pad(padded_batch["mask"])

        # Make the model and set it to eval
        # noinspection PyUnresolvedReferences
        ctt = ContactTracingTransformer(num_sabs=1).eval()
        with torch.no_grad(), ctt.diagnose():
            output = ctt(batch)
            padded_output = ctt(padded_batch)

        encounter_soll_wert = output["encounter_variables"][..., 0]
        encounter_ist_wert = padded_output["encounter_variables"][..., :-pad_size, 0]
        latent_soll_wert = output["latent_variable"]
        latent_ist_wert = padded_output["latent_variable"]
        self.assertSequenceEqual(encounter_soll_wert.shape, encounter_ist_wert.shape)
        self.assertSequenceEqual(latent_ist_wert.shape, latent_soll_wert.shape)
        self.assert_(torch.allclose(encounter_soll_wert, encounter_ist_wert))
        self.assert_(torch.allclose(latent_soll_wert, latent_ist_wert))
示例#3
0
    def test_losses(self):
        from loader import ContactDataset
        from torch.utils.data import DataLoader
        from models import ContactTracingTransformer
        from losses import ContagionLoss, InfectiousnessLoss
        from addict import Dict

        batch_size = 5
        path = self.DATASET_PATH
        dataset = ContactDataset(path)
        dataloader = DataLoader(dataset,
                                batch_size=batch_size,
                                collate_fn=ContactDataset.collate_fn)
        batch = next(iter(dataloader))

        ctt = ContactTracingTransformer(pool_latent_entities=False,
                                        use_logit_sink=False)
        output = Dict(ctt(batch))

        loss_fn = ContagionLoss(allow_multiple_exposures=True)
        loss = loss_fn(batch, output)
        loss_fn = ContagionLoss(allow_multiple_exposures=False)
        loss = loss_fn(batch, output)
        loss_fn = InfectiousnessLoss()
        loss = loss_fn(batch, output)
示例#4
0
import torch
path = "output.pkl"
dataloader = get_dataloader(batch_size=1,
                            shuffle=False,
                            num_workers=0,
                            path=path)
batch = next(iter(dataloader))

#Load ONNX model
onnx_model = onnx.load("model_onnx_10.onnx")
tf_model = prepare(onnx_model)
#Inputs to the model
print('inputs:', tf_model.inputs)
# Output nodes from the model
print('outputs:', tf_model.outputs)

# All nodes in the model
# print('tensor_dict:')
# print(tf_model.tensor_dict)
output = tf_model.run(batch)
print(output)
tf_model.export_graph('tf_graph2.pb')

#Sanity check with the PyTorch Model
ctt = ContactTracingTransformer(pool_latent_entities=False,
                                use_logit_sink=False)
ctt.load_state_dict(torch.load('model.pth'))
ctt.eval()
output = ctt(batch)
print(output)
 def _build_model(self):
     self.model: nn.Module = to_device(
         ContactTracingTransformer(**self.get("model/kwargs", {})),
         self.device)
示例#6
0
 def _build(self):
     self.preprocessor = ContactPreprocessor(
         relative_days=self.get("data/loader_kwargs/relative_days", True))
     self.model: torch.nn.Module = ContactTracingTransformer(
         **self.get("model/kwargs", {}))
     self.load()
示例#7
0
import torch
from models import ContactTracingTransformer
import torch.onnx
from loader import get_dataloader

model = ContactTracingTransformer()
model.load_state_dict(torch.load("model.pth"))
model.eval()
path = "output.pkl"
dataloader = get_dataloader(batch_size=1,
                            shuffle=False,
                            num_workers=0,
                            path=path)
batch = next(iter(dataloader))
#List of inputs as in the batch
input_names = []
for i in batch:
    input_names.append(i)
torch.onnx.export(model,
                  batch,
                  "model_onnx_10.onnx",
                  export_params=True,
                  opset_version=10,
                  do_constant_folding=True,
                  input_names=input_names,
                  output_names=['latent_variable', 'encounter_variable'])