def train(args, pan, model, device, train_loader, target_create_fn, optimizer, epoch): model.eval() pan.train() for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) optimizer.zero_grad() logits, feature = model(data, out_feature=True) if args.pan_type == "feature": output = pan(feature) elif args.pan_type == "logits": output = pan(logits) elif args.pan_type == "agnostic_feature": output = pan(compute_agnostic_stats(feature)) elif args.pan_type == "agnostic_logits": output = pan(compute_agnostic_stats(logits)) else: raise NotImplementedError("Not an eligible pan type.") pan_target = target_create_fn(target).to(device) loss = F.cross_entropy(output, pan_target) loss.backward() optimizer.step() if batch_idx % args.log_interval == 0: print( "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( epoch, batch_idx * len(data), len(train_loader.dataset), 100.0 * batch_idx / len(train_loader), loss.item(), ) )
def predict_with_agnostic_pan(args, model1, model2, pan1, pan2, m1_data, m2_data): """ Make a prediction with PAN using agnostic features of the models. Here we take a winner takes all approach, as we have 2 classifier classifying 1 input with 1 intended label(output). However, theoredically we can also go for a multi-label(multi-output) appproach, with multiple network working together to classify one input into multiple class. """ output1, feature1 = model1(m1_data, out_feature=True) output2, feature2 = model2(m2_data, out_feature=True) if args.pan_type == "agnostic_feature": stats1 = compute_agnostic_stats(feature1) stats2 = compute_agnostic_stats(feature2) elif args.pan_type == "agnostic_logits": stats1 = compute_agnostic_stats(output1) stats2 = compute_agnostic_stats(output2) p1_out = pan1(stats1) p2_out = pan2(stats2) # debugging p1_count = 0 p2_count = 0 p0_count = 0 # Winner takes all # Take m1_data len since m1 and m2 are the same data with diff channel sizes combined_output = [0] * len(m1_data) for i in range(len(combined_output)): if p1_out[i].max(0)[1] == 1 and p2_out[i].max(0)[1] == 0: # p1 true and p2 false combined_output[i] = torch.cat([ output1[i], torch.Tensor([torch.min(output1[i])] * len(output1[i])).to( args.device), ]) p1_count += 1 elif p1_out[i].max(0)[1] == 0 and p2_out[i].max(0)[1] == 1: # p1 false and p2 true combined_output[i] = torch.cat([ torch.Tensor([torch.min(output2[i])] * len(output2[i])).to( args.device), output2[i], ]) p2_count += 1 else: combined_output[i] = torch.cat([output1[i], output2[i]]) p0_count += 1 combined_output = torch.stack(combined_output, 0) print(p1_count, p2_count, p0_count) return combined_output
def eval_upan(args, upan, all_experts_output, device): upan.eval() fpan_target = [] for expert_output in all_experts_output: upan_output = [] for logits in expert_output: if args.upan_type == "logits": output = upan(logits) elif args.upan_type == "agnostic_logits": output = upan(compute_agnostic_stats(logits)) output = F.log_softmax(output, dim=-1) upan_output.append(output) # Concatenate batches of UPAN outputs upan_output = torch.cat(upan_output) # Extract the output of UPAN (ie. probability of the expert truly belonging to the input data) upan_output = torch.index_select(upan_output, 1, torch.tensor([1]).to(device)) upan_output = torch.flatten(upan_output) fpan_target.append(upan_output) # Concatenate UPAN predictions on different experts when given the same input data fpan_target = torch.stack(fpan_target, dim=1) # Extract index of the max log-probability (represents the expert chosen by UPAN) fpan_target = torch.argmax(fpan_target, dim=1) return fpan_target
def predict_with_agnostic_pan(args, model1, model2, upan, data1, data2): """ Make a prediction with PAN using agnostic features of the models. Here we take a winner takes all approach, as we have 2 classifier classifying 1 input with 1 intended label(output). However, theoredically we can also go for a multi-label(multi-output) appproach, with multiple network working together to classify one input into multiple class. """ output1 = model1(data1) output2 = model2(data2) p1_out = upan(compute_agnostic_stats(output1)) p2_out = upan(compute_agnostic_stats(output2)) p1_out = F.log_softmax(p1_out, dim=-1) p2_out = F.log_softmax(p2_out, dim=-1) # debugging p1_count = 0 p2_count = 0 # Winner takes all combined_output = [0] * len(data1) for i in range(len(combined_output)): if p1_out[i][1] > p2_out[i][1]: # p1 true and p2 false combined_output[i] = torch.cat([ output1[i], torch.Tensor([torch.min(output1[i])] * len(output1[i])).to( args.device), ]) p1_count += 1 else: # p1 false and p2 true combined_output[i] = torch.cat([ torch.Tensor([torch.min(output2[i])] * len(output2[i])).to( args.device), output2[i], ]) p2_count += 1 combined_output = torch.stack(combined_output, 0) print(p1_count, p2_count) return combined_output
def test(args, pan, model, device, test_loader, target_create_fn): model.eval() pan.eval() test_loss = 0 correct = 0 with torch.no_grad(): for data, target in test_loader: data, target = data.to(device), target.to(device) logits, feature = model(data, out_feature=True) if args.pan_type == "feature": output = pan(feature) elif args.pan_type == "logits": output = pan(logits) elif args.pan_type == "agnostic_feature": output = pan(compute_agnostic_stats(feature)) elif args.pan_type == "agnostic_logits": output = pan(compute_agnostic_stats(logits)) else: raise NotImplementedError("Not an eligible pan type.") pan_target = target_create_fn(target).to(device) test_loss += F.cross_entropy( output, pan_target, reduction="sum" ).item() # sum up batch loss pred = output.argmax( dim=1, keepdim=True ) # get the index of the max log-probability correct += pred.eq(pan_target.view_as(pred)).sum().item() test_loss /= len(test_loader.dataset) acc = 100.0 * correct / len(test_loader.dataset) print( "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format( test_loss, correct, len(test_loader.dataset), acc, ) ) return test_loss, acc
def test(args, device, upan, upan_test_loader): upan.eval() test_loss = "N/A" # Not being used with torch.no_grad(): model_pred = [] model_target = [] for expert_data in upan_test_loader: upan_output = [] upan_target = [] for logits, target in expert_data: if args.upan_type == "logits": output = upan(logits) elif args.upan_type == "agnostic_logits": output = upan(compute_agnostic_stats(logits)) output = F.log_softmax(output, dim=-1) upan_output.append(output) upan_target.append(target) # Concatenate batches of UPAN outputs and targets upan_output = torch.cat(upan_output) upan_target = torch.cat(upan_target) # Extract the output of UPAN (ie. probability of the expert truly belonging to the input data) upan_output = torch.index_select(upan_output, 1, torch.tensor(1).to(device)) upan_output = torch.flatten(upan_output) # Append UPAN output and target for this expert model_pred.append(upan_output) model_target.append(upan_target) # Concatenate UPAN predictions on different experts when given the same input data model_pred = torch.stack(model_pred, dim=1) # Extract index of the max log-probability (represents the expert chosen by UPAN) model_pred = torch.argmax(model_pred, dim=1) # Concatenate UPAN targets on different experts when given the same input data model_target = torch.stack(model_target, dim=1) # Extract index of the true target (represent the correct expert) model_target = torch.nonzero(model_target, as_tuple=True)[1] correct = model_pred.eq(model_target).sum().item() total_data = len(model_pred) acc = 100.0 * correct / total_data print("\nTest set: Accuracy: {}/{} ({:.0f}%)".format( correct, total_data, acc)) return test_loss, acc
def train(args, upan, upan_train_loader, optimizer, epoch): # Use the collected training set to train upan upan.train() total_data = sum(len(data) for data, target in upan_train_loader) for batch_idx, (data, upan_target) in enumerate(upan_train_loader): if args.upan_type == "logits": output = upan(data) elif args.upan_type == "agnostic_logits": output = upan(compute_agnostic_stats(data)) optimizer.zero_grad() loss = F.cross_entropy(output, upan_target) loss.backward() optimizer.step() if batch_idx % args.log_interval == 0: print( "Train UPAN Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( epoch, batch_idx * len(data), total_data, 100.0 * batch_idx * len(data) / total_data, loss.item(), ))