def main(): start_epoch = 0 save_model = "./save_model" tensorboard_dir = "./tensorboard/OOD" # Device configuration device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Hyper-parameters eps = 1e-8 ### data config train_dataset = load_data.Dog_dataloader(image_dir = image_dir, num_class = args.num_classes, mode = "train") train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=2) test_dataset = load_data.Dog_dataloader(image_dir = image_dir, num_class = args.num_classes, mode = "test") test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=8, shuffle=False, num_workers=2) ##### model, optimizer config if args.net_type == "resnet50": model = models.resnet50(num_c=args.num_classes, pretrained=True) elif args.net_type == "resnet34": model = models.resnet34(num_c=args.num_classes, pretrained=True) # optimizer = optim.Adam(model.parameters(), lr=args.init_lr, weight_decay=1e-5) optimizer = optim.SGD(model.parameters(), lr=args.init_lr, momentum=0.9, nesterov=True) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.num_epochs * len(train_loader), ) if args.resume == True: print("load checkpoint_last") checkpoint = torch.load(os.path.join(save_model, "checkpoint_last.pth.tar")) ##### load model model.load_state_dict(checkpoint["model"]) start_epoch = checkpoint["epoch"] optimizer = optim.Adam(model.parameters(), lr = checkpoint["init_lr"]) #### loss config criterion = nn.BCEWithLogitsLoss() #### create folder Path(os.path.join(save_model, env, args.net_type)).mkdir(exist_ok=True, parents=True) if args.board_clear == True: files = glob.glob(tensorboard_dir+"/*") for f in files: shutil.rmtree(f) i = 0 while True: if Path(os.path.join(tensorboard_dir, str(i))).exists() == True: i += 1 else: Path(os.path.join(tensorboard_dir, str(i))).mkdir(exist_ok=True, parents=True) break summary = SummaryWriter(os.path.join(tensorboard_dir, str(i))) # Start training j=0 best_score=0 score = 0 for epoch in range(start_epoch, args.num_epochs): for i in range(args.num_classes): locals()["train_label{}".format(i)] = 0 locals()["test_label{}".format(i)] = 0 total_loss = 0 train_acc = 0 test_acc = 0 stime = time.time() for i, train_data in enumerate(train_loader): #### initialized org_image = train_data['input'].to(device) gt = train_data['label'].type(torch.FloatTensor).to(device) model = model.to(device).train() optimizer.zero_grad() #### forward path output = model(org_image) #### calc loss class_loss = criterion(output, gt) #### calc accuracy train_acc += sum(torch.argmax(torch.sigmoid(output), dim=1) == torch.argmax(gt, dim=1)).cpu().detach().item() gt_label = torch.argmax(gt, dim=1).cpu().detach().tolist() output_label = torch.argmax(torch.sigmoid(output), dim=1).cpu().detach().tolist() for idx, label in enumerate(gt_label): if label == output_label[idx]: locals()["train_label{}".format(label)] += 1 with autograd.detect_anomaly(): class_loss.backward() optimizer.step() scheduler.step() total_loss += class_loss.item() with torch.no_grad(): for i, test_data in enumerate(test_loader): org_image = test_data['input'].to(device) gt = test_data['label'].type(torch.FloatTensor).to(device) model = model.to(device).eval() #### forward path output = model(org_image) gt_label = torch.argmax(gt, dim=1).cpu().detach().tolist() output_label = torch.argmax(torch.sigmoid(output), dim=1).cpu().detach().tolist() for idx, label in enumerate(gt_label): if label == output_label[idx]: locals()["test_label{}".format(label)] += 1 test_acc += sum(torch.argmax(torch.sigmoid(output), dim=1) == torch.argmax(gt, dim=1)).cpu().detach().item() print('Epoch [{}/{}], Step {}, loss = {:.4f}, exe time: {:.2f}, lr: {:.4f}*e-4' .format(epoch, args.num_epochs, i+1, total_loss/len(train_loader), time.time() - stime, scheduler.get_last_lr()[0] * 10 ** 4)) print("train accuracy total : {:.4f}".format(train_acc/train_data.num_image)) for num in range(args.num_classes): print("label{} : {:.4f}" .format(num, locals()["train_label{}".format(num)]/train_data.len_list[num]) , end=" ") print() print("test accuracy total : {:.4f}".format(test_acc/test_data.num_image)) for num in range(args.num_classes): print("label{} : {:.4f}" .format(num, locals()["test_label{}".format(num)]/test_data.len_list[num]) , end=" ") print("\n") summary.add_scalar('loss/loss', total_loss/len(train_loader), epoch) summary.add_scalar('acc/train_acc', train_acc/train_data.num_image, epoch) summary.add_scalar('acc/test_acc', test_acc/test_data.num_image, epoch) summary.add_scalar("learning_rate/lr", scheduler.get_last_lr()[0], epoch) time.sleep(0.001) torch.save({ 'model': model.state_dict(), 'epoch': epoch, 'init_lr' : scheduler.get_last_lr()[0] }, os.path.join(save_model, env,args.net_type, 'checkpoint_last.pth.tar'))
def main(): start_epoch = 0 pretrained_model = os.path.join("./pre_trained", args.dataset, args.net_type + ".pth.tar") save_model = "./save_model_dis/pre_training" tensorboard_dir = "./tensorboard/OOD_dis/pre_training" + args.dataset # Device configuration device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Hyper-parameters eps = 1e-8 ### data config train_dataset = load_data.Dog_metric_dataloader(image_dir=image_dir, num_class=args.num_classes, mode="train", soft_label=args.soft_label) if args.custom_sampler: MySampler = load_data.customSampler(train_dataset, args.batch_size, args.num_instances) train_loader = torch.utils.data.DataLoader(train_dataset, batch_sampler=MySampler, num_workers=2) else: train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=2) test_dataset = load_data.Dog_dataloader(image_dir=image_dir, num_class=args.num_classes, mode="test") test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=8, shuffle=False, num_workers=2) out_test_dataset = load_data.Dog_dataloader(image_dir=image_dir, num_class=args.num_classes, mode="OOD") out_test_loader = torch.utils.data.DataLoader(out_test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=2) if args.transfer: ### perfectly OOD data OOD_dataset = load_data.Dog_dataloader(image_dir=OOD_dir, num_class=args.OOD_num_classes, mode="OOD") OOD_loader = torch.utils.data.DataLoader(OOD_dataset, batch_size=args.batch_size, shuffle=True, num_workers=2) ##### model, optimizer config if args.net_type == "resnet50": model = models.resnet50(num_c=args.num_classes, pretrained=True) elif args.net_type == "resnet34": model = models.resnet34(num_c=args.num_classes, pretrained=True) elif args.net_type == "vgg19": model = models.vgg19(num_c=args.num_classes, pretrained=True) elif args.net_type == "vgg16": model = models.vgg16(num_c=args.num_classes, pretrained=True) elif args.net_type == "vgg19_bn": model = models.vgg19_bn(num_c=args.num_classes, pretrained=True) elif args.net_type == "vgg16_bn": model = models.vgg16_bn(num_c=args.num_classes, pretrained=True) if args.transfer: extra_fc = nn.Linear(2048, args.num_classes + args.OOD_num_classes) if args.load == True: print("loading model") checkpoint = torch.load(pretrained_model) ##### load model model.load_state_dict(checkpoint["model"]) batch_num = len( train_loader) / args.batch_size if args.custom_sampler else len( train_loader) optimizer = optim.SGD(model.parameters(), lr=args.init_lr, momentum=0.9, nesterov=args.nesterov) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, args.num_epochs * batch_num) #### loss config criterion = nn.BCEWithLogitsLoss() #### create folder Path(os.path.join(save_model, env, args.net_type)).mkdir(exist_ok=True, parents=True) if args.board_clear == True: files = glob.glob(tensorboard_dir + "/*") for f in files: shutil.rmtree(f) i = 0 while True: if Path(os.path.join(tensorboard_dir, str(i))).exists() == True: i += 1 else: Path(os.path.join(tensorboard_dir, str(i))).mkdir(exist_ok=True, parents=True) break summary = SummaryWriter(os.path.join(tensorboard_dir, str(i))) # Start training j = 0 best_score = 0 score = 0 membership_loss = torch.tensor(0) transfer_loss = torch.tensor(0) for epoch in range(start_epoch, args.num_epochs): running_loss = 0 running_membership_loss = 0 running_transfer_loss = 0 running_class_loss = 0 train_acc = 0 test_acc = 0 stime = time.time() # for i, (train_data, OOD_data) in enumerate(zip(train_loader, OOD_loader)): for i, train_data in enumerate(train_loader): #### initialized org_image = train_data['input'] + 0.01 * torch.randn_like( train_data['input']) org_image = org_image.to(device) gt = train_data['label'].type(torch.FloatTensor).to(device) model = model.to(device).train() optimizer.zero_grad() #### forward path out1, out2 = model.pendis_forward(org_image) if args.membership: membership_loss = ( Membership_loss(out2, gt, args.num_classes) + Membership_loss(out1, gt, args.num_classes)) running_membership_loss += membership_loss.item() if args.transfer: extra_fc = extra_fc.to(device).train() OOD_image = ( OOD_data['input'] + 0.01 * torch.randn_like(OOD_data['input'])).to(device) OOD_gt = torch.cat( (torch.zeros(args.batch_size, args.num_classes), OOD_data['label'].type(torch.FloatTensor)), dim=1).to(device) #### forward path _, feature = model.gen_forward(OOD_image) OOD_output = extra_fc(feature) transfer_loss = criterion(OOD_output, OOD_gt) running_transfer_loss += transfer_loss.item() #### calc loss class1_loss = criterion(out1, gt) class2_loss = criterion(out2, gt) class_loss = (class1_loss + class2_loss) total_loss = class_loss + membership_loss * 0.3 + transfer_loss #### calc accuracy train_acc += sum( torch.argmax(out1, dim=1) == torch.argmax( gt, dim=1)).cpu().detach().item() train_acc += sum( torch.argmax(out2, dim=1) == torch.argmax( gt, dim=1)).cpu().detach().item() total_loss.backward() optimizer.step() scheduler.step() running_class_loss += class_loss.item() running_loss += total_loss.item() with torch.no_grad(): for i, test_data in enumerate(test_loader): org_image = test_data['input'].to(device) model = model.to(device).eval() gt = test_data['label'].type(torch.FloatTensor).to(device) #### forward path out1, out2 = model.pendis_forward(org_image) score_1 = nn.functional.softmax(out1, dim=1) score_2 = nn.functional.softmax(out2, dim=1) dist = torch.sum(torch.abs(score_1 - score_2), dim=1).reshape( (org_image.shape[0], -1)) if i == 0: dists = dist labels = torch.zeros((org_image.shape[0], )) else: dists = torch.cat((dists, dist), dim=0) labels = torch.cat( (labels, torch.zeros((org_image.shape[0]))), dim=0) test_acc += sum( torch.argmax(torch.sigmoid(out1), dim=1) == torch.argmax( gt, dim=1)).cpu().detach().item() test_acc += sum( torch.argmax(torch.sigmoid(out2), dim=1) == torch.argmax( gt, dim=1)).cpu().detach().item() for i, out_org_data in enumerate(out_test_loader): out_org_image = out_org_data['input'].to(device) out1, out2 = model.pendis_forward(out_org_image) score_1 = nn.functional.softmax(out1, dim=1) score_2 = nn.functional.softmax(out2, dim=1) dist = torch.sum(torch.abs(score_1 - score_2), dim=1).reshape( (out_org_image.shape[0], -1)) dists = torch.cat((dists, dist), dim=0) labels = torch.cat((labels, torch.ones( (out_org_image.shape[0]))), dim=0) roc = evaluate(labels.cpu(), dists.cpu(), metric='roc') print('Epoch{} AUROC: {:.3f}, test accuracy : {:.4f}'.format( epoch, roc, test_acc / test_dataset.num_image / 2)) print( 'Epoch [{}/{}], Step {}, total_loss = {:.4f}, class = {:.4f}, membership = {:.4f}, transfer = {:.4f}, exe time: {:.2f}, lr: {:.4f}*e-4' .format(epoch, args.num_epochs, i + 1, running_loss / batch_num, running_class_loss / batch_num, running_membership_loss / batch_num, running_transfer_loss / batch_num, time.time() - stime, scheduler.get_last_lr()[0] * 10**4)) print('exe time: {:.2f}, lr: {:.4f}*e-4'.format( time.time() - stime, scheduler.get_last_lr()[0] * 10**4)) print("train accuracy total : {:.4f}".format( train_acc / train_dataset.num_image / 2)) print("test accuracy total : {:.4f}".format( test_acc / test_dataset.num_image / 2)) summary.add_scalar('loss/total_loss', running_loss / batch_num, epoch) summary.add_scalar('loss/class_loss', running_class_loss / batch_num, epoch) summary.add_scalar('loss/membership_loss', running_membership_loss / batch_num, epoch) summary.add_scalar('acc/train_acc', train_acc / train_dataset.num_image / 2, epoch) summary.add_scalar('acc/test_acc', test_acc / test_dataset.num_image / 2, epoch) summary.add_scalar("learning_rate/lr", scheduler.get_last_lr()[0], epoch) time.sleep(0.001) torch.save( { 'model': model.state_dict(), 'epoch': epoch, 'init_lr': scheduler.get_last_lr()[0] }, os.path.join(save_model, env, args.net_type, 'checkpoint_last_pre.pth.tar'))
def data_config(image_dir, OOD_dir, num_classes, OOD_num_classes, batch_size, num_instances, soft_label, custom_sampler, not_test_ODIN, transfer, resize=(160, 160)): train_dataset = load_data.Dog_metric_dataloader(image_dir=image_dir, num_class=num_classes, mode="train", resize=resize, soft_label=soft_label) if custom_sampler: MySampler = load_data.customSampler(train_dataset, batch_size, num_instances) train_loader = DataLoader(train_dataset, batch_sampler=MySampler, num_workers=2) else: train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2) test_dataset = load_data.Dog_dataloader(image_dir=image_dir, num_class=num_classes, mode="test", resize=resize) test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False, num_workers=2) out_test_dataset, out_test_loader, OOD_dataset, OOD_loader = 0, 0, 0, 0 ### novelty data if not_test_ODIN: out_test_dataset = load_data.Dog_dataloader(image_dir=image_dir, num_class=num_classes, mode="OOD", resize=resize) out_test_loader = DataLoader(out_test_dataset, batch_size=8, shuffle=True, num_workers=2) ### perfectly OOD data if transfer: OOD_dataset = load_data.Dog_dataloader(image_dir=OOD_dir, num_class=OOD_num_classes, mode="OOD", resize=resize) OOD_loader = DataLoader(OOD_dataset, batch_size=batch_size, shuffle=True, num_workers=2) return train_dataset, train_loader, test_dataset, test_loader, out_test_dataset, out_test_loader, OOD_dataset, OOD_loader
def main(): start_epoch = 0 save_model = "./pre_trained" tensorboard_dir = "./tensorboard/OOD" # Device configuration device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Hyper-parameters eps = 1e-8 init_lr = 5e-4 unorm = UnNormalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) ### data config test_data = load_data.Dog_dataloader(image_dir=image_dir, num_class=args.num_classes, mode="OOD") test_loader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=True, num_workers=2) ##### model, optimizer config if args.net_type == "resnet": model = models.resnet50(num_c=args.num_classes, pretrained=True) optimizer = optim.Adam(model.parameters(), lr=init_lr, weight_decay=1e-5) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [30], gamma=0.3) if args.resume == True: print("load checkpoint_last") checkpoint = torch.load(os.path.join(save_model, "resnet50.pth.tar")) ##### load model model.load_state_dict(checkpoint["model"]) for i in range(args.num_classes): locals()["test_label{}".format(i)] = 0 test_acc = 0 MSP = torch.tensor([]) with torch.no_grad(): for i, (org_image, gt) in enumerate(test_loader): org_image = org_image.to(device) model = model.to(device).eval() gt = gt.type(torch.FloatTensor).to(device) #### forward path output = model(org_image) raw_image = unorm(org_image.squeeze(0)).cpu().detach() gt_label = torch.argmax(gt, dim=1).cpu().detach().tolist() output_label = torch.argmax(torch.sigmoid(output), dim=1).cpu().detach().tolist() for idx, label in enumerate(gt_label): if label == output_label[idx]: locals()["test_label{}".format(label)] += 1 MSP = torch.cat( (MSP, (torch.softmax(output, dim=1).max().cpu()).unsqueeze(0)), dim=0) # print(torch.softmax(output, dim=1).max()) # print("label : {}, predicted class : {}".format(label, output_label)) test_acc += sum( torch.argmax(torch.sigmoid(output), dim=1) == torch.argmax( gt, dim=1)).cpu().detach().item() # thismanager = get_current_fig_manager() # thismanager.window.SetPosition((500, 0)) # plt.get_current_fig_manager().window.wm_geometry("+1000+100") # move the window # plt.imshow(raw_image.permute(1,2,0)) # plt.show() thres_list = [0.501, 0.601, 0.701, 0.801, 0.901] print("total # of data : {}".format(test_data.num_image)) for idx, thres in enumerate(thres_list): print(thres, end=" ") if idx == 0: print(torch.sum(MSP < thres)) else: print(torch.sum(torch.mul((thres + 0.1) >= MSP, thres < MSP))) print("test accuracy total : {:.4f}".format(test_acc / test_data.num_image)) for num in range(args.num_classes): print("label{} : {:.4f}".format( num, locals()["test_label{}".format(num)] / test_data.len_list[num]), end=" ") print("\n") time.sleep(0.001)
def main(): start_epoch = 0 if args.metric: save_model = "./save_model_" + args.dataset + "_metric" tensorboard_dir = "./tensorboard/OOD_" + args.dataset else: save_model = "./save_model_" + args.dataset tensorboard_dir = "./tensorboard/OOD_" + args.dataset # Device configuration device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Hyper-parameters eps = 1e-8 ### data config train_dataset = load_data.Dog_metric_dataloader(image_dir = image_dir, num_class = args.num_classes, mode = "train", soft_label=args.soft_label) MySampler = customSampler(train_dataset, args.batch_size, args.num_instances) train_loader = torch.utils.data.DataLoader(train_dataset, # batch_size=args.batch_size, batch_sampler= MySampler, num_workers=2) test_dataset = load_data.Dog_dataloader(image_dir = image_dir, num_class = args.num_classes, mode = "test") test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=8, shuffle=False, num_workers=2) out_test_dataset = load_data.Dog_dataloader(image_dir = image_dir, num_class = args.num_classes, mode = "OOD") out_test_loader = torch.utils.data.DataLoader(out_test_dataset, batch_size=8, shuffle=True, num_workers=2) ##### model, optimizer config if args.net_type == "resnet50": model = models.resnet50(num_c=args.num_classes, pretrained=True) elif args.net_type == "resnet34": model = models.resnet34(num_c=args.num_classes, pretrained=True) optimizer = optim.SGD(model.parameters(), lr=args.init_lr, momentum=0.9, nesterov=args.nesterov) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.num_epochs * len(train_loader)//50, eta_min=args.init_lr/10) if args.resume == True: print("load checkpoint_last") checkpoint = torch.load(os.path.join(save_model, "checkpoint_last.pth.tar")) ##### load model model.load_state_dict(checkpoint["model"]) start_epoch = checkpoint["epoch"] optimizer = optim.SGD(model.parameters(), lr = checkpoint["init_lr"]) #### loss config criterion = nn.BCEWithLogitsLoss() triplet = torch.nn.TripletMarginLoss(margin=0.5, p=2) #### create folder Path(os.path.join(save_model, env, args.net_type)).mkdir(exist_ok=True, parents=True) if args.board_clear == True: files = glob.glob(tensorboard_dir+"/*") for f in files: shutil.rmtree(f) i = 0 while True: if Path(os.path.join(tensorboard_dir, str(i))).exists() == True: i += 1 else: Path(os.path.join(tensorboard_dir, str(i))).mkdir(exist_ok=True, parents=True) break summary = SummaryWriter(os.path.join(tensorboard_dir, str(i))) # Start training j=0 best_score=0 score = 0 triplet_loss = torch.tensor(0) membership_loss = torch.tensor(0) for epoch in range(start_epoch, args.num_epochs): for i in range(args.num_classes): locals()["train_label{}".format(i)] = 0 locals()["test_label{}".format(i)] = 0 total_loss = 0 triplet_running_loss = 0 membership_running_loss = 0 class_running_loss = 0 train_acc = 0 test_acc = 0 stime = time.time() for i, train_data in enumerate(train_loader): #### initialized org_image = train_data['input'] + 0.01 * torch.randn_like(train_data['input']) org_image = org_image.to(device) model = model.to(device).train() gt = train_data['label'].type(torch.FloatTensor).to(device) optimizer.zero_grad() #### forward path output, output_list = model.feature_list(org_image) if args.metric: target_layer = output_list[-1] negative_list = [] for batch_idx in range(args.batch_size): gt_arg = gt.argmax(dim=1) negative = (gt_arg != gt_arg[batch_idx]) if batch_idx == 0: negative_tensor = target_layer[np.random.choice(np.where(negative.cpu().numpy() == True)[0], 1)[0]] positive_tensor = target_layer[np.random.choice(np.delete( np.where(~negative.cpu().numpy() == True)[0],np.where(np.where(~negative.cpu().numpy() == True)[0] == batch_idx)), 1)[0]] negative_tensor = torch.unsqueeze(negative_tensor, dim=0) positive_tensor = torch.unsqueeze(positive_tensor, dim=0) else: tmp_negative_tensor = target_layer[np.random.choice(np.where(negative.cpu().numpy() == True)[0], 1)[0]] negative_tensor = torch.cat((negative_tensor, torch.unsqueeze(tmp_negative_tensor, dim=0)), dim=0) tmp_positive_tensor = target_layer[np.random.choice(np.delete( np.where(~negative.cpu().numpy() == True)[0],np.where(np.where(~negative.cpu().numpy() == True)[0] == batch_idx)), 1)[0]] positive_tensor = torch.cat((positive_tensor, torch.unsqueeze(tmp_positive_tensor, dim=0)), dim=0) triplet_loss = 0.5 * triplet(target_layer, positive_tensor, negative_tensor) if args.membership: R_wrong = 0 R_correct = 0 gt_idx = torch.argmax(gt, dim=1) for batch_idx, which in enumerate(gt_idx): for idx in range(args.num_classes): output_sigmoid = torch.sigmoid(output) if which == idx: R_wrong += (1 - output_sigmoid[batch_idx][idx]) ** 2 else: R_correct += output_sigmoid[batch_idx][idx] / (args.num_classes-1) membership_loss = (R_wrong + R_correct) / args.batch_size #### calc loss class_loss = criterion(output, gt) #### calc accuracy train_acc += sum(torch.argmax(torch.sigmoid(output), dim=1) == torch.argmax(gt, dim=1)).cpu().detach().item() gt_label = torch.argmax(gt, dim=1).cpu().detach().tolist() output_label = torch.argmax(torch.sigmoid(output), dim=1).cpu().detach().tolist() for idx, label in enumerate(gt_label): if label == output_label[idx]: locals()["train_label{}".format(label)] += 1 total_backward_loss = class_loss + triplet_loss + membership_loss total_backward_loss.backward() optimizer.step() scheduler.step() class_running_loss += class_loss.item() triplet_running_loss += triplet_loss.item() membership_running_loss += membership_loss.item() total_loss += total_backward_loss.item() with torch.no_grad(): for i, test_data in enumerate(test_loader): org_image = test_data['input'].to(device) model = model.to(device).eval() gt = test_data['label'].type(torch.FloatTensor).to(device) #### forward path output = model(org_image) gt_label = torch.argmax(gt, dim=1).cpu().detach().tolist() output_label = torch.argmax(torch.sigmoid(output), dim=1).cpu().detach().tolist() for idx, label in enumerate(gt_label): if label == output_label[idx]: locals()["test_label{}".format(label)] += 1 test_acc += sum(torch.argmax(torch.sigmoid(output), dim=1) == torch.argmax(gt, dim=1)).cpu().detach().item() print('Epoch [{}/{}], Step {}, class_loss = {:.4f}, membership_loss = {:.4f}, total_loss = {:.4f}, exe time: {:.2f}, lr: {:.4f}*e-4' .format(epoch, args.num_epochs, i+1, class_running_loss/len(train_loader), membership_running_loss/len(train_loader), total_loss/len(train_loader), time.time() - stime, scheduler.get_last_lr()[0] * 10 ** 4)) print("train accuracy total : {:.4f}".format(train_acc/(len(MySampler)*args.batch_size))) # print("train accuracy total : {:.4f}".format(train_acc/train_dataset.num_image)) for num in range(args.num_classes): print("label{} : {:.4f}" .format(num, locals()["train_label{}".format(num)]/train_dataset.len_list[num]) , end=" ") print() print("test accuracy total : {:.4f}".format(test_acc/test_dataset.num_image)) for num in range(args.num_classes): print("label{} : {:.4f}" .format(num, locals()["test_label{}".format(num)]/test_dataset.len_list[num]) , end=" ") print("\n") if epoch % 10 == 9: best_TNR, best_AUROC = test_ODIN(model, test_loader, out_test_loader, args.net_type, args) summary.add_scalar('AD_acc/AUROC', best_AUROC, epoch) summary.add_scalar('AD_acc/TNR', best_TNR, epoch) summary.add_scalar('loss/loss', total_loss/len(train_loader), epoch) summary.add_scalar('loss/membership_loss', membership_running_loss/len(train_loader), epoch) summary.add_scalar('acc/train_acc', train_acc/train_dataset.num_image, epoch) summary.add_scalar('acc/test_acc', test_acc/test_dataset.num_image, epoch) summary.add_scalar("learning_rate/lr", scheduler.get_last_lr()[0], epoch) time.sleep(0.001) torch.save({ 'model': model.state_dict(), 'epoch': epoch, 'init_lr' : scheduler.get_last_lr()[0] }, os.path.join(save_model, env, args.net_type, 'checkpoint_last.pth.tar')) scheduler.step()
def main(): start_epoch = 0 save_model = "./save_model_dis/fine" pretrained_model_dir = "./save_model_dis/pre_training" tensorboard_dir = "./tensorboard/OOD_dis/fine/" + args.dataset # Device configuration device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Hyper-parameters eps = 1e-8 ### data config train_dataset = load_data.Dog_dataloader(image_dir = image_dir, num_class = args.num_classes, mode = "train") train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=2) test_dataset = load_data.Dog_dataloader(image_dir = image_dir, num_class = args.num_classes, mode = "test") test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.num_classes, shuffle=True, num_workers=2) out_train_dataset = load_data.Dog_dataloader(image_dir=image_dir, num_class=args.num_classes, mode="OOD_val") out_train_loader = torch.utils.data.DataLoader(out_train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=2) out_test_dataset = load_data.Dog_dataloader(image_dir=image_dir, num_class=args.num_classes, mode="OOD") out_test_loader = torch.utils.data.DataLoader(out_test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=2) ##### model, optimizer config if args.net_type == "resnet50": model = models.resnet50(num_c=args.num_classes, pretrained=True) elif args.net_type == "resnet34": model = models.resnet34(num_c=args.num_classes, pretrained=True) if args.load == True: print("loading model") checkpoint = torch.load(os.path.join(pretrained_model_dir, args.pretrained_model, "checkpoint_last_pre.pth.tar")) ##### load model model.load_state_dict(checkpoint["model"]) optimizer = optim.SGD(model.parameters(), lr=args.init_lr, momentum=0.9, nesterov=args.nesterov) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.num_epochs * len(train_loader)) #### loss config criterion = nn.BCEWithLogitsLoss() #### create folder Path(os.path.join(save_model, env, args.net_type)).mkdir(exist_ok=True, parents=True) if args.board_clear == True: files = glob.glob(tensorboard_dir+"/*") for f in files: shutil.rmtree(f) i = 0 while True: if Path(os.path.join(tensorboard_dir, str(i))).exists() == True: i += 1 else: Path(os.path.join(tensorboard_dir, str(i))).mkdir(exist_ok=True, parents=True) break summary = SummaryWriter(os.path.join(tensorboard_dir, str(i))) # Start training j=0 best_score=0 score = 0 for epoch in range(start_epoch, args.num_epochs): total_class_loss = 0 total_dis_loss = 0 train_acc = 0 test_acc = 0 stime = time.time() model.eval().to(device) with torch.no_grad(): for i, test_data in enumerate(test_loader): org_image = test_data['input'].to(device) gt = test_data['label'].type(torch.FloatTensor).to(device) out1, out2 = model.dis_forward(org_image) score_1 = nn.functional.softmax(out1, dim=1) score_2 = nn.functional.softmax(out2, dim=1) dist = torch.sum(torch.abs(score_1 - score_2), dim=1).reshape((org_image.shape[0], )) if i == 0: dists = dist labels = torch.zeros((org_image.shape[0],)) else: dists = torch.cat((dists, dist), dim=0) labels = torch.cat((labels, torch.zeros((org_image.shape[0]))), dim=0) test_acc += sum(torch.argmax(torch.sigmoid(out1), dim=1) == torch.argmax(gt, dim=1)).cpu().detach().item() test_acc += sum(torch.argmax(torch.sigmoid(out2), dim=1) == torch.argmax(gt, dim=1)).cpu().detach().item() for i, out_org_data in enumerate(out_test_loader): out_org_image = out_org_data['input'].to(device) out1, out2 = model.dis_forward(out_org_image) score_1 = nn.functional.softmax(out1, dim=1) score_2 = nn.functional.softmax(out2, dim=1) dist = torch.sum(torch.abs(score_1 - score_2), dim=1).reshape((out_org_image.shape[0], -1)) dists = torch.cat((dists, dist), dim=0) labels = torch.cat((labels, torch.ones((out_org_image.shape[0]))), dim=0) roc = evaluate(labels.cpu(), dists.cpu(), metric='roc') print('Epoch{} AUROC: {:.3f}, test accuracy : {:.4f}'.format(epoch, roc, test_acc/test_dataset.num_image/2)) for i, (org_data, out_org_data) in enumerate(zip(train_loader, out_train_loader)): #### initialized org_image = org_data['input'].to(device) out_org_image = out_org_data['input'].to(device) model = model.to(device).train() gt = org_data['label'].type(torch.FloatTensor).to(device) optimizer.zero_grad() #### forward path out1, out2 = model.dis_forward(org_image) #### calc accuracy train_acc += sum(torch.argmax(out1, dim=1) == torch.argmax(gt, dim=1)).cpu().detach().item() train_acc += sum(torch.argmax(out2, dim=1) == torch.argmax(gt, dim=1)).cpu().detach().item() #### calc loss class1_loss = criterion(out1, gt) class2_loss = criterion(out2, gt) out1, out2 = model.dis_forward(out_org_image) dis_loss = DiscrepancyLoss(out1, out2, args.m) loss = class1_loss + class2_loss + dis_loss total_class_loss += class1_loss.item() + class2_loss.item() total_dis_loss += dis_loss.item() loss.backward() optimizer.step() scheduler.step() print('Epoch [{}/{}], Step {}, class_loss = {:.4f}, dis_loss = {:.4f}, exe time: {:.2f}, lr: {:.4f}*e-4' .format(epoch, args.num_epochs, i+1, total_class_loss/len(out_train_loader), dis_loss/len(out_train_loader), time.time() - stime, scheduler.get_last_lr()[0] * 10 ** 4)) summary.add_scalar('loss/class_loss', total_class_loss/len(train_loader), epoch) summary.add_scalar('loss/dis_loss', total_dis_loss/len(train_loader), epoch) summary.add_scalar('acc/roc', roc, epoch) summary.add_scalar("learning_rate/lr", scheduler.get_last_lr()[0], epoch) time.sleep(0.001) torch.save({ 'model': model.state_dict(), 'epoch': epoch, 'init_lr' : scheduler.get_last_lr()[0] }, os.path.join(save_model, env, args.net_type, 'checkpoint_last_fine.pth.tar'))
def main(): output_dir = "./save_fig" # Device configuration device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Hyper-parameters eps = 1e-8 ### data config test_dataset = load_data.Dog_dataloader(image_dir=image_dir, num_class=args.num_classes, mode="test") test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=2) ### novelty data out_test_dataset = load_data.Dog_dataloader(image_dir=image_dir, num_class=args.num_classes, mode="OOD") out_test_loader = torch.utils.data.DataLoader(out_test_dataset, batch_size=1, shuffle=False, num_workers=2) ##### model, optimizer config if args.net_type == "resnet50": model = models.resnet50(num_c=args.num_classes, pretrained=True) elif args.net_type == "resnet34": model = models.resnet34(num_c=args.num_classes, num_cc=args.OOD_num_classes, pretrained=True) elif args.net_type == "vgg19": model = models.vgg19(num_c=args.num_classes, num_cc=args.OOD_num_classes, pretrained=True) elif args.net_type == "vgg16": model = models.vgg16(num_c=args.num_classes, num_cc=args.OOD_num_classes, pretrained=True) elif args.net_type == "vgg19_bn": model = models.vgg19_bn(num_c=args.num_classes, num_cc=args.OOD_num_classes, pretrained=True) elif args.net_type == "vgg16_bn": model = models.vgg16_bn(num_c=args.num_classes, num_cc=args.OOD_num_classes, pretrained=True) print("load checkpoint_last") checkpoint = torch.load(args.model_path) ##### load model model.load_state_dict(checkpoint["model"]) start_epoch = checkpoint["epoch"] optimizer = optim.SGD(model.parameters(), lr=checkpoint["init_lr"]) #### create folder Path(output_dir).mkdir(exist_ok=True, parents=True) model = model.to(device).eval() # Start grad-CAM bp = BackPropagation(model=model) inv_normalize = transforms.Normalize( mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.255], std=[1 / 0.229, 1 / 0.224, 1 / 0.255]) target_layer = "layer4" stime = time.time() gcam = GradCAM(model=model) grad_cam = GradCAMmodule(target_layer, output_dir) grad_cam.model_config(model) for j, test_data in enumerate(test_loader): #### initialized org_image = test_data['input'].to(device) target_class = test_data['label'].to(device) target_class = int(target_class.argmax().cpu().detach()) result = model(org_image).argmax() print("number: {} pred: {} target: {}".format(j, result, target_class)) result = int(result.cpu().detach()) grad_cam.saveGradCAM(org_image, result, j)