def train(args): json_options = json_file_to_pyobj(args.config) no_teacher_configurations = json_options.training wrn_depth = no_teacher_configurations.wrn_depth wrn_width = no_teacher_configurations.wrn_width M = no_teacher_configurations.M dataset = no_teacher_configurations.dataset seeds = [int(seed) for seed in no_teacher_configurations.seeds] log = True if no_teacher_configurations.log.lower() == 'True' else False if log: net_str = "WideResNet-{}-{}".format(wrn_depth, wrn_width) logfile = "No_Teacher-{}-{}-M-{}.txt".format( net_str, no_teacher_configurations.dataset, M) with open(os.path.join('./', logfile), "w") as temp: temp.write('No teacher {} in {} with M={}\n'.format( net_str, no_teacher_configurations.dataset, M)) else: logfile = '' checkpoint = bool(no_teacher_configurations.checkpoint) if torch.cuda.is_available(): device = torch.device('cuda:0') else: device = torch.device('cpu') test_set_accuracies = [] for seed in seeds: set_seed(seed) if dataset.lower() == 'cifar10': # Full data if M == 5000: from utils import cifar10loaders loaders = cifar10loaders() # No data elif M == 0: from utils import cifar10loaders _, test_loader = cifar10loaders() else: from utils import cifar10loadersM loaders = cifar10loadersM(M) elif dataset.lower() == 'svhn': # Full data if M == 5000: from utils import svhnLoaders loaders = svhnLoaders() # No data elif M == 0: from utils import svhnLoaders _, test_loader = svhnLoaders() else: from utils import svhnloadersM loaders = svhnloadersM(M) else: raise ValueError('Datasets to choose from: CIFAR10 and SVHN') if log: with open(os.path.join('./', logfile), "a") as temp: temp.write( '------------------- SEED {} -------------------\n'.format( seed)) strides = [1, 1, 2, 2] net = WideResNet(d=wrn_depth, k=wrn_width, n_classes=10, input_features=3, output_features=16, strides=strides) net = net.to(device) checkpointFile = 'No_teacher_wrn-{}-{}-M-{}-seed-{}-{}-dict.pth'.format( wrn_depth, wrn_width, M, seed, dataset) if checkpoint else '' best_test_set_accuracy = _train_seed_no_teacher( net, M, loaders, device, dataset, log, checkpoint, logfile, checkpointFile) if log: with open(os.path.join('./', logfile), "a") as temp: temp.write('Best test set accuracy of seed {} is {}\n'.format( seed, best_test_set_accuracy)) test_set_accuracies.append(best_test_set_accuracy) if log: with open(os.path.join('./', logfile), "a") as temp: temp.write('Best test set accuracy of seed {} is {}\n'.format( seed, best_test_set_accuracy)) mean_test_set_accuracy, std_test_set_accuracy = np.mean( test_set_accuracies), np.std(test_set_accuracies) if log: with open(os.path.join('./', logfile), "a") as temp: temp.write( 'Mean test set accuracy is {} with standard deviation equal to {}\n' .format(mean_test_set_accuracy, std_test_set_accuracy))
def train(args): json_options = json_file_to_pyobj(args.config) training_configurations = json_options.training wandb.init( name= f"{training_configurations.checkpoint}_subset_{args.subset_index}_ensemble" ) device = torch.device(f'cuda:{args.device}') flag = False if training_configurations.train_pickle != 'None' and training_configurations.test_pickle != 'None': pickle_files = [ training_configurations.train_pickle, training_configurations.test_pickle ] flag = True if args.subset_index is None: model = build_model(args) model = model.to(device) epochs = 40 optimizer = optim.SGD(model.parameters(), lr=1.25e-2, momentum=0.9, nesterov=True, weight_decay=1e-4) scheduler = MultiStepLR(optimizer, milestones=[10, 20, 30], gamma=0.1) dataset = args.dataset.lower() b = 0.2 m = 0.4 if not flag: trainloader, val_loader, testloader = fine_grained_image_loaders_subset( dataset, subset_index=args.subset_index, validation_test_split=800, save_to_pickle=True) else: pickle_files[0] = "pickle_files/" + pickle_files[0].split( ".pickle")[0] + f"_subset_{args.subset_index}.pickle" pickle_files[1] = "pickle_files/" + pickle_files[1].split( ".pickle")[0] + f"_subset_{args.subset_index}.pickle" trainloader, val_loader, testloader, num_classes = fine_grained_image_loaders_subset( dataset, subset_index=args.subset_index, validation_test_split=800, pickle_files=pickle_files, ret_num_classes=True) train_ood_loader = fine_grained_image_loaders_subset( dataset, single=True, subset_index=args.subset_index, validation_test_split=800, pickle_files=pickle_files) if 'genOdin' in training_configurations.checkpoint: weight_decay = 1e-4 optimizer = optim.SGD([ { 'params': model._conv_stem.parameters(), 'weight_decay': weight_decay }, { 'params': model._bn0.parameters(), 'weight_decay': weight_decay }, { 'params': model._blocks.parameters(), 'weight_decay': weight_decay }, { 'params': model._conv_head.parameters(), 'weight_decay': weight_decay }, { 'params': model._bn1.parameters(), 'weight_decay': weight_decay }, { 'params': model._fc_denominator.parameters(), 'weight_decay': weight_decay }, { 'params': model._denominator_batch_norm.parameters(), 'weight_decay': weight_decay }, { 'params': model._fc_nominator.parameters(), 'weight_decay': 0 }, ], lr=1.25e-2, momentum=0.9, nesterov=True) scheduler = MultiStepLR(optimizer, milestones=[10, 20, 30], gamma=0.1) if args.subset_index is not None: model = build_model(args) model._fc = nn.Linear(model._fc.in_features, num_classes) model = model.to(device) epochs = 40 optimizer = optim.SGD(model.parameters(), lr=1.25e-2, momentum=0.9, nesterov=True, weight_decay=1e-4) scheduler = MultiStepLR(optimizer, milestones=[10, 20, 30], gamma=0.1) criterion = nn.CrossEntropyLoss() checkpoint_val_accuracy, best_val_acc, test_set_accuracy = 0, 0, 0 ood_loader_iter = iter(train_ood_loader) for epoch in tqdm(range(epochs)): model.train() correct, total = 0, 0 train_loss = 0 for data in tqdm(trainloader): model.train() inputs, labels = data inputs = inputs.to(device) labels = labels.to(device) optimizer.zero_grad() outputs = model(inputs) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() ce_loss = criterion(outputs, labels) try: ood_inputs, _ = next(ood_loader_iter) except: ood_loader_iter = iter(train_ood_loader) ood_inputs, _ = next(ood_loader_iter) ood_inputs = ood_inputs.to(device) ood_outputs = model(ood_inputs) entropy_input = -torch.mean( torch.sum( F.log_softmax(outputs, dim=1) * F.softmax(outputs, dim=1), dim=1)) entropy_output = -torch.mean( torch.sum(F.log_softmax(ood_outputs, dim=1) * F.softmax(ood_outputs, dim=1), dim=1)) margin_loss = b * torch.clamp(m + entropy_input - entropy_output, min=0) loss = ce_loss + margin_loss train_loss += loss.item() loss.backward() optimizer.step() train_accuracy = correct / total wandb.log({'epoch': epoch}, commit=False) wandb.log({ 'Train Set Loss': train_loss / trainloader.__len__(), 'epoch': epoch }) wandb.log({'Train Set Accuracy': train_accuracy, 'epoch': epoch}) model.eval() correct, total = 0, 0 with torch.no_grad(): for data in val_loader: images, labels = data images = images.to(device) labels = labels.to(device) outputs = model(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() epoch_val_accuracy = correct / total wandb.log({ 'Validation Set Accuracy': epoch_val_accuracy, 'epoch': epoch }) if epoch_val_accuracy > best_val_acc: best_val_acc = epoch_val_accuracy if os.path.exists('/raid/ferles/'): torch.save( model.state_dict(), f'/raid/ferles/checkpoints/eb0/{dataset}/{training_configurations.checkpoint}_subset_ens_{args.subset_index}.pth' ) else: torch.save( model.state_dict(), f'/home/ferles/checkpoints/eb0/{dataset}/{training_configurations.checkpoint}_subset_ens_{args.subset_index}.pth' ) correct, total = 0, 0 for data in testloader: images, labels = data images = images.to(device) labels = labels.to(device) if 'genodin' in training_configurations.checkpoint.lower(): outputs, h, g = model(images) else: outputs = model(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() test_set_accuracy = correct / total wandb.log({'Test Set Accuracy': test_set_accuracy, 'epoch': epoch}) scheduler.step(epoch=epoch)
def train(args): device = torch.device(f'cuda:{args.device}') json_options = json_file_to_pyobj(args.config) training_configurations = json_options.training traincsv = training_configurations.traincsv testcsv = training_configurations.testcsv gtFileName = training_configurations.gtFile out_classes = training_configurations.out_classes exclude_class = training_configurations.exclude_class exclude_class = None if exclude_class == "None" else exclude_class if exclude_class is None: wandb.init(name='oe_isic') else: wandb.init(name=f'oe_{exclude_class}') batch_size = 32 if exclude_class is None: train_loader, val_loader, test_loader, columns = oversampling_loaders_custom(csvfiles=[traincsv, testcsv], train_batch_size=32, val_batch_size=16, gtFile=gtFileName) else: train_loader, val_loader, test_loader, columns = oversampling_loaders_exclude_class_custom_no_gts(csvfiles=[traincsv, testcsv], train_batch_size=32, val_batch_size=16, gtFile=gtFileName, exclude_class=exclude_class) ood_loader = imageNetLoader(dataset='isic', batch_size=batch_size) ood_loader_iter = iter(ood_loader) model = build_model(args).to(device) epochs = 40 criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=1.25e-2, momentum=0.9, nesterov=True, weight_decay=1e-4) scheduler = MultiStepLR(optimizer, milestones=[10, 20, 30], gamma=0.1) uniform = torch.ones(size=(batch_size, out_classes)) / float(out_classes) uniform = uniform.to(device) lamda = 0.5 checkpoint_val_accuracy, best_val_acc, test_set_accuracy = 0, 0, 0 for epoch in tqdm(range(epochs)): model.train() loss_acc = [] for data in tqdm(train_loader): model.train() inputs, labels = data inputs = inputs.to(device) labels = labels.to(device) optimizer.zero_grad() outputs = model(inputs) _, predicted = torch.max(outputs.data, 1) try: ood_inputs, _ = next(ood_loader_iter) except: ood_loader_iter = iter(ood_loader) ood_inputs, _ = next(ood_loader_iter) ood_inputs = ood_inputs.to(device) ood_outputs = model(ood_inputs) _labels = torch.argmax(labels, dim=1) ce_loss = criterion(outputs, _labels) if ood_outputs.size(0) < batch_size: uniform = torch.ones(size=(ood_outputs.size(0), out_classes)) / float(out_classes) uniform = uniform.to(device) outlier_loss = lamda * -(uniform.mean(1) - torch.logsumexp(ood_outputs, dim=1)).mean() loss = ce_loss + outlier_loss loss_acc.append(loss.item()) loss.backward() optimizer.step() if ood_outputs.size(0) < batch_size: uniform = torch.ones(size=(batch_size, out_classes)) / float(out_classes) uniform = uniform.to(device) wandb.log({'epoch': epoch}, commit=False) wandb.log({'Train Set Loss': sum(loss_acc) / float(train_loader.__len__()), 'epoch': epoch}) model.eval() correct, total = 0, 0 with torch.no_grad(): for data in val_loader: images, labels = data images = images.to(device) labels = labels.to(device) _labels = torch.argmax(labels, dim=1) outputs = model(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == _labels).sum().item() val_detection_accuracy = round(100*correct/total, 2) wandb.log({'Validation Detection Accuracy': val_detection_accuracy, 'epoch': epoch}) if val_detection_accuracy > best_val_acc: best_val_acc = val_detection_accuracy if os.path.exists('/raid/ferles/'): if exclude_class is None: torch.save(model.state_dict(), f'/raid/ferles/checkpoints/isic_classifiers/outlier_exposure_isic.pth') else: torch.save(model.state_dict(), f'/raid/ferles/checkpoints/isic_classifiers/outlier_exposure_{exclude_class}.pth') else: if exclude_class is None: torch.save(model.state_dict(), f'/home/ferles/checkpoints/isic_classifiers/outlier_exposure_isic.pth') else: torch.save(model.state_dict(), f'/home/ferles/checkpoints/isic_classifiers/outlier_exposure_{exclude_class}.pth') correct, total = 0, 0 for data in test_loader: _, images, labels = data images = images.to(device) labels = labels.to(device) _labels = torch.argmax(labels, dim=1) outputs = model(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == _labels).sum().item() test_detection_accuracy = correct / total wandb.log({'Detection Accuracy': test_detection_accuracy, 'epoch': epoch}) scheduler.step(epoch=epoch)
def train(args): use_wandb = True device = torch.device(f'cuda:{args.device}') json_options = json_file_to_pyobj(args.config) training_configurations = json_options.training traincsv = training_configurations.traincsv testcsv = training_configurations.testcsv gtFileName = training_configurations.gtFile checkpointFileName = training_configurations.checkpointFile out_classes = training_configurations.out_classes exclude_class = training_configurations.exclude_class exclude_class = None if exclude_class == "None" else exclude_class if use_wandb: wandb.init(name=checkpointFileName) if exclude_class is None: train_loader, val_loader, test_loader, columns = oversampling_loaders_custom(csvfiles=[traincsv, testcsv], train_batch_size=32, val_batch_size=16, gtFile=gtFileName) else: train_loader, val_loader, test_loader, columns = oversampling_loaders_exclude_class_custom_no_gts(csvfiles=[traincsv, testcsv], train_batch_size=32, val_batch_size=16, gtFile=gtFileName, exclude_class=exclude_class) model = build_model(args).to(device) epochs = 40 criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=1.25e-2, momentum=0.9, nesterov=True, weight_decay=1e-4) scheduler = MultiStepLR(optimizer, milestones=[10, 20, 30], gamma=0.1) best_val_detection_accuracy, test_detection_accuracy = 0, 0 train_loss, val_loss, balanced_accuracies = [], [], [] early_stopping = False early_stopping_cnt = 0 for epoch in tqdm(range(epochs)): model.train() loss_acc = [] for data in tqdm(train_loader): inputs, labels = data inputs = inputs.to(device) labels = labels.to(device) optimizer.zero_grad() if 'genodin' in training_configurations.checkpointFile.lower(): outputs, _, _ = model(inputs) else: outputs = model(inputs) _labels = torch.argmax(labels, dim=1) loss = criterion(outputs, _labels) loss_acc.append(loss.item()) loss.backward() optimizer.step() wandb.log({'Train Set Loss': sum(loss_acc) / float(train_loader.__len__()), 'epoch': epoch}) wandb.log({'epoch': epoch}, commit=False) train_loss.append(sum(loss_acc) / float(train_loader.__len__())) loss_acc.clear() with torch.no_grad(): correct, total = 0, 0 for data in tqdm(val_loader): images, labels = data images = images.to(device) labels = labels.to(device) if 'genodin' in training_configurations.checkpointFile.lower(): outputs, _, _ = model(images) else: outputs = model(images) softmax_outputs = torch.softmax(outputs, 1) max_idx = torch.argmax(softmax_outputs, axis=1) _labels = torch.argmax(labels, dim=1) correct += (max_idx == _labels).sum().item() total += max_idx.size()[0] loss = criterion(outputs, _labels) loss_acc.append(loss.item()) val_detection_accuracy = round(100*correct/total, 2) wandb.log({'Validation Detection Accuracy': val_detection_accuracy, 'epoch': epoch}) if val_detection_accuracy > best_val_detection_accuracy: best_val_detection_accuracy = val_detection_accuracy if 'genodin' in training_configurations.checkpointFile.lower(): # test_loss, auc, balanced_accuracy, test_detection_accuracy = _test_set_eval(model, epoch, device, test_loader, out_classes, columns, gtFileName, gen=True) test_loss, test_detection_accuracy = _test_set_eval(model, epoch, device, test_loader, out_classes, columns, gtFileName, gen=True) else: # test_loss, auc, balanced_accuracy, test_detection_accuracy = _test_set_eval(model, epoch, device, test_loader, out_classes, columns, gtFileName) test_loss, test_detection_accuracy = _test_set_eval(model, epoch, device, test_loader, out_classes, columns, gtFileName) checkpointFile = os.path.join(f'/raid/ferles/checkpoints/isic_classifiers/{checkpointFileName}-best-model.pth') if os.path.exists(checkpointFile): torch.save(model.state_dict(), checkpointFile) else: torch.save(model.state_dict(), checkpointFile.replace('raid', 'home')) else: if early_stopping: early_stopping_cnt += 1 if early_stopping_cnt == 3: break wandb.log({'Val Set Loss': val_loss, 'epoch': epoch}) wandb.log({'Detection Accuracy': test_detection_accuracy, 'epoch': epoch}) # wandb.log({'Balanced Accuracy': balanced_accuracy, 'epoch': epoch}) # wandb.log({'AUC': auc, 'epoch': epoch}) scheduler.step()
def train(args): json_options = json_file_to_pyobj(args.config) kd_att_configurations = json_options.training wrn_depth_teacher = kd_att_configurations.wrn_depth_teacher wrn_width_teacher = kd_att_configurations.wrn_width_teacher wrn_depth_student = kd_att_configurations.wrn_depth_student wrn_width_student = kd_att_configurations.wrn_width_student M = kd_att_configurations.M dataset = kd_att_configurations.dataset seeds = [int(seed) for seed in kd_att_configurations.seeds] log = True if kd_att_configurations.log.lower() == 'True' else False if log: teacher_str = "WideResNet-{}-{}".format(wrn_depth_teacher, wrn_width_teacher) student_str = "WideResNet-{}-{}".format(wrn_depth_student, wrn_width_student) logfile = "Teacher-{}-Student-{}-{}-M-{}-seeds-1-2.txt".format( teacher_str, student_str, kd_att_configurations.dataset, M) print(logfile) with open(os.path.join('./', logfile), "w") as temp: temp.write( 'KD_ATT with teacher {} and student {} in {} with M={}\n'. format(teacher_str, student_str, kd_att_configurations.dataset, M)) else: logfile = '' checkpoint = bool(kd_att_configurations.checkpoint) if torch.cuda.is_available(): device = torch.device('cuda:0') else: device = torch.device('cpu') test_set_accuracies = [] for seed in seeds: set_seed(seed) if dataset.lower() == 'cifar10': # Full data if M == 5000: from utils import cifar10loaders loaders = cifar10loaders() # No data elif M == 0: from utils import cifar10loaders _, test_loader = cifar10loaders() else: from utils import cifar10loadersM loaders = cifar10loadersM(M) elif dataset.lower() == 'svhn': # Full data if M == 5000: from utils import svhnLoaders loaders = svhnLoaders() # No data elif M == 0: from utils import svhnLoaders _, test_loader = svhnLoaders() else: from utils import svhnloadersM loaders = svhnloadersM(M) else: raise ValueError('Datasets to choose from: CIFAR10 and SVHN') if log: with open(os.path.join('./', logfile), "a") as temp: temp.write( '------------------- SEED {} -------------------\n'.format( seed)) strides = [1, 1, 2, 2] teacher_net = WideResNet(d=wrn_depth_teacher, k=wrn_width_teacher, n_classes=10, input_features=3, output_features=16, strides=strides) teacher_net = teacher_net.to(device) if dataset.lower() == 'cifar10': torch_checkpoint = torch.load( './PreTrainedModels/PreTrainedScratches/CIFAR10/wrn-{}-{}-seed-{}-dict.pth' .format(wrn_depth_teacher, wrn_width_teacher, seed), map_location=device) else: torch_checkpoint = torch.load( './PreTrainedModels/PreTrainedScratches/SVHN/wrn-{}-{}-seed-svhn-{}-dict.pth' .format(wrn_depth_teacher, wrn_width_teacher, seed), map_location=device) teacher_net.load_state_dict(torch_checkpoint) student_net = WideResNet(d=wrn_depth_student, k=wrn_width_student, n_classes=10, input_features=3, output_features=16, strides=strides) student_net = student_net.to(device) checkpointFile = 'kd_att_teacher_wrn-{}-{}_student_wrn-{}-{}-M-{}-seed-{}-{}-dict.pth'.format( wrn_depth_teacher, wrn_width_teacher, wrn_depth_student, wrn_width_student, M, seed, dataset) if checkpoint else '' if M != 0: best_test_set_accuracy = _train_seed_kd_att( teacher_net, student_net, M, loaders, device, dataset, log, checkpoint, logfile, checkpointFile) if log: with open(os.path.join('./', logfile), "a") as temp: temp.write( 'Best test set accuracy of seed {} is {}\n'.format( seed, best_test_set_accuracy)) test_set_accuracies.append(best_test_set_accuracy) if log: with open(os.path.join('./', logfile), "a") as temp: temp.write( 'Best test set accuracy of seed {} is {}\n'.format( seed, best_test_set_accuracy)) else: best_test_set_accuracy = _test_set_eval(student_net, device, test_loader) test_set_accuracies.append(best_test_set_accuracy) mean_test_set_accuracy, std_test_set_accuracy = np.mean( test_set_accuracies), np.std(test_set_accuracies) if log: with open(os.path.join('./', logfile), "a") as temp: temp.write( 'Mean test set accuracy is {} with standard deviation equal to {}\n' .format(mean_test_set_accuracy, std_test_set_accuracy))
def train(args): json_options = json_file_to_pyobj(args.config) training_configurations = json_options.training wandb.init(name=f'rot_{training_configurations.checkpoint}') device = torch.device(f'cuda') flag = False if training_configurations.train_pickle != 'None' and training_configurations.test_pickle != 'None': pickle_files = [training_configurations.train_pickle, training_configurations.test_pickle] flag = True if args.checkpoint is None: model = build_model(args, rot=True) model = nn.DataParallel(model).to(device) else: model = build_model_with_checkpoint(modelName='rot' + training_configurations.model.lower(), model_checkpoint=args.checkpoint, device=device, out_classes=training_configurations.out_classes, rot=True) model = nn.DataParallel(model).to(device) dataset = args.dataset.lower() if 'wide' in training_configurations.model.lower(): resize = False epochs = 200 optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, nesterov=True, weight_decay=5e-4) scheduler = MultiStepLR(optimizer, milestones=[60, 120, 160], gamma=0.2) else: resize = True epochs = 40 optimizer = optim.SGD(model.parameters(), lr=1.25e-2, momentum=0.9, nesterov=True, weight_decay=1e-4) scheduler = MultiStepLR(optimizer, milestones=[10, 20, 30], gamma=0.1) if not flag: trainloader, val_loader, testloader = natural_image_loaders(dataset, train_batch_size=32, test_batch_size=16, validation_test_split=1000, save_to_pickle=True, resize=resize) else: trainloader, val_loader, testloader = natural_image_loaders(dataset, train_batch_size=32, test_batch_size=16, validation_test_split=1000, pickle_files=pickle_files, resize=resize) criterion = nn.CrossEntropyLoss() checkpoint_val_accuracy, best_val_acc, test_set_accuracy = 0, 0, 0 train_loss, test_loss = 0, 0 for epoch in tqdm(range(epochs)): model.train() correct, total = 0, 0 for data in tqdm(trainloader): inputs, labels = data labels = labels.to(device) optimizer.zero_grad() outputs = model(inputs) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() ce_loss = criterion(outputs, labels) rot_gt = torch.cat((torch.zeros(inputs.size(0)), torch.ones(inputs.size(0)), 2*torch.ones(inputs.size(0)), 3*torch.ones(inputs.size(0))), 0).long().to(device) rot_inputs = inputs.detach().cpu().numpy() rot_inputs = np.concatenate((rot_inputs, np.rot90(rot_inputs, 1, axes=(2, 3)), np.rot90(rot_inputs, 2, axes=(2, 3)), np.rot90(rot_inputs, 3, axes=(2, 3))), 0) rot_inputs = torch.FloatTensor(rot_inputs) rot_preds = model(rot_inputs, rot=True) rot_loss = criterion(rot_preds, rot_gt) loss = ce_loss + rot_loss loss.backward() optimizer.step() train_loss += loss.item() train_accuracy = correct / total wandb.log({'epoch': epoch}, commit=False) wandb.log({'Train Set Loss': train_loss / trainloader.__len__(), 'epoch': epoch}) wandb.log({'Train Set Accuracy': train_accuracy, 'epoch': epoch}) model.eval() correct, total = 0, 0 with torch.no_grad(): for data in val_loader: images, labels = data images = images.to(device) labels = labels.to(device) outputs = model(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() epoch_val_accuracy = correct / total wandb.log({'Validation Set Accuracy': epoch_val_accuracy, 'epoch': epoch}) if epoch_val_accuracy > best_val_acc: best_val_acc = epoch_val_accuracy if os.path.exists('/raid/ferles/'): torch.save(model.state_dict(), f'/raid/ferles/checkpoints/eb0/{dataset}/rot_{training_configurations.checkpoint}.pth') else: torch.save(model.state_dict(), f'/home/ferles/checkpoints/eb0/{dataset}/rot_{training_configurations.checkpoint}.pth') correct, total = 0, 0 for data in testloader: images, labels = data images = images.to(device) labels = labels.to(device) outputs = model(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() test_set_accuracy = correct / total wandb.log({'Test Set Accuracy': test_set_accuracy, 'epoch': epoch}) scheduler.step(epoch=epoch)
def train(args): json_options = json_file_to_pyobj(args.config) extra_M_configuration = json_options.training wrn_depth_teacher = extra_M_configuration.wrn_depth_teacher wrn_width_teacher = extra_M_configuration.wrn_width_teacher wrn_depth_student = extra_M_configuration.wrn_depth_student wrn_width_student = extra_M_configuration.wrn_width_student M = extra_M_configuration.M dataset = extra_M_configuration.dataset seeds = [int(seed) for seed in extra_M_configuration.seeds] log = True if extra_M_configuration.log.lower() == 'True' else False if dataset.lower() == 'cifar10': epochs = 200 elif dataset.lower() == 'svhn': epochs = 100 else: raise ValueError('Unknown dataset') if log: teacher_str = 'WideResNet-{}-{}'.format(wrn_depth_teacher, wrn_width_teacher) student_str = 'WideResNet-{}-{}'.format(wrn_depth_student, wrn_width_student) logfile = 'Extra_M_samples_Reproducibility_Zero_Shot_Teacher-{}-Student-{}-{}-M-{}-Zero-Shot.txt'.format( teacher_str, student_str, extra_M_configuration.dataset, M) with open(logfile, 'w') as temp: temp.write( 'Zero-Shot with teacher {} and student {} in {} with M-{}\n'. format(teacher_str, student_str, extra_M_configuration.dataset, M)) else: logfile = '' checkpoint = bool(extra_M_configuration.checkpoint) if torch.cuda.is_available(): device = torch.device('cuda:2') else: device = torch.device('cpu') test_set_accuracies = [] for seed in seeds: set_seed(seed) if dataset.lower() == 'cifar10': from utils import cifar10loadersM loaders = cifar10loadersM(M) elif dataset.lower() == 'svhn': from utils import svhnloadersM loaders = svhnloadersM(M) else: raise ValueError('Datasets to choose from: CIFAR10 and SVHN') if log: with open(logfile, 'a') as temp: temp.write( '------------------- SEED {} -------------------\n'.format( seed)) strides = [1, 1, 2, 2] teacher_net = WideResNet(d=wrn_depth_teacher, k=wrn_width_teacher, n_classes=10, input_features=3, output_features=16, strides=strides) teacher_net = teacher_net.to(device) if dataset.lower() == 'cifar10': torch_checkpoint = torch.load( './PreTrainedModels/PreTrainedScratches/CIFAR10/wrn-{}-{}-seed-{}-dict.pth' .format(wrn_depth_teacher, wrn_width_teacher, seed), map_location=device) elif dataset.lower() == 'svhn': torch_checkpoint = torch.load( './PreTrainedModels/PreTrainedScratches/SVHN/wrn-{}-{}-seed-svhn-{}-dict.pth' .format(wrn_depth_teacher, wrn_width_teacher, seed), map_location=device) else: raise ValueError('Dataset not found') teacher_net.load_state_dict(torch_checkpoint) student_net = WideResNet(d=wrn_depth_student, k=wrn_width_student, n_classes=10, input_features=3, output_features=16, strides=strides) student_net = student_net.to(device) if dataset.lower() == 'cifar10': torch_checkpoint = torch.load( './PreTrainedModels/Zero-Shot/CIFAR10/reproducibility_zero_shot_teacher_wrn-{}-{}_student_wrn-{}-{}-M-0-seed-{}-CIFAR10-dict.pth' .format(wrn_depth_teacher, wrn_width_teacher, wrn_depth_student, wrn_width_student, seed), map_location=device) elif dataset.lower() == 'svhn': torch_checkpoint = torch.load( './PreTrainedModels/Zero-Shot/SVHN/reproducibility_zero_shot_teacher_wrn-{}-{}_student_wrn-{}-{}-M-0-seed-{}-SVHN-dict.pth' .format(wrn_depth_teacher, wrn_width_teacher, wrn_depth_student, wrn_width_student, seed), map_location=device) else: raise ValueError('Dataset not found') student_net.load_state_dict(torch_checkpoint) if checkpoint: teacher_str = 'WideResNet-{}-{}'.format(wrn_depth_teacher, wrn_width_teacher) student_str = 'WideResNet-{}-{}'.format(wrn_depth_student, wrn_width_student) checkpointFile = 'Checkpoint_Extra_M_samples_Reproducibility_Zero_Shot_Teacher-{}-Student-{}-{}-M-{}-Zero-Shot-seed-{}.pth'.format( teacher_str, student_str, extra_M_configuration.dataset, M, seed) else: checkpointFile = '' best_test_set_accuracy = _train_extra_M(epochs, teacher_net, student_net, M, loaders, device, log, checkpoint, logfile, checkpointFile) test_set_accuracies.append(best_test_set_accuracy) if log: with open(logfile, 'a') as temp: temp.write('Best test set accuracy of seed {} is {}\n'.format( seed, best_test_set_accuracy)) mean_test_set_accuracy, std_test_set_accuracy = np.mean( test_set_accuracies), np.std(test_set_accuracies) if log: with open(logfile, 'a') as temp: temp.write( 'Mean test set accuracy is {} with standard deviation equal to {}\n' .format(mean_test_set_accuracy, std_test_set_accuracy))
def train(args): json_options = json_file_to_pyobj(args.config) training_configurations = json_options.training wrn_depth = training_configurations.wrn_depth wrn_width = training_configurations.wrn_width dataset = training_configurations.dataset.lower() seeds = [int(seed) for seed in training_configurations.seeds] log = True if training_configurations.log.lower() == 'true' else False if log: logfile = 'WideResNet-{}-{}-{}.txt'.format( wrn_depth, wrn_width, training_configurations.dataset) with open(logfile, 'w') as temp: temp.write('WideResNet-{}-{} on {}\n'.format( wrn_depth, wrn_width, training_configurations.dataset)) else: logfile = '' checkpoint = True if training_configurations.checkpoint.lower( ) == 'true' else False loaders = get_loaders(dataset) if torch.cuda.is_available(): device = torch.device('cuda:0') else: device = torch.device('cpu') test_set_accuracies = [] for seed in seeds: set_seed(seed) if log: with open(logfile, 'a') as temp: temp.write( '------------------- SEED {} -------------------\n'.format( seed)) strides = [1, 1, 2, 2] net = WideResNet(d=wrn_depth, k=wrn_width, n_classes=10, input_features=3, output_features=16, strides=strides) net = net.to(device) checkpointFile = 'wrn-{}-{}-seed-{}-{}-dict.pth'.format( wrn_depth, wrn_width, dataset, seed) if checkpoint else '' best_test_set_accuracy = _train_seed(net, loaders, device, dataset, log, checkpoint, logfile, checkpointFile) if log: with open(logfile, 'a') as temp: temp.write('Best test set accuracy of seed {} is {}\n'.format( seed, best_test_set_accuracy)) test_set_accuracies.append(best_test_set_accuracy) mean_test_set_accuracy, std_test_set_accuracy = np.mean( test_set_accuracies), np.std(test_set_accuracies) if log: with open(logfile, 'a') as temp: temp.write( 'Mean test set accuracy is {} with standard deviation equal to {}\n' .format(mean_test_set_accuracy, std_test_set_accuracy))
def train(args): json_options = json_file_to_pyobj(args.config) training_configurations = json_options.training wandb.init(name=training_configurations.checkpoint) device = torch.device(f'cuda:{args.device}') flag = False if training_configurations.train_pickle != 'None' and training_configurations.test_pickle != 'None': pickle_files = [ training_configurations.train_pickle, training_configurations.test_pickle ] flag = True dataset = args.dataset.lower() model = build_model(args, dropout=0.5) # model = build_model(args) model = model.to(device) optimizer = optim.SGD(model.parameters(), lr=1.25e-2, momentum=0.9, nesterov=True, weight_decay=1e-4) resize = True epochs = 40 scheduler = MultiStepLR(optimizer, milestones=[10, 20, 30], gamma=0.1) # epochs = 90 # scheduler = MultiStepLR(optimizer, milestones=[30, 60, 80], gamma=0.1) if 'genOdin' in training_configurations.checkpoint: weight_decay = 1e-4 optimizer = optim.SGD([ { 'params': model._conv_stem.parameters(), 'weight_decay': weight_decay }, { 'params': model._bn0.parameters(), 'weight_decay': weight_decay }, { 'params': model._blocks.parameters(), 'weight_decay': weight_decay }, { 'params': model._conv_head.parameters(), 'weight_decay': weight_decay }, { 'params': model._bn1.parameters(), 'weight_decay': weight_decay }, { 'params': model._fc_denominator.parameters(), 'weight_decay': weight_decay }, { 'params': model._denominator_batch_norm.parameters(), 'weight_decay': weight_decay }, { 'params': model._fc_nominator.parameters(), 'weight_decay': 0 }, ], lr=1.25e-2, momentum=0.9, nesterov=True) scheduler = MultiStepLR(optimizer, milestones=[10, 20, 30], gamma=0.1) if not flag: trainloader, val_loader, testloader = natural_image_loaders( dataset, train_batch_size=32, test_batch_size=32, validation_test_split=1000, save_to_pickle=True, resize=resize) else: trainloader, val_loader, testloader = natural_image_loaders( dataset, train_batch_size=32, test_batch_size=32, validation_test_split=1000, pickle_files=pickle_files, resize=resize) criterion = nn.CrossEntropyLoss() checkpoint_val_accuracy, best_val_acc, test_set_accuracy = 0, 0, 0 for epoch in tqdm(range(epochs)): model.train() correct, total = 0, 0 train_loss = 0 for data in tqdm(trainloader): inputs, labels = data inputs = inputs.to(device) labels = labels.to(device) optimizer.zero_grad() if 'genOdin' in training_configurations.checkpoint: outputs, _, _ = model(inputs) else: outputs = model(inputs) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() loss = criterion(outputs, labels) train_loss += loss.item() loss.backward() optimizer.step() scheduler.step() train_accuracy = correct / total wandb.log({'epoch': epoch}, commit=False) wandb.log({ 'Train Set Loss': train_loss / trainloader.__len__(), 'epoch': epoch }) wandb.log({'Train Set Accuracy': train_accuracy, 'epoch': epoch}) model.eval() correct, total = 0, 0 with torch.no_grad(): for data in val_loader: images, labels = data images = images.to(device) labels = labels.to(device) if 'genOdin' in training_configurations.checkpoint: outputs, _, _ = model(images) else: outputs = model(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() epoch_val_accuracy = correct / total wandb.log({ 'Validation Set Accuracy': epoch_val_accuracy, 'epoch': epoch }) if epoch_val_accuracy > best_val_acc: best_val_acc = epoch_val_accuracy if os.path.exists('/raid/ferles/'): torch.save( model.state_dict(), f'/raid/ferles/checkpoints/eb0/{dataset}/{training_configurations.checkpoint}.pth' ) # torch.save(model.state_dict(), f'/raid/ferles/checkpoints/eb0/{dataset}/extended_{training_configurations.checkpoint}.pth') else: torch.save( model.state_dict(), f'/home/ferles/checkpoints/eb0/{dataset}/{training_configurations.checkpoint}.pth' ) # torch.save(model.state_dict(), f'/home/ferles/checkpoints/eb0/{dataset}/low_dropout_extended_{training_configurations.checkpoint}.pth') correct, total = 0, 0 for data in testloader: images, labels = data images = images.to(device) labels = labels.to(device) if 'genOdin' in training_configurations.checkpoint: outputs, _, _ = model(images) else: outputs = model(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() test_set_accuracy = correct / total wandb.log({'Test Set Accuracy': test_set_accuracy, 'epoch': epoch})
def train(args): json_options = json_file_to_pyobj(args.config) modified_zero_shot_configurations = json_options.training wrn_depth_teacher = modified_zero_shot_configurations.wrn_depth_teacher wrn_width_teacher = modified_zero_shot_configurations.wrn_width_teacher wrn_depth_student = modified_zero_shot_configurations.wrn_depth_student wrn_width_student = modified_zero_shot_configurations.wrn_width_student M = modified_zero_shot_configurations.M dataset = modified_zero_shot_configurations.dataset seeds = [int(seed) for seed in modified_zero_shot_configurations.seeds] log = True if modified_zero_shot_configurations.log.lower() == 'True' else False if log: teacher_str = 'WideResNet-{}-{}'.format(wrn_depth_teacher, wrn_width_teacher) student_str = 'WideResNet-{}-{}'.format(wrn_depth_student, wrn_width_student) logfile = 'Teacher-{}-Student-{}-{}-M-{}-Zero-Shot.txt'.format(teacher_str, student_str, modified_zero_shot_configurations.dataset, M) with open(logfile, 'w') as temp: temp.write('Zero-Shot with teacher {} and student {} in {} with M-{}\n'.format(teacher_str, student_str, modified_zero_shot_configurations.dataset, M)) else: logfile = '' checkpoint = bool(modified_zero_shot_configurations.checkpoint) if torch.cuda.is_available(): device = torch.device('cuda:0') else: device = torch.device('cpu') test_set_accuracies = [] for seed in seeds: set_seed(seed) if dataset.lower() == 'cifar10': from utils import cifar10loaders _, test_loader = cifar10loaders() elif dataset.lower() == 'svhn': from utils import svhnLoaders _, test_loader = svhnLoaders() else: raise ValueError('Datasets to choose from: CIFAR10 and SVHN') if log: with open(logfile, 'a') as temp: temp.write('------------------- SEED {} -------------------\n'.format(seed)) strides = [1, 1, 2, 2] teacher_net = WideResNet(d=wrn_depth_teacher, k=wrn_width_teacher, n_classes=10, input_features=3, output_features=16, strides=strides) teacher_net = teacher_net.to(device) if dataset.lower() == 'cifar10': torch_checkpoint = torch.load('./PreTrainedModels/PreTrainedScratches/CIFAR10/wrn-{}-{}-seed-{}-dict.pth'.format(wrn_depth_teacher, wrn_width_teacher, seed), map_location=device) elif dataset.lower() == 'svhn': torch_checkpoint = torch.load('./PreTrainedModels/PreTrainedScratches/SVHN/wrn-{}-{}-seed-svhn-{}-dict.pth'.format(wrn_depth_teacher, wrn_width_teacher, seed), map_location=device) else: raise ValueError('Dataset not found') teacher_net.load_state_dict(torch_checkpoint) student_net = WideResNet(d=wrn_depth_student, k=wrn_width_student, n_classes=10, input_features=3, output_features=16, strides=strides) student_net = student_net.to(device) generator_net = Generator() generator_net = generator_net.to(device) checkpointFile = 'zero_shot_teacher_wrn-{}-{}_student_wrn-{}-{}-M-{}-seed-{}-{}-dict.pth'.format(wrn_depth_teacher, wrn_width_teacher, wrn_depth_student, wrn_width_student, M, seed, dataset) if checkpoint else '' finalCheckpointFile = 'zero_shot_teacher_wrn-{}-{}_student_wrn-{}-{}-M-{}-seed-{}-{}-final-dict.pth'.format(wrn_depth_teacher, wrn_width_teacher, wrn_depth_student, wrn_width_student, M, seed, dataset) if checkpoint else '' genCheckpointFile = 'zero_shot_teacher_wrn-{}-{}_student_wrn-{}-{}-M-{}-seed-{}-{}-generator-dict.pth'.format(wrn_depth_teacher, wrn_width_teacher, wrn_depth_student, wrn_width_student, M, seed, dataset) if checkpoint else '' best_test_set_accuracy = _train_seed_zero_shot(teacher_net, student_net, generator_net, test_loader, device, log, checkpoint, logfile, checkpointFile, finalCheckpointFile, genCheckpointFile) if log: with open(logfile, 'a') as temp: temp.write('Best test set accuracy of seed {} is {}\n'.format(seed, best_test_set_accuracy)) test_set_accuracies.append(best_test_set_accuracy) if log: with open(logfile, 'a') as temp: temp.write('Best test set accuracy of seed {} is {}\n'.format(seed, best_test_set_accuracy)) mean_test_set_accuracy, std_test_set_accuracy = np.mean(test_set_accuracies), np.std(test_set_accuracies) if log: with open(logfile, 'a') as temp: temp.write('Mean test set accuracy is {} with standard deviation equal to {}\n'.format(mean_test_set_accuracy, std_test_set_accuracy))
def train(args): json_options = json_file_to_pyobj(args.config) training_configurations = json_options.training wandb.init(name=training_configurations.checkpoint + 'Ensemble') device = torch.device(f'cuda:{args.device}') dataset = args.dataset.lower() pickle_files = [ training_configurations.train_pickle, training_configurations.test_pickle ] train_ind_loaders, train_ood_loaders, val_ind_loaders, test_ind_loaders, num_classes, dicts = create_ensemble_loaders( dataset, num_classes=training_configurations.out_classes, pickle_files=pickle_files) criterion = nn.CrossEntropyLoss() b = 0.2 m = 0.4 for index in range(len(train_ind_loaders)): epochs = 40 model = build_model(args) model._fc = nn.Linear(model._fc.in_features, num_classes[index]) model = model.to(device) optimizer = optim.SGD(model.parameters(), lr=1.25e-02, momentum=0.9, nesterov=True, weight_decay=1e-4) scheduler = MultiStepLR(optimizer, milestones=[10, 20, 30], gamma=0.1) train_ind_loader, train_ood_loader = train_ind_loaders[ index], train_ood_loaders[index] val_ind_loader = val_ind_loaders[index] test_ind_loader = test_ind_loaders[index] dic = dicts[index] ood_loader_iter = iter(train_ood_loader) best_val_acc = 0 test_epoch_accuracy = 0 for epoch in tqdm(range(epochs)): model.train() correct, total = 0, 0 train_loss = 0 for data in tqdm(train_ind_loader): inputs, labels = data inputs = inputs.to(device) _labels = torch.LongTensor([dic[int(l)] for l in labels]) labels = _labels.to(device) labels = labels.to(device) optimizer.zero_grad() outputs = model(inputs) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() ce_loss = criterion(outputs, labels) try: ood_inputs, _ = next(ood_loader_iter) except: ood_loader_iter = iter(train_ood_loader) ood_inputs, _ = next(ood_loader_iter) ood_inputs = ood_inputs.to(device) ood_outputs = model(ood_inputs) entropy_input = -torch.mean( torch.sum(F.log_softmax(outputs, dim=1) * F.softmax(outputs, dim=1), dim=1)) entropy_output = -torch.mean( torch.sum(F.log_softmax(ood_outputs, dim=1) * F.softmax(ood_outputs, dim=1), dim=1)) margin_loss = b * torch.clamp( m + entropy_input - entropy_output, min=0) loss = ce_loss + margin_loss train_loss += loss.item() loss.backward() optimizer.step() train_accuracy = correct / total wandb.log({'epoch': epoch}, commit=False) epoch_train_set_loss = train_loss / train_ind_loader.__len__() wandb.log({ f'Train Set Loss {index}': epoch_train_set_loss, 'epoch': epoch }) wandb.log({ f'Train Set Accuracy {index}': train_accuracy, 'epoch': epoch }) with torch.no_grad(): model.eval() v_correct, v_total = 0, 0 for data in val_ind_loader: images, labels = data images = images.to(device) _labels = torch.LongTensor([dic[int(l)] for l in labels]) labels = _labels.to(device) labels = labels.to(device) outputs = model(images) _, predicted = torch.max(outputs.data, 1) v_total += labels.size(0) v_correct += (predicted == labels).sum().item() val_epoch_accuracy = v_correct / v_total wandb.log({ f'Validation Set Accuracy {index}': val_epoch_accuracy, 'epoch': epoch }) if val_epoch_accuracy > best_val_acc: best_val_acc = val_epoch_accuracy torch.save( model.state_dict(), f'/raid/ferles/checkpoints/eb0/{dataset}/{training_configurations.checkpoint}_best_accuracy_ensemble_{index}.pth' ) correct, total = 0, 0 for data in test_ind_loader: images, labels = data images = images.to(device) _labels = torch.LongTensor( [dic[int(l)] for l in labels]) labels = _labels.to(device) labels = labels.to(device) outputs = model(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() test_epoch_accuracy = correct / total wandb.log({ f'Test Set Accuracy {index}': test_epoch_accuracy, 'epoch': epoch }) scheduler.step()
def train(args): use_wandb = True device = torch.device(f'cuda') json_options = json_file_to_pyobj(args.config) training_configurations = json_options.training traincsv = training_configurations.traincsv testcsv = training_configurations.testcsv gtFileName = training_configurations.gtFile checkpointFileName = training_configurations.checkpointFile out_classes = training_configurations.out_classes exclude_class = training_configurations.exclude_class exclude_class = None if exclude_class == "None" else exclude_class if use_wandb: wandb.init(name=checkpointFileName, entity='ferles') if exclude_class is None: train_loader, val_loader, test_loader, columns = oversampling_loaders_custom( csvfiles=[traincsv, testcsv], train_batch_size=32, val_batch_size=16, gtFile=gtFileName) else: train_loader, val_loader, test_loader, columns = oversampling_loaders_exclude_class_custom_no_gts( csvfiles=[traincsv, testcsv], train_batch_size=32, val_batch_size=16, gtFile=gtFileName, exclude_class=exclude_class) model = build_model(args, rot=True) model = nn.DataParallel(model).to(device) epochs = 40 criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=1.25e-2, momentum=0.9, nesterov=True, weight_decay=1e-4) scheduler = MultiStepLR(optimizer, milestones=[10, 20, 30], gamma=0.1) test_loss, best_val_detection_accuracy, test_detection_accuracy = 10, 0, 0 train_loss, val_loss, balanced_accuracies = [], [], [] # early_stopping = True early_stopping = False early_stopping_cnt = 0 for epoch in tqdm(range(epochs)): model.train() loss_acc = [] for data in tqdm(train_loader): inputs, labels = data labels = labels.to(device) optimizer.zero_grad() outputs = model(inputs) _labels = torch.argmax(labels, dim=1) ce_loss = criterion(outputs, _labels) # Rotation Loss rot_gt = torch.cat( (torch.zeros(inputs.size(0)), torch.ones(inputs.size(0)), 2 * torch.ones(inputs.size(0)), 3 * torch.ones(inputs.size(0))), 0).long().to(device) rot_inputs = inputs.detach().cpu().numpy() rot_inputs = np.concatenate( (rot_inputs, np.rot90(rot_inputs, 1, axes=(2, 3)), np.rot90(rot_inputs, 2, axes=(2, 3)), np.rot90(rot_inputs, 3, axes=(2, 3))), 0) rot_inputs = torch.FloatTensor(rot_inputs) rot_preds = model(rot_inputs, rot=True) rot_loss = 0.5 * criterion(rot_preds, rot_gt.to(device)) loss = ce_loss + rot_loss loss_acc.append(loss.item()) loss.backward() optimizer.step() wandb.log({ 'Train Set Loss': sum(loss_acc) / float(train_loader.__len__()), 'epoch': epoch }) wandb.log({'epoch': epoch}, commit=False) train_loss.append(sum(loss_acc) / float(train_loader.__len__())) loss_acc.clear() with torch.no_grad(): model.eval() correct, total = 0, 0 for data in tqdm(val_loader): images, labels = data images = images.to(device) labels = labels.to(device) outputs = model(images) softmax_outputs = torch.softmax(outputs, 1) max_idx = torch.argmax(softmax_outputs, axis=1) _labels = torch.argmax(labels, dim=1) correct += (max_idx == _labels).sum().item() total += max_idx.size()[0] loss = criterion(outputs, _labels) loss_acc.append(loss.item()) val_detection_accuracy = round(100 * correct / total, 2) wandb.log({ 'Validation Detection Accuracy': val_detection_accuracy, 'epoch': epoch }) if val_detection_accuracy > best_val_detection_accuracy: best_val_detection_accuracy = val_detection_accuracy test_loss, test_detection_accuracy = _test_set_eval( model, epoch, device, test_loader, out_classes, columns, gtFileName) if exclude_class is None: checkpointFile = os.path.join( f'/raid/ferles/checkpoints/isic_classifiers/rot_isic-best-model.pth' ) else: checkpointFile = os.path.join( f'/raid/ferles/checkpoints/isic_classifiers/rot_isic-_{exclude_class}-best-model.pth' ) if os.path.exists('/raid/ferles/'): torch.save(model.state_dict(), checkpointFile) else: torch.save(model.state_dict(), checkpointFile.replace('raid', 'home')) else: if early_stopping: early_stopping_cnt += 1 if early_stopping_cnt == 3: wandb.log({'Test Set Loss': test_loss, 'epoch': epoch}) wandb.log({ 'Detection Accuracy': test_detection_accuracy, 'epoch': epoch }) break if exclude_class is None and epoch == 20: break elif exclude_class == 'AK' and epoch == 19: break elif exclude_class == 'BCC' and epoch == 12: break elif exclude_class == 'BKL' and epoch == 15: break elif exclude_class == 'DF' and epoch == 10: break elif exclude_class == 'MEL' and epoch == 12: break elif exclude_class == 'NV' and epoch == 27: break elif exclude_class == 'SCC' and epoch == 10: break elif exclude_class == 'VASC' and epoch == 10: break wandb.log({'Test Set Loss': test_loss, 'epoch': epoch}) wandb.log({ 'Detection Accuracy': test_detection_accuracy, 'epoch': epoch }) # val_loss, auc, balanced_accuracy = _test_set_eval(model, epoch, device, val_loader, out_classes, columns, gtFileName) # if auc > best_auc: # best_auc = auc # checkpointFile = os.path.join(f'{abs_path}/checkpoints/rotation/{checkpointFileName}-best-auc-model.pth') # torch.save(model.state_dict(), checkpointFile) # if balanced_accuracy > best_balanced_accuracy: # best_balanced_accuracy = balanced_accuracy # checkpointFile = os.path.join(f'{abs_path}/checkpoints/rotation/{checkpointFileName}-best-balanced-accuracy-model.pth') # torch.save(model.state_dict(), checkpointFile) # if val_loss < best_val_loss: # best_val_loss = val_loss # checkpointFile = os.path.join(f'{abs_path}/checkpoints/rotation/{checkpointFileName}-best-val-loss-model.pth') # torch.save(model.state_dict(), checkpointFile) # early_stopping_cnt = 0 scheduler.step()
def train(args): json_options = json_file_to_pyobj(args.config) training_configurations = json_options.training wandb.init( name=f"{training_configurations.checkpoint}_subset_{args.subset_index}" ) device = torch.device(f'cuda:{args.device}') flag = False if training_configurations.train_pickle != 'None' and training_configurations.test_pickle != 'None': pickle_files = [ training_configurations.train_pickle, training_configurations.test_pickle ] flag = True model = build_model(args) model = model.to(device) dataset = args.dataset.lower() if 'wide' in training_configurations.model.lower(): epochs = 100 optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, nesterov=True, weight_decay=5e-4) scheduler = MultiStepLR(optimizer, milestones=[20, 50, 80], gamma=0.2) else: epochs = 40 optimizer = optim.SGD(model.parameters(), lr=1.25e-2, momentum=0.9, nesterov=True, weight_decay=1e-4) scheduler = MultiStepLR(optimizer, milestones=[10, 20, 30], gamma=0.1) if not flag: if args.subset_index is None: trainloader, val_loader, testloader = fine_grained_image_loaders( dataset, train_batch_size=32, test_batch_size=32, validation_test_split=1000, save_to_pickle=True) else: trainloader, val_loader, testloader = fine_grained_image_loaders_subset( dataset, subset_index=args.subset_index, validation_test_split=800, save_to_pickle=True) else: if args.subset_index is None: trainloader, val_loader, testloader = fine_grained_image_loaders( dataset, train_batch_size=32, test_batch_size=32, validation_test_split=1000, pickle_files=pickle_files) else: pickle_files[0] = pickle_files[0].split( ".pickle")[0] + f"_subset_{args.subset_index}.pickle" pickle_files[1] = pickle_files[1].split( ".pickle")[0] + f"_subset_{args.subset_index}.pickle" trainloader, val_loader, testloader = fine_grained_image_loaders_subset( dataset, subset_index=args.subset_index, validation_test_split=800, pickle_files=pickle_files) if 'genOdin' in training_configurations.checkpoint: weight_decay = 1e-4 optimizer = optim.SGD([ { 'params': model._conv_stem.parameters(), 'weight_decay': weight_decay }, { 'params': model._bn0.parameters(), 'weight_decay': weight_decay }, { 'params': model._blocks.parameters(), 'weight_decay': weight_decay }, { 'params': model._conv_head.parameters(), 'weight_decay': weight_decay }, { 'params': model._bn1.parameters(), 'weight_decay': weight_decay }, { 'params': model._fc_denominator.parameters(), 'weight_decay': weight_decay }, { 'params': model._denominator_batch_norm.parameters(), 'weight_decay': weight_decay }, { 'params': model._fc_nominator.parameters(), 'weight_decay': 0 }, ], lr=1.25e-2, momentum=0.9, nesterov=True) scheduler = MultiStepLR(optimizer, milestones=[10, 20, 30], gamma=0.1) criterion = nn.CrossEntropyLoss() checkpoint_val_accuracy, best_val_acc, test_set_accuracy = 0, 0, 0 for epoch in tqdm(range(epochs)): model.train() correct, total = 0, 0 train_loss = 0 for data in tqdm(trainloader): model.train() inputs, labels = data inputs = inputs.to(device) labels = labels.to(device) optimizer.zero_grad() if 'genodin' in training_configurations.checkpoint.lower(): outputs, h, g = model(inputs) else: outputs = model(inputs) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() loss = criterion(outputs, labels) train_loss += loss.item() loss.backward() optimizer.step() # if epoch < 2: # model.eval() # v_correct, v_total = 0, 0 # # with torch.no_grad(): # # for v_data in testloader: # v_images, v_labels = v_data # v_images = v_images.to(device) # v_labels = v_labels.to(device) # # v_outputs = model(v_images) # _, v_predicted = torch.max(v_outputs.data, 1) # v_total += v_labels.size(0) # v_correct += (v_predicted == v_labels).sum().item() # # acc = v_correct / v_total # if os.path.exists('/raid/ferles/'): # torch.save(model.state_dict(), f'/raid/ferles/checkpoints/eb0/{dataset}/{training_configurations.checkpoint}_acc_{acc}.pth') # else: # torch.save(model.state_dict(), f'/home/ferles/checkpoints/eb0/{dataset}/{training_configurations.checkpoint}_acc_{acc}.pth') train_accuracy = correct / total wandb.log({'epoch': epoch}, commit=False) wandb.log({ 'Train Set Loss': train_loss / trainloader.__len__(), 'epoch': epoch }) wandb.log({'Train Set Accuracy': train_accuracy, 'epoch': epoch}) model.eval() correct, total = 0, 0 with torch.no_grad(): for data in val_loader: images, labels = data images = images.to(device) labels = labels.to(device) if 'genodin' in training_configurations.checkpoint.lower(): outputs, h, g = model(images) else: outputs = model(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() epoch_val_accuracy = correct / total wandb.log({ 'Validation Set Accuracy': epoch_val_accuracy, 'epoch': epoch }) if epoch_val_accuracy > best_val_acc: best_val_acc = epoch_val_accuracy if os.path.exists('/raid/ferles/'): if args.subset_index is None: torch.save( model.state_dict(), f'/raid/ferles/checkpoints/eb0/{dataset}/{training_configurations.checkpoint}.pth' ) else: torch.save( model.state_dict(), f'/raid/ferles/checkpoints/eb0/{dataset}/{training_configurations.checkpoint}_subset_{args.subset_index}.pth' ) else: if args.subset_index is None: torch.save( model.state_dict(), f'/home/ferles/checkpoints/eb0/{dataset}/{training_configurations.checkpoint}.pth' ) else: torch.save( model.state_dict(), f'/home/ferles/checkpoints/eb0/{dataset}/{training_configurations.checkpoint}_subset_{args.subset_index}.pth' ) # if best_val_acc - checkpoint_val_accuracy > 0.05: # checkpoint_val_accuracy = best_val_acc # torch.save(model.state_dict(), f'/raid/ferles/checkpoints/eb0/{dataset}/{training_configurations.checkpoint}_epoch_{epoch}_accuracy_{best_val_acc}.pth') correct, total = 0, 0 for data in testloader: images, labels = data images = images.to(device) labels = labels.to(device) if 'genodin' in training_configurations.checkpoint.lower(): outputs, h, g = model(images) else: outputs = model(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() test_set_accuracy = correct / total wandb.log({'Test Set Accuracy': test_set_accuracy, 'epoch': epoch}) scheduler.step(epoch=epoch)
def train(args): json_options = json_file_to_pyobj(args.config) training_configurations = json_options.training wandb.init( name=f"{training_configurations.checkpoint}_subset_{args.subset_index}" ) device = torch.device(f'cuda:{args.device}') flag = False if training_configurations.train_pickle != 'None' and training_configurations.test_pickle != 'None': pickle_files = [ training_configurations.train_pickle, training_configurations.test_pickle ] flag = True if args.subset_index is None: model = build_model(args) model = model.to(device) if training_configurations.model == 'EfficientNet': epochs = 40 optimizer = optim.SGD(model.parameters(), lr=1.25e-2, momentum=0.9, nesterov=True, weight_decay=1e-4) scheduler = MultiStepLR(optimizer, milestones=[10, 20, 30], gamma=0.1) batch_size = 32 elif training_configurations.model == 'DenseNet': model = torch.hub.load('pytorch/vision:v0.6.0', 'densenet121', pretrained=True) model = model.to(device) epochs = 200 optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=0.0001) scheduler = MultiStepLR( optimizer, milestones=[int(0.5 * epochs), int(0.75 * epochs)], gamma=0.1) batch_size = 16 dataset = args.dataset.lower() if not flag: if args.subset_index is None: trainloader, val_loader, testloader = fine_grained_image_loaders( dataset, train_batch_size=batch_size, test_batch_size=batch_size, validation_test_split=1000, save_to_pickle=True) else: trainloader, val_loader, testloader = fine_grained_image_loaders_subset( dataset, subset_index=args.subset_index, validation_test_split=800, save_to_pickle=True) else: if args.subset_index is None: trainloader, val_loader, testloader = fine_grained_image_loaders( dataset, train_batch_size=batch_size, test_batch_size=batch_size, validation_test_split=1000, pickle_files=pickle_files) else: pickle_files[0] = "pickle_files/" + pickle_files[0].split( ".pickle")[0] + f"_subset_{args.subset_index}.pickle" pickle_files[1] = "pickle_files/" + pickle_files[1].split( ".pickle")[0] + f"_subset_{args.subset_index}.pickle" trainloader, val_loader, testloader, num_classes = fine_grained_image_loaders_subset( dataset, subset_index=args.subset_index, validation_test_split=800, pickle_files=pickle_files, ret_num_classes=True) if args.subset_index is not None: model = build_model(args) if 'genodin' in training_configurations.checkpoint.lower(): from efficientnet_pytorch.gen_odin_model import CosineSimilarity model._fc_nominator = CosineSimilarity(feat_dim=1280, num_centers=num_classes) else: model._fc = nn.Linear(model._fc.in_features, num_classes) model = model.to(device) epochs = 40 optimizer = optim.SGD(model.parameters(), lr=1.25e-2, momentum=0.9, nesterov=True, weight_decay=1e-4) scheduler = MultiStepLR(optimizer, milestones=[10, 20, 30], gamma=0.1) criterion = nn.CrossEntropyLoss() checkpoint_val_accuracy, best_val_acc, test_set_accuracy = 0, 0, 0 if 'genodin' in training_configurations.checkpoint: weight_decay = 1e-4 optimizer = optim.SGD([ { 'params': model._conv_stem.parameters(), 'weight_decay': weight_decay }, { 'params': model._bn0.parameters(), 'weight_decay': weight_decay }, { 'params': model._blocks.parameters(), 'weight_decay': weight_decay }, { 'params': model._conv_head.parameters(), 'weight_decay': weight_decay }, { 'params': model._bn1.parameters(), 'weight_decay': weight_decay }, { 'params': model._fc_denominator.parameters(), 'weight_decay': weight_decay }, { 'params': model._denominator_batch_norm.parameters(), 'weight_decay': weight_decay }, { 'params': model._fc_nominator.parameters(), 'weight_decay': 0 }, ], lr=1.25e-2, momentum=0.9, nesterov=True) scheduler = MultiStepLR(optimizer, milestones=[10, 20, 30], gamma=0.1) for epoch in tqdm(range(epochs)): model.train() correct, total = 0, 0 train_loss = 0 for data in tqdm(trainloader): model.train() inputs, labels = data inputs = inputs.to(device) labels = labels.to(device) optimizer.zero_grad() if 'genodin' in training_configurations.checkpoint.lower(): outputs, h, g = model(inputs) else: outputs = model(inputs) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() loss = criterion(outputs, labels) train_loss += loss.item() loss.backward() optimizer.step() train_accuracy = correct / total wandb.log({'epoch': epoch}, commit=False) wandb.log({ 'Train Set Loss': train_loss / trainloader.__len__(), 'epoch': epoch }) wandb.log({'Train Set Accuracy': train_accuracy, 'epoch': epoch}) model.eval() correct, total = 0, 0 with torch.no_grad(): for data in val_loader: images, labels = data images = images.to(device) labels = labels.to(device) if 'genodin' in training_configurations.checkpoint.lower(): outputs, h, g = model(images) else: outputs = model(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() epoch_val_accuracy = correct / total wandb.log({ 'Validation Set Accuracy': epoch_val_accuracy, 'epoch': epoch }) if epoch_val_accuracy > best_val_acc: best_val_acc = epoch_val_accuracy if os.path.exists('/raid/ferles/'): if args.subset_index is None: torch.save( model.state_dict(), f'/raid/ferles/checkpoints/eb0/{dataset}/{training_configurations.checkpoint}.pth' ) else: torch.save( model.state_dict(), f'/raid/ferles/checkpoints/eb0/{dataset}/{training_configurations.checkpoint}_subset_{args.subset_index}.pth' ) else: if args.subset_index is None: torch.save( model.state_dict(), f'/home/ferles/checkpoints/eb0/{dataset}/{training_configurations.checkpoint}.pth' ) else: torch.save( model.state_dict(), f'/home/ferles/checkpoints/eb0/{dataset}/{training_configurations.checkpoint}_subset_{args.subset_index}.pth' ) # if best_val_acc - checkpoint_val_accuracy > 0.05: # checkpoint_val_accuracy = best_val_acc # torch.save(model.state_dict(), f'/raid/ferles/checkpoints/eb0/{dataset}/{training_configurations.checkpoint}_epoch_{epoch}_accuracy_{best_val_acc}.pth') correct, total = 0, 0 for data in testloader: images, labels = data images = images.to(device) labels = labels.to(device) if 'genodin' in training_configurations.checkpoint.lower(): outputs, h, g = model(images) else: outputs = model(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() test_set_accuracy = correct / total wandb.log({'Test Set Accuracy': test_set_accuracy, 'epoch': epoch}) scheduler.step(epoch=epoch)
def adversarial_belief_matching(args): json_options = json_file_to_pyobj(args.config) abm_configurations = json_options.abm_setting wrn_depth_teacher = abm_configurations.wrn_depth_teacher wrn_width_teacher = abm_configurations.wrn_width_teacher wrn_depth_student = abm_configurations.wrn_depth_student wrn_width_student = abm_configurations.wrn_width_student dataset = abm_configurations.dataset.lower() seeds = abm_configurations.seeds mode = abm_configurations.mode eval_teacher = True if abm_configurations.eval_teacher.lower( ) == 'true' else False if torch.cuda.is_available(): device = torch.device('cuda:0') else: device = torch.device('cpu') for seed in seeds: teacher_net, student_net = _load_teacher_and_student( abm_configurations, seed, device) test_loader = get_matching_indices(dataset, teacher_net, student_net, device, n=1000) cnt = test_loader.__len__() teacher_net.eval() criterion = nn.CrossEntropyLoss() eta = 1 K = 100 student_image_average_transition_curves_acc, teacher_image_average_transition_curves_acc = [], [] mean_transition_error = 0 # count on how many test set samples teacher and student initially agree (and they are correct too!) for data in tqdm(test_loader): images, _ = data images = images.to(device) student_net.eval() student_outputs = student_net(images)[0] _, student_predicted = torch.max(student_outputs.data, 1) teacher_outputs = teacher_net(images)[0] _, teacher_predicted = torch.max(teacher_outputs.data, 1) x0 = deepcopy(images.detach()) student_transition_curves, teacher_transition_curves = [], [] for fake_label in range(0, 10): if fake_label != student_predicted: fake_label = torch.Tensor([fake_label]).long().to(device) student_probs_acc, teacher_probs_acc = [], [] x_adv = deepcopy(x0) x_adv.requires_grad = True for _ in range(K): if eval_teacher: teacher_fake_outputs = teacher_net(x_adv)[0] with torch.no_grad(): student_fake_outputs = student_net(x_adv)[0] loss = criterion(teacher_fake_outputs, fake_label) teacher_net.zero_grad() loss.backward() x_adv.data -= eta * x_adv.grad.data x_adv.grad.data.zero_() else: student_fake_outputs = student_net(x_adv)[0] with torch.no_grad(): teacher_fake_outputs = teacher_net(x_adv)[0] loss = criterion(student_fake_outputs, fake_label) student_net.zero_grad() loss.backward() x_adv.data -= eta * x_adv.grad.data x_adv.grad.data.zero_() teacher_probs = F.softmax(teacher_fake_outputs, dim=1) student_probs = F.softmax(student_fake_outputs, dim=1) pj_b = teacher_probs[0][fake_label].item() pj_a = student_probs[0][fake_label].item() with torch.no_grad(): student_probs_acc.append(pj_a) teacher_probs_acc.append(pj_b) mean_transition_error += abs(pj_b - pj_a) student_transition_curves.append(student_probs_acc) teacher_transition_curves.append(teacher_probs_acc) else: continue student_image_average_transition_curves_acc.append( np.average(np.array(student_transition_curves), axis=0)) teacher_image_average_transition_curves_acc.append( np.average(np.array(teacher_transition_curves), axis=0)) student_image_average_transition_curves_acc_np = np.average( np.array(student_image_average_transition_curves_acc), axis=0) teacher_image_average_transition_curves_acc_np = np.average( np.array(teacher_image_average_transition_curves_acc), axis=0) np.savez( 'Teacher_WRN-{}-{}_transition_curve-{}-seed-{}.nzp'.format( wrn_depth_teacher, wrn_width_teacher, mode, seed), teacher_image_average_transition_curves_acc_np) np.savez( 'Student_WRN-{}-{}_transition_curve-{}-seed-{}.npz'.format( wrn_depth_student, wrn_width_student, mode, seed), student_image_average_transition_curves_acc_np) # Average MTE over C-1 classes, K adversarial steps and correct initial samples mean_transition_error /= float(9 * K * cnt) write_mode = 'w' if seed == seeds[0] else 'a' with open( 'Teacher_WRN-{}-{}-Student_WRN-{}-{}-{}-MTE.txt'.format( wrn_depth_teacher, wrn_width_teacher, wrn_depth_student, wrn_width_student, mode), write_mode) as logfile: logfile.write( 'Teacher WideResNet-{}-{} and Student WideResNet-{}-{} trained with {} Mean Transition Error on seed {}: {}\n' .format(wrn_depth_teacher, wrn_width_teacher, wrn_depth_student, wrn_width_student, mode, seed, mean_transition_error))