def test_roc_trainer(): data = get_test_examples() metanet = module.MetaphlanNet( components={'widths': [12365, 6000, 2000, 55]}) classifier = module.DiseaseClassifier(components={ "in_width": 55, "out_width": 49 }, input=metanet) if torch.cuda.is_available(): data.cuda() metanet.cuda() classifier.cuda() bce = nn.BCELoss() def loss(batch): return bce(batch['predictions'], batch['label']) trainer = hyperfactory.roc_trainer(classifier, loss, components={ "in_width": 55, "out_width": 49 }, input=metanet) evaluator = trainer(Message({'roc_bias': torch.ones(49)})) assert hasattr(evaluator, 'run') evaluator.run(data)
def test_MetaphlanNet(): metanet = module.MetaphlanNet(components={'widths': [12365, 6000, 2000, 100]}) state = metanet.get_state() betanet = module.MetaphlanNet() crate = betanet.get_state() classifier = module.DiseaseClassifier(components={'in_width': 100, 'out_width': 55}) data = get_test_examples() if torch.cuda.is_available(): metanet.cuda() classifier.cuda() data.cuda() output = classifier(metanet((data[0:10]))) assert len(output['embeddings'][0]) == 100 output = classifier(metanet((data[0:10]))) assert output['predictions'].shape == torch.Size([10, 55])
def trainer(parameters: dict, max_epochs: int = 15): converted_params = convert_keys_to_variables(parameters) embedder = module.MetaphlanNet(converted_params['widths']) classifier = module.DiseaseClassifier(converted_params['in_width'], converted_params['out_width']) learning_rate = converted_params['learning_rate'] engine = module.get_trainer(embedder, classifier) train_loader = BatchingPipe(inputs=train_dataset) engine.run(train_loader, max_epochs=max_epochs) eval_engine = module.get_evaluator(embedder, classifier, attach=False) engine = None return eval_engine
def test_PrevalenceNet(): metanet = module.MetaphlanNet(components={'widths': [12365, 6000, 1000]}) prior = torch.empty(55).uniform_(0,1) prior = prior / sum(prior) prevalence = module.Concatenator(input=metanet, components={"in_column": "embeddings", "out_column": "embeddings", "concatenate_column": "prevalence"}) posterior = module.PosteriorNet(input=prevalence, components={'widths': [1055, 100]}) classifier = module.DiseaseClassifier(input=posterior, components={'in_width': 100, 'out_width': 55}) data = get_test_examples() batch = data[0:10] batch['prevalence'] = prior.expand(len(batch), len(prior)) if torch.cuda.is_available(): metanet.cuda() prevalence.cuda() posterior.cuda() classifier.cuda() batch.cuda() output = classifier(batch)
device=device) print("Initializing Model.") coo = torch.load( 'phylogenetic_tree.torch' ) # This is the adjacency matrix for the tree in COO format. conv_net = module.DeepConvNet(components={ "channels": [1, 64, 64, 1], "edge_index": coo, "num_nodes": 12365 }) deep_net = module.MetaphlanNet(components={ "widths": [12365, 350], "in_column": 'embeddings' }, input=conv_net) classifier = module.DiseaseClassifier(input=deep_net, components={ 'in_column': 'embeddings', 'in_width': 350, 'out_width': len(study_labels) }) single_classifier = module.MulticlassDiseaseClassifier(input=classifier, components={ 'in_column': 'embeddings', 'out_column': 'top_prediction'
normalizer.compile() normalizer.enable_inference() normalizer.disable_updates() oversampler.compile() minibatcher = BatchingPipe(oversampler, batch_size=20) l = len(minibatcher) oversample_weights = torch.Tensor([ 1 - 1/len(study_labels) for label in study_labels[:]]) # (1/n_i) / { (1\n_i) + (1\(n-n_i)) } def inject_oversample_weights(x): """ Adds a 'prevalence' column, where the prevalence values are tuned such that each individual classifier is oversampled to 50/50 """ x['prevalence'] = oversample_weights.expand(len(x), len(oversample_weights)) return x training_set = TensorPipe(FunctionPipe(minibatcher, function=inject_oversample_weights), columns=['examples', 'label', 'prevalence', 'label_index'], device=device) deep_only_net = module.MetaphlanNet(components={'widths': [12365, 1000, 350]}) deep_only_classifier = module.DiseaseClassifier(input=deep_only_net, components={'in_column': 'embeddings', 'in_width': 350, 'out_width': len(study_labels)}) deep_only_single_classifier = module.MulticlassDiseaseClassifier(input=deep_only_classifier, components={'in_column': 'embeddings', 'out_column': 'top_prediction'}) bce = nn.BCELoss(size_average=False) ce = nn.CrossEntropyLoss() if torch.cuda.is_available(): deep_only_net.cuda(device=device) deep_only_classifier.cuda(device=device) deep_only_single_classifier.cuda(device=device) bce.cuda(device=device) ce.cuda(device=device) def loss(batch): loss_multiplier = batch['prevalence']