def test_dataset(self): from loader import ContactDataset from addict import Dict path = self.DATASET_PATH dataset = ContactDataset(path) def validate(sample): self.assertIsInstance(sample, Dict) self.assertEqual( dataset.extract(sample, "preexisting_conditions").shape[-1], len(dataset.DEFAULT_PREEXISTING_CONDITIONS), ) self.assertEqual(dataset.extract(sample, "test_results").shape[-1], 1) self.assertEqual(dataset.extract(sample, "age").shape[-1], 1) self.assertEqual(dataset.extract(sample, "sex").shape[-1], 1) self.assertEqual( dataset.extract(sample, "reported_symptoms_at_encounter").shape[-1], 28 ) self.assertEqual( dataset.extract(sample, "test_results_at_encounter").shape[-1], 1 ) sample = dataset.get(890, 3) validate(sample)
def test_losses(self): from loader import ContactDataset from torch.utils.data import DataLoader from models import ContactTracingTransformer from losses import ContagionLoss, InfectiousnessLoss from addict import Dict batch_size = 5 path = self.DATASET_PATH dataset = ContactDataset(path) dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=ContactDataset.collate_fn) batch = next(iter(dataloader)) ctt = ContactTracingTransformer(pool_latent_entities=False, use_logit_sink=False) output = Dict(ctt(batch)) loss_fn = ContagionLoss(allow_multiple_exposures=True) loss = loss_fn(batch, output) loss_fn = ContagionLoss(allow_multiple_exposures=False) loss = loss_fn(batch, output) loss_fn = InfectiousnessLoss() loss = loss_fn(batch, output)
def test_model_runs(self): from loader import ContactDataset from torch.utils.data import DataLoader from models import ContactTracingTransformer from addict import Dict batch_size = 5 path = self.DATASET_PATH dataset = ContactDataset(path) dataloader = DataLoader( dataset, batch_size=batch_size, collate_fn=ContactDataset.collate_fn ) batch = next(iter(dataloader)) def test_output(mod): output = Dict(mod(batch)) # print(output.latent_variable.shape) self.assertEqual(output.latent_variable.shape[0], batch_size) self.assertEqual(output.encounter_variables.shape[0], batch_size) ctt = ContactTracingTransformer() test_output(ctt) ctt = ContactTracingTransformer(use_encounter_partner_id_embedding=False) test_output(ctt) ctt = ContactTracingTransformer(use_learned_time_embedding=True) test_output(ctt) ctt = ContactTracingTransformer(encounter_duration_embedding_mode="sines") test_output(ctt)
def test_model_padding(self): import torch from loader import ContactDataset from torch.utils.data import DataLoader from models import ContactTracingTransformer torch.random.manual_seed(42) batch_size = 5 path = self.DATASET_PATH dataset = ContactDataset(path) dataloader = DataLoader( dataset, batch_size=batch_size, collate_fn=ContactDataset.collate_fn ) batch = next(iter(dataloader)) # Padding test -- pad everything that has to do with encounters, and check # whether it changes results pad_size = 1 def pad(tensor): if tensor.dim() == 3: zeros = torch.zeros( (tensor.shape[0], pad_size, tensor.shape[2]), dtype=tensor.dtype ) else: zeros = torch.zeros((tensor.shape[0], pad_size), dtype=tensor.dtype) return torch.cat([tensor, zeros], dim=1) padded_batch = { key: (pad(tensor) if key.startswith("encounter") else tensor) for key, tensor in batch.items() } # Pad the mask padded_batch["mask"] = pad(padded_batch["mask"]) # Make the model and set it to eval # noinspection PyUnresolvedReferences ctt = ContactTracingTransformer(num_sabs=1).eval() with torch.no_grad(), ctt.diagnose(): output = ctt(batch) padded_output = ctt(padded_batch) encounter_soll_wert = output["encounter_variables"][..., 0] encounter_ist_wert = padded_output["encounter_variables"][..., :-pad_size, 0] latent_soll_wert = output["latent_variable"] latent_ist_wert = padded_output["latent_variable"] self.assertSequenceEqual(encounter_soll_wert.shape, encounter_ist_wert.shape) self.assertSequenceEqual(latent_ist_wert.shape, latent_soll_wert.shape) self.assert_(torch.allclose(encounter_soll_wert, encounter_ist_wert)) self.assert_(torch.allclose(latent_soll_wert, latent_ist_wert))
def load(self): # path = os.path.join(self.checkpoint_directory, "best.ckpt") # assert os.path.exists(path) # state = torch.load(path) # self.model.load_state_dict(state["model"]) # self.model.eval() return self def infer(self, human_day_info): with torch.no_grad(): model_input = self.preprocessor.preprocess(human_day_info, as_batch=True) model_output = self.model(model_input.to_dict()) contagion_proba = ( model_output["encounter_variables"].sigmoid().numpy()[0, :, 0]) infectiousness = model_output["latent_variable"].numpy()[0, :, 0] return dict(contagion_proba=contagion_proba, infectiousness=infectiousness) if __name__ == "__main__": from loader import ContactDataset dataset = ContactDataset( path="/Users/nrahaman/Python/ctt/data/sim_people-1000_days-60_init-0") hdi = dataset.read(0, 0) engine = InferenceEngine("/Users/nrahaman/Python/ctt/exp/DEBUG-0") output = engine.infer(hdi) pass