def test_set_state_get_state(): classifier = module.DiseaseClassifier(components={'in_width': 100, 'out_width': 55}) state = classifier.get_state() new_classifier = module.DiseaseClassifier(components={'in_width': 100, 'out_width': 55}) assert not (classifier.state_dict()['classification_layer.weight'] == new_classifier.state_dict()['classification_layer.weight']).all() new_state = new_classifier.get_state() new_classifier.set_state(state, reset=False) assert (classifier.state_dict()['classification_layer.weight'] == new_classifier.state_dict()['classification_layer.weight']).all()
def test_roc_trainer(): data = get_test_examples() metanet = module.MetaphlanNet( components={'widths': [12365, 6000, 2000, 55]}) classifier = module.DiseaseClassifier(components={ "in_width": 55, "out_width": 49 }, input=metanet) if torch.cuda.is_available(): data.cuda() metanet.cuda() classifier.cuda() bce = nn.BCELoss() def loss(batch): return bce(batch['predictions'], batch['label']) trainer = hyperfactory.roc_trainer(classifier, loss, components={ "in_width": 55, "out_width": 49 }, input=metanet) evaluator = trainer(Message({'roc_bias': torch.ones(49)})) assert hasattr(evaluator, 'run') evaluator.run(data)
def trainer(parameters: dict, max_epochs: int = 15): converted_params = convert_keys_to_variables(parameters) embedder = module.MetaphlanNet(converted_params['widths']) classifier = module.DiseaseClassifier(converted_params['in_width'], converted_params['out_width']) learning_rate = converted_params['learning_rate'] engine = module.get_trainer(embedder, classifier) train_loader = BatchingPipe(inputs=train_dataset) engine.run(train_loader, max_epochs=max_epochs) eval_engine = module.get_evaluator(embedder, classifier, attach=False) engine = None return eval_engine
def test_DiseaseClassifier(): metanet = module.MetaphlanNet(components={'widths': [12365, 6000, 2000, 100]}) classifier = module.DiseaseClassifier(components={'in_width': 100, 'out_width': 55}) data = get_test_examples() if torch.cuda.is_available(): metanet.cuda() classifier.cuda() data.cuda() output = classifier(metanet((data[2]))) assert len(output['embeddings'][0]) == 100 output = classifier(metanet(data[3:10])) assert output['predictions'].shape == torch.Size([7, 55])
def test_roc_bias_generator(): in_width = 50 out_width = 55 classifier = module.DiseaseClassifier(components={ 'in_width': in_width, 'out_width': out_width }) bias_generator = hyperfactory.roc_bias_generator(classifier) new_bias = bias_generator(None, None) assert type(new_bias) is Message assert 'roc_bias' in new_bias assert len(new_bias['roc_bias'][0]) == out_width assert len(new_bias['roc_bias']) == 1
def test_PrevalenceNet(): metanet = module.MetaphlanNet(components={'widths': [12365, 6000, 1000]}) prior = torch.empty(55).uniform_(0,1) prior = prior / sum(prior) prevalence = module.Concatenator(input=metanet, components={"in_column": "embeddings", "out_column": "embeddings", "concatenate_column": "prevalence"}) posterior = module.PosteriorNet(input=prevalence, components={'widths': [1055, 100]}) classifier = module.DiseaseClassifier(input=posterior, components={'in_width': 100, 'out_width': 55}) data = get_test_examples() batch = data[0:10] batch['prevalence'] = prior.expand(len(batch), len(prior)) if torch.cuda.is_available(): metanet.cuda() prevalence.cuda() posterior.cuda() classifier.cuda() batch.cuda() output = classifier(batch)
'phylogenetic_tree.torch' ) # This is the adjacency matrix for the tree in COO format. conv_net = module.DeepConvNet(components={ "channels": [1, 64, 64, 1], "edge_index": coo, "num_nodes": 12365 }) deep_net = module.MetaphlanNet(components={ "widths": [12365, 350], "in_column": 'embeddings' }, input=conv_net) classifier = module.DiseaseClassifier(input=deep_net, components={ 'in_column': 'embeddings', 'in_width': 350, 'out_width': len(study_labels) }) single_classifier = module.MulticlassDiseaseClassifier(input=classifier, components={ 'in_column': 'embeddings', 'out_column': 'top_prediction' }) bce = nn.BCELoss(size_average=False) ce = nn.CrossEntropyLoss() if torch.cuda.is_available():