Exemple #1
0
    def test_model_runs(self):
        from ctt.data_loading.loader import ContactDataset
        from torch.utils.data import DataLoader
        from ctt.models.transformer import ContactTracingTransformer
        from addict import Dict

        batch_size = 5
        path = self.DATASET_PATH
        dataset = ContactDataset(path)
        dataloader = DataLoader(
            dataset, batch_size=batch_size, collate_fn=ContactDataset.collate_fn
        )
        batch = next(iter(dataloader))

        def test_output(mod):
            output = Dict(mod(batch))
            # print(output.latent_variable.shape)
            self.assertEqual(output.latent_variable.shape[0], batch_size)
            self.assertEqual(output.encounter_variables.shape[0], batch_size)

        ctt = ContactTracingTransformer()
        test_output(ctt)

        ctt = ContactTracingTransformer(use_encounter_partner_id_embedding=False)
        test_output(ctt)

        ctt = ContactTracingTransformer(use_learned_time_embedding=True)
        test_output(ctt)

        ctt = ContactTracingTransformer(encounter_duration_embedding_mode="sines")
        test_output(ctt)
Exemple #2
0
    def test_model_padding(self):
        import torch
        from ctt.data_loading.loader import ContactDataset
        from torch.utils.data import DataLoader
        from ctt.models.transformer import ContactTracingTransformer

        torch.random.manual_seed(42)

        batch_size = 5
        path = self.DATASET_PATH
        dataset = ContactDataset(path)
        dataloader = DataLoader(
            dataset, batch_size=batch_size, collate_fn=ContactDataset.collate_fn
        )
        batch = next(iter(dataloader))

        # Padding test -- pad everything that has to do with encounters, and check
        # whether it changes results
        pad_size = 1

        def pad(tensor):
            if tensor.dim() == 3:
                zeros = torch.zeros(
                    (tensor.shape[0], pad_size, tensor.shape[2]), dtype=tensor.dtype
                )
            else:
                zeros = torch.zeros((tensor.shape[0], pad_size), dtype=tensor.dtype)
            return torch.cat([tensor, zeros], dim=1)

        padded_batch = {
            key: (pad(tensor) if key.startswith("encounter") else tensor)
            for key, tensor in batch.items()
        }
        # Pad the mask
        padded_batch["mask"] = pad(padded_batch["mask"])

        # Make the model and set it to eval
        # noinspection PyUnresolvedReferences
        ctt = ContactTracingTransformer(num_sabs=1).eval()
        with torch.no_grad(), ctt.diagnose():
            output = ctt(batch)
            padded_output = ctt(padded_batch)

        encounter_soll_wert = output["encounter_variables"][..., 0]
        encounter_ist_wert = padded_output["encounter_variables"][..., :-pad_size, 0]
        latent_soll_wert = output["latent_variable"]
        latent_ist_wert = padded_output["latent_variable"]
        self.assertSequenceEqual(encounter_soll_wert.shape, encounter_ist_wert.shape)
        self.assertSequenceEqual(latent_ist_wert.shape, latent_soll_wert.shape)
        self.assert_(torch.allclose(encounter_soll_wert, encounter_ist_wert))
        self.assert_(torch.allclose(latent_soll_wert, latent_ist_wert))
Exemple #3
0
    def test_tflite_model_conversion(self):
        from ctt.models.transformer import ContactTracingTransformer
        from ctt.conversion.export_to_tflite import convert_pytorch_model_fixed_messages

        # Instantiate new model
        model = ContactTracingTransformer()
        model.eval()

        # Test the conversion to TFLite
        for nb_messages in [10, 50, 100]:
            max_diff = convert_pytorch_model_fixed_messages(
                model, nb_messages, working_directory="./tmp/test_dir/",
                dataset_path=self.DATASET_PATH)
            self.assertLess(max_diff, 0.005)
Exemple #4
0
    def test_losses(self):
        from ctt.data_loading.loader import ContactDataset
        from torch.utils.data import DataLoader
        from ctt.models.transformer import ContactTracingTransformer
        from ctt.losses import ContagionLoss, InfectiousnessLoss
        from addict import Dict

        batch_size = 5
        path = self.DATASET_PATH
        dataset = ContactDataset(path)
        dataloader = DataLoader(
            dataset, batch_size=batch_size, collate_fn=ContactDataset.collate_fn
        )
        batch = next(iter(dataloader))

        ctt = ContactTracingTransformer()
        output = Dict(ctt(batch))

        loss_fn = ContagionLoss(allow_multiple_exposures=True)
        loss = loss_fn(batch, output)
        loss_fn = ContagionLoss(allow_multiple_exposures=False)
        loss = loss_fn(batch, output)
        loss_fn = InfectiousnessLoss()
        loss = loss_fn(batch, output)
Exemple #5
0
 def _build_model(self):
     self.model: nn.Module = to_device(
         ContactTracingTransformer(**self.get("model/kwargs", {})),
         self.device)
Exemple #6
0
                abs_pytorch_tf_deltas.mean())
        f.write("  Max abs diff between outputs : %f \n\n" %
                abs_pytorch_tf_deltas.max())

        f.write("Conversion from TF model to TFLite model\n")
        f.write("  Min abs diff between outputs : %f \n" %
                abs_tf_tflite_deltas.min())
        f.write("  Mean abs diff between outputs : %f \n" %
                abs_tf_tflite_deltas.mean())
        f.write("  Max abs diff between outputs : %f \n\n" %
                abs_tf_tflite_deltas.max())

        f.write("Overall conversion from pytorch model to TFLite model\n")
        f.write("  Min abs diff between outputs : %f \n" %
                abs_pytorch_tflite_deltas.min())
        f.write("  Mean abs diff between outputs : %f \n" %
                abs_pytorch_tflite_deltas.mean())
        f.write("  Max abs diff between outputs : %f \n\n" %
                abs_pytorch_tflite_deltas.max())

    return abs_pytorch_tflite_deltas.max()


if __name__ == '__main__':
    # Load pytorch model
    pytorch_model = ContactTracingTransformer()
    pytorch_model.load_state_dict(torch.load("models/model.pth"))
    pytorch_model.eval()

    # Launch conversion
    convert_pytorch_model(pytorch_model)
Exemple #7
0
 def _build(self):
     self.preprocessor = ContactPreprocessor(
         relative_days=self.get("data/loader_kwargs/relative_days", True))
     self.model: torch.nn.Module = ContactTracingTransformer(
         **self.get("model/kwargs", {}))
     self.load()