def main(): global args, best_prec1 args = parser.parse_args() # ensuring reproducibility SEED = 42 torch.manual_seed(SEED) torch.backends.cudnn.benchmark = False kwargs = {'num_workers': 1, 'pin_memory': True} device = torch.device("cuda") num_epochs = 7 # create model model = WideResNet(args.layers, 10, args.widen_factor, dropRate=args.droprate).to(device) optimizer = torch.optim.Adam(model.parameters(), args.learning_rate, weight_decay=args.weight_decay) # instantiate loaders train_loader = get_data_loader(args.data_dir, args.batch_size, **kwargs) test_loader = get_test_loader(args.data_dir, 128, **kwargs) tic = time.time() for epoch in range(1, num_epochs + 1): train(model, device, train_loader, optimizer, epoch) test(model, device, test_loader, epoch) toc = time.time() print("Time Elapsed: {}s".format(toc - tic))
def get_model(args): if args.seed is not None: set_seed(args) if args.dataset == "cifar10": depth, widen_factor = 28, 2 elif args.dataset == 'cifar100': depth, widen_factor = 28, 8 student_model = WideResNet(num_classes=args.num_classes, depth=depth, widen_factor=widen_factor, dropout=0, dense_dropout=args.dense_dropout) if os.path.isfile(args.resume): print(f"=> loading checkpoint '{args.resume}'") loc = f'cpu' checkpoint = torch.load(args.resume, map_location=loc) if checkpoint['avg_state_dict'] is not None: model_load_state_dict(student_model, checkpoint['avg_state_dict']) else: model_load_state_dict(student_model, checkpoint['student_state_dict']) print( f"=> loaded checkpoint '{args.resume}' (step {checkpoint['step']})" ) else: print(f"=> no checkpoint found at '{args.resume}'") exit(1) if args.device != 'cpu': student_model.cuda() return student_model
def main(): global args, best_prec1 args = parser.parse_args() # ensuring reproducibility SEED = 42 torch.manual_seed(SEED) torch.backends.cudnn.benchmark = False kwargs = {'num_workers': 1, 'pin_memory': True} device = torch.device("cuda") num_epochs = 7 # create model model = WideResNet(args.layers, 10, args.widen_factor, dropRate=args.droprate).to(device) optimizer = torch.optim.Adam( model.parameters(), args.learning_rate, weight_decay=args.weight_decay ) # instantiate loaders train_loader = get_data_loader(args.data_dir, args.batch_size, **kwargs) test_loader = get_test_loader(args.data_dir, 128, **kwargs) tic = time.time() for epoch in range(1, num_epochs+1): train(model, device, train_loader, optimizer, epoch) test(model, device, test_loader, epoch) toc = time.time() print("Time Elapsed: {}s".format(toc-tic))
def main(): global args, best_prec1 args = parser.parse_args() # ensuring reproducibility SEED = 42 torch.manual_seed(SEED) torch.backends.cudnn.benchmark = False kwargs = {'num_workers': 1, 'pin_memory': True} device = torch.device("cuda") num_epochs_transient = 2 num_epochs_steady = 7 perc_to_remove = 10 torch.manual_seed(SEED) # create model model = WideResNet(args.layers, 10, args.widen_factor, dropRate=args.droprate).to(device) optimizer = torch.optim.Adam( model.parameters(), args.learning_rate, weight_decay=args.weight_decay ) # instantiate loaders train_loader = get_data_loader(args.data_dir, args.batch_size, **kwargs) test_loader = get_test_loader(args.data_dir, 128, **kwargs) tic = time.time() seen_losses = None for epoch in range(1, 3): if epoch == 1: seen_losses = train_transient(model, device, train_loader, optimizer, epoch, track=True) else: train_transient(model, device, train_loader, optimizer, epoch) test(model, device, test_loader, epoch) for epoch in range(3, 4): seen_losses = [v for sublist in seen_losses for v in sublist] sorted_loss_idx = sorted(range(len(seen_losses)), key=lambda k: seen_losses[k][1], reverse=True) removed = sorted_loss_idx[-int((perc_to_remove / 100) * len(sorted_loss_idx)):] sorted_loss_idx = sorted_loss_idx[:-int((perc_to_remove / 100) * len(sorted_loss_idx))] to_add = list(np.random.choice(removed, int(0.33*len(sorted_loss_idx)), replace=False)) sorted_loss_idx = sorted_loss_idx + to_add sorted_loss_idx.sort() weights = [seen_losses[idx][1] for idx in sorted_loss_idx] train_loader = get_weighted_loader(args.data_dir, 64*2, weights, **kwargs) seen_losses = train_steady_state(model, device, train_loader, optimizer, epoch) test(model, device, test_loader, epoch) for epoch in range(4, 8): train_transient(model, device, train_loader, optimizer, epoch) test(model, device, test_loader, epoch) toc = time.time() print("Time Elapsed: {}s".format(toc-tic))
def __init__( self, input_shape, output_dim, patience=4, structure='wide_res_net', ): self.model = None if structure == 'wide_res_net': self.model = WideResNet(input_shape=input_shape, output_dim=output_dim) elif structure == 'res_net': self.model = ResNet(input_shape=input_shape, output_dim=output_dim) else: raise Exception('no structure') self.criterion = tf.keras.losses.CategoricalCrossentropy() self.optimizer = tf.keras.optimizers.SGD(learning_rate=0.1) self.train_loss = tf.keras.metrics.Mean() self.train_acc = tf.keras.metrics.CategoricalAccuracy() self.val_loss = tf.keras.metrics.Mean() self.val_acc = tf.keras.metrics.CategoricalAccuracy() self.history = { 'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': [] } self.es = {'loss': float('inf'), 'patience': patience, 'step': 0} self.save_dir = './logs' if not os.path.exists(self.save_dir): os.mkdir('logs')
def __init__(self, input_shape, encode_dim, output_dim, model='efficient_net', loss='emd'): self.model = None if model == 'efficient_net': self.model = EfficientNet(input_shape, encode_dim, output_dim) elif model == 'wide_res_net': self.model = WideResNet(input_shape, output_dim) else: raise Exception('no match model name') optimizer = tf.keras.optimizers.SGD(learning_rate=0.1, momentum=0.1) loss_func = None if loss == 'emd': loss_func = EMD elif loss == 'categorical_crossentropy': loss_func = 'categorical_crossentropy' else: raise Exception('no match loss function') self.model.compile(optimizer=optimizer, loss=loss_func, metrics=['acc'])
def get_model_for_attack(model_name): if model_name == 'model1': model = ResNet34() load_w(model, "./models/weights/resnet34.pt") elif model_name == 'model2': model = ResNet18() load_w(model, "./models/weights/resnet18_AT.pt") elif model_name == 'model3': model = SmallResNet() load_w(model, "./models/weights/res_small.pth") elif model_name == 'model4': model = WideResNet34() pref = next(model.parameters()) model.load_state_dict( filter_state_dict( torch.load("./models/weights/trades_wide_resnet.pt", map_location=pref.device))) elif model_name == 'model5': model = WideResNet() load_w(model, "./models/weights/wideres34-10-pgdHE.pt") elif model_name == 'model6': model = WideResNet28() pref = next(model.parameters()) model.load_state_dict( filter_state_dict( torch.load('models/weights/RST-AWP_cifar10_linf_wrn28-10.pt', map_location=pref.device))) elif model_name == 'model_vgg16bn': model = vgg16_bn(pretrained=True) elif model_name == 'model_resnet18_imgnet': model = resnet18(pretrained=True) elif model_name == 'model_inception': model = inception_v3(pretrained=True) elif model_name == 'model_vitb': from mnist_vit import ViT, MegaSizer model = MegaSizer( ImageNetRenormalize(ViT('B_16_imagenet1k', pretrained=True))) elif model_name.startswith('model_hub:'): _, a, b = model_name.split(":") model = torch.hub.load(a, b, pretrained=True) model = Cifar10Renormalize(model) elif model_name.startswith('model_mnist:'): _, a = model_name.split(":") model = torch.load('mnist.pt')[a] elif model_name.startswith('model_ex:'): _, a = model_name.split(":") model = torch.load(a) return model
def create_model(config): model_type = config["model_type"] if model_type == "SimpleConvNet": if model_type not in config: config[model_type] = { "conv1_size": 32, "conv2_size": 64, "fc_size": 128 } model = SimpleConvNet(**config[model_type]) elif model_type == "MiniVGG": if model_type not in config: config[model_type] = { "conv1_size": 128, "conv2_size": 256, "classifier_size": 1024 } model = MiniVGG(**config[model_type]) elif model_type == "WideResNet": if model_type not in config: config[model_type] = { "depth": 34, "num_classes": 10, "widen_factor": 10 } model = WideResNet(**config[model_type]) # elif model_type == "ShuffleNetv2": # if model_type not in config: # config[model_type] = {} # model = shufflenet_v2_x0_5() elif model_type == "MobileNetv2": if model_type not in config: config[model_type] = {"pretrained": False} model = mobilenet_v2(num_classes=10, pretrained=config[model_type]["pretrained"]) else: print(f"Error: MODEL_TYPE {model_type} unknown.") exit() config["num_parameters"] = sum(p.numel() for p in model.parameters()) config["num_trainable_parameters"] = sum(p.numel() for p in model.parameters() if p.requires_grad) return model
def get_model_for_attack(model_name): if model_name=='model1': model = ResNet34() model.load_state_dict(torch.load("models/weights/resnet34.pt")) elif model_name=='model2': model = ResNet18() model.load_state_dict(torch.load('models/weights/resnet18_AT.pt')) elif model_name=='model3': model = SmallResNet() model.load_state_dict(torch.load('models/weights/res_small.pth')) elif model_name=='model4': model = WideResNet34() model.load_state_dict(filter_state_dict(torch.load('models/weights/trades_wide_resnet.pt'))) elif model_name=='model5': model = WideResNet() model.load_state_dict(torch.load('models/weights/wideres34-10-pgdHE.pt')) elif model_name=='model6': model = WideResNet28() model.load_state_dict(filter_state_dict(torch.load('models/weights/RST-AWP_cifar10_linf_wrn28-10.pt'))) return model
def get_model(weight_decay=0.0005): # parameters for WideResnet model k = 10 # widening factor N = 4 # number of blocks per stage. Depth = 6*N+4 dropout = 0.3 # WRN 28 - 10 with dropout 0.3 model = WideResNet([16 * k, 32 * k, 64 * k], [N] * 3, dropout, weight_decay, nb_classes=100, batchnorm_training=False, use_bias=False) weights_location = file_loc + 'saved_weights/initial_weights_C100_WRN.h5' if 'initial_weights_C100_WRN.h5' not in os.listdir(file_loc + 'saved_weights'): model.save_weights(weights_location) else: model.load_weights(weights_location) return model
class TrainerV2(object): def __init__(self, input_shape, encode_dim, output_dim, model='efficient_net', loss='emd'): self.model = None if model == 'efficient_net': self.model = EfficientNet(input_shape, encode_dim, output_dim) elif model == 'wide_res_net': self.model = WideResNet(input_shape, output_dim) else: raise Exception('no match model name') optimizer = tf.keras.optimizers.SGD(learning_rate=0.1, momentum=0.1) loss_func = None if loss == 'emd': loss_func = EMD elif loss == 'categorical_crossentropy': loss_func = 'categorical_crossentropy' else: raise Exception('no match loss function') self.model.compile(optimizer=optimizer, loss=loss_func, metrics=['acc']) def train(self, x_train, t_train, x_val, t_val, epochs, batch_size, image_path, save_name): train_gen = DataGenerator(x_train, t_train, image_path=image_path, batch_size=batch_size) val_gen = DataGenerator(x_val, t_val, image_path=image_path, batch_size=batch_size) callbacks = [ tf.keras.callbacks.ModelCheckpoint(save_name, monitor='val_loss', verbose=1, save_best_only=True, mode='min') ] self.history = self.model.fit_generator( train_gen, len(train_gen), epochs=30, validation_data=val_gen, validation_steps=len(val_gen), callbacks=callbacks, ) def evaluate( self, x_test, t_test, batch_size, image_path, ): test_gen = DataGenerator(x_test, t_test, image_path=image_path, batch_size=batch_size) preds = self.model.predict_generator( test_gen, len(test_gen), ) idx = np.array([0, 1, 2, 3, 4]) acc1 = accuracy_score(np.argmax(t_test, axis=1), np.argmax(preds, axis=1)) acc2 = accuracy_score(np.argmax(t_test, axis=1), np.sum(preds * idx, axis=1).astype(np.int32)) cm = confusion_matrix(np.argmax(t_test, axis=1), np.argmax(preds, axis=1)) print(acc1, acc2, cm) return (acc1, acc2, cm)
dropout_rate = 0.2 initializer = 'he_normal' weight_decay = 5e-4 regularizer = l2(weight_decay) # training parameters epochs = 200 batch_size = 32 learning_rate = 0.01 max_learning_rate = 0.1 clr = OneCycleLR(num_samples=X_train.shape[0], batch_size=batch_size, max_lr=max_learning_rate) chk = ModelCheckpoint(filepath='results/wrn1028', save_weights_only=True, monitor='val_loss', mode='min', save_best_only=True) # fit the model model = WideResNet(width, depth, classes, filters, input_shape, activation, dropout_rate, initializer, regularizer).get_model() model.compile(optimizer=SGD(lr=learning_rate), loss='categorical_crossentropy', metrics=['accuracy']) model.fit(generator.flow(X_train, Y_train, batch_size=batch_size), epochs=epochs, batch_size=batch_size, verbose=2, validation_data=(X_test, Y_test), callbacks=[clr, chk])
elif backbone_network == 'ResNext': from models import ResNext model = ResNext(n_layers=n_layers, n_groups=opt.n_groups, dataset=opt.dataset, attention=opt.attention_module, group_size=opt.group_size) elif backbone_network == 'WideResNet': from models import WideResNet model = WideResNet(n_layers=n_layers, widening_factor=opt.widening_factor, dataset=opt.dataset, attention=opt.attention_module, group_size=opt.group_size) model = nn.DataParallel(model).to(device) criterion = nn.CrossEntropyLoss() if dataset_name in ['CIFAR10', 'CIFAR100']: optim = torch.optim.SGD(model.parameters(), lr=opt.lr, momentum=opt.momentum, weight_decay=opt.weight_decay) milestones = [150, 225]
parser.add_argument('--perturb_steps', type=int, default=20, help='iterations for pgd attack (default pgd20)') parser.add_argument('--model_name', type=str, default="") parser.add_argument('--model_path', type=str, default="./models/weights/model-wideres-pgdHE-wide10.pt") parser.add_argument('--gpu_id', type=str, default="0") return parser.parse_args() if __name__=='__main__': args = parse_args() os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id #多卡机设置使用的gpu卡号 gpu_num = max(len(args.gpu_id.split(',')), 1) device = torch.device('cuda') if args.model_name!="": model = get_model_for_attack(args.model_name).to(device) # 根据model_name, 切换要攻击的model model = nn.DataParallel(model, device_ids=[i for i in range(gpu_num)]) else: # 防御任务, Change to your model here model = WideResNet() model.load_state_dict(torch.load('models/weights/wideres34-10-pgdHE.pt')) model = nn.DataParallel(model, device_ids=[i for i in range(gpu_num)]) #攻击任务:Change to your attack function here #Here is a attack baseline: PGD attack attack = PGDAttack(args.step_size, args.epsilon, args.perturb_steps) model.eval() test_loader = get_test_cifar(args.batch_size) natural_acc, robust_acc, distance = eval_model_with_attack(model, test_loader, attack, args.epsilon, device) print(f"Natural Acc: {natural_acc:.5f}, Robust acc: {robust_acc:.5f}, distance:{distance:.5f}")
def main(args): # writer = SummaryWriter('./runs/CIFAR_100_exp') train_transform = transforms.Compose([transforms.Pad(4, padding_mode='reflect'), transforms.RandomRotation(15), transforms.RandomHorizontalFlip(), transforms.RandomCrop(32), transforms.ToTensor(), transforms.Normalize((0.5071, 0.4867, 0.4408),(0.2675,0.2565,0.2761))]) test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5071, 0.4867, 0.4408),(0.2675,0.2565,0.2761))]) train_dataset = datasets.CIFAR100('./dataset',train = True, transform = train_transform, download=True) test_dataset = datasets.CIFAR100('./dataset',train = False, transform = test_transform, download=True) train_loader = DataLoader(train_dataset, batch_size = args.batch_size, shuffle=True, num_workers=args.num_workers) test_loader = DataLoader(test_dataset, batch_size = args.batch_size, shuffle=False, num_workers=args.num_workers) Teacher = WideResNet(depth=args.teacher_depth, num_classes=100, widen_factor=args.teacher_width_factor, drop_rate=0.3) Teacher.cuda() Teacher.eval() teacher_weight_path = path.join(args.teacher_root_path, 'model_best.pth.tar') t_load = torch.load(teacher_weight_path)['state_dict'] Teacher.load_state_dict(t_load) Student = WideResNet(depth = args.student_depth, num_classes=100, widen_factor=args.student_width_factor, drop_rate=0.0) Student.cuda() cudnn.benchmark = True optimizer = torch.optim.SGD(Student.parameters(), lr = args.lr, momentum=0.9, weight_decay=5e-4, nesterov=True) opt_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones = [60, 120, 160], gamma=2e-1) criterion = nn.CrossEntropyLoss() best_acc = 0 best_acc5 = 0 best_flag = False for epoch in range(args.total_epochs): for iter_, data in enumerate(train_loader): images, labels = data images, labels = images.cuda(), labels.cuda() t_outs, *t_acts = Teacher(images) s_outs, *s_acts = Student(images) cls_loss = criterion(s_outs, labels) """ statistical matching and AdaIN losses """ if args.aux_flag==0: aux_loss_1 = SM_Loss(t_acts[2], s_acts[2]) # group conv2 else: aux_loss_1 = 0 for i in range(3): aux_loss_1 += SM_Loss(t_acts[i], s_acts[i]) F_hat = AdaIN(t_acts[2], s_acts[2]) interim_out_q = Teacher.bn1(F_hat) interim_out_q = Teacher.relu(interim_out_q) interim_out_q = F.avg_pool2d(interim_out_q, 8) interim_out_q = interim_out_q.view(-1, Teacher.last_ch) q = Teacher.fc(interim_out_q) aux_loss_2 = torch.mean(torch.pow(t_outs-q, 2)) total_loss = cls_loss + aux_loss_1 + aux_loss_2 optimizer.zero_grad() total_loss.backward() optimizer.step() top1, top5 = evaluator(test_loader, Student) if top1 > best_acc: best_acc = top1 best_acc5 = top5 best_flag = True if best_flag: state = {'epoch':epoch+1, 'state_dict':Student.state_dict(), 'optimizer': optimizer.state_dict()} save_ckpt(state, is_best=best_flag, root_path = args.student_weight_path) best_flag = False opt_scheduler.step() # writer.add_scalar('acc/top1', top1, epoch) # writer.add_scalar('acc/top5', top5, epoch) # writer.close() print("Best top 1 acc: {}".format(best_acc)) print("Best top 5 acc: {}".format(best_acc5))
def load_paper_settings(args): WRN_path = os.path.join(args.data_path, 'WRN28-4_21.09.pt') Pyramid_path = os.path.join(args.data_path, 'pyramid200_mixup_15.6.tar') if args.paper_setting == 'a': teacher = WRN.WideResNet(depth=28, widen_factor=4, num_classes=100) state = torch.load(WRN_path, map_location={'cuda:0': 'cpu'})['model'] teacher.load_state_dict(state) student = WRN.WideResNet(depth=16, widen_factor=4, num_classes=100) elif args.paper_setting == 'b': teacher = WRN.WideResNet(depth=28, widen_factor=4, num_classes=100) state = torch.load(WRN_path, map_location={'cuda:0': 'cpu'})['model'] teacher.load_state_dict(state) student = WRN.WideResNet(depth=28, widen_factor=2, num_classes=100) elif args.paper_setting == 'c': teacher = WRN.WideResNet(depth=28, widen_factor=4, num_classes=100) state = torch.load(WRN_path, map_location={'cuda:0': 'cpu'})['model'] teacher.load_state_dict(state) student = WRN.WideResNet(depth=16, widen_factor=2, num_classes=100) elif args.paper_setting == 'd': teacher = WRN.WideResNet(depth=28, widen_factor=4, num_classes=100) state = torch.load(WRN_path, map_location={'cuda:0': 'cpu'})['model'] teacher.load_state_dict(state) student = RN.ResNet(depth=56, num_classes=100) elif args.paper_setting == 'e': teacher = PYN.PyramidNet(depth=200, alpha=240, num_classes=100, bottleneck=True) state = torch.load(Pyramid_path, map_location={'cuda:0': 'cpu'})['state_dict'] from collections import OrderedDict new_state = OrderedDict() for k, v in state.items(): name = k[7:] # remove 'module.' of dataparallel new_state[name] = v teacher.load_state_dict(new_state) student = WRN.WideResNet(depth=28, widen_factor=4, num_classes=100) elif args.paper_setting == 'f': teacher = PYN.PyramidNet(depth=200, alpha=240, num_classes=100, bottleneck=True) state = torch.load(Pyramid_path, map_location={'cuda:0': 'cpu'})['state_dict'] from collections import OrderedDict new_state = OrderedDict() for k, v in state.items(): name = k[7:] # remove 'module.' of dataparallel new_state[name] = v teacher.load_state_dict(new_state) student = PYN.PyramidNet(depth=110, alpha=84, num_classes=100, bottleneck=False) else: print('Undefined setting name !!!') return teacher, student, args
parser.add_argument('--regu', type=str, default='no', help='type of regularization. Possible values are: ' 'no: no regularization' 'random-svd: employ random-svd in regularization ') if __name__ == "__main__": args = parser.parse_args() # create model n_classes = args.dataset == 'cifar10' and 10 or 100 if args.model == 'resnet': net = resnet110(num_classes=n_classes) elif args.model == 'wideresnet': net = WideResNet(depth=28, widen_factor=10, dropRate=0.3, num_classes=n_classes) elif args.model == 'resnext': net = CifarResNeXt(cardinality=8, depth=29, base_width=64, widen_factor=4, nlabels=n_classes) else: raise Exception('Invalid model name') # create optimizer optimizer = torch.optim.SGD(net.parameters(), args.lr, momentum=args.momentum, nesterov=args.nesterov, weight_decay=args.weight_decay)
def main(): args = parser.parse_args() args.best_top1 = 0. args.best_top5 = 0. if args.local_rank != -1: args.gpu = args.local_rank torch.distributed.init_process_group(backend='nccl') args.world_size = torch.distributed.get_world_size() else: args.gpu = 0 args.world_size = 1 args.device = torch.device('cuda', args.gpu) logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO if args.local_rank in [-1, 0] else logging.WARNING) logger.warning(f"Process rank: {args.local_rank}, " f"device: {args.device}, " f"distributed training: {bool(args.local_rank != -1)}, " f"16-bits training: {args.amp}") logger.info(dict(args._get_kwargs())) if args.local_rank in [-1, 0]: args.writer = SummaryWriter(f"results/{args.name}") if args.seed is not None: set_seed(args) if args.local_rank not in [-1, 0]: torch.distributed.barrier() labeled_dataset, unlabeled_dataset, test_dataset = DATASET_GETTERS[ args.dataset](args) if args.local_rank == 0: torch.distributed.barrier() train_sampler = RandomSampler if args.local_rank == -1 else DistributedSampler labeled_loader = DataLoader(labeled_dataset, sampler=train_sampler(labeled_dataset), batch_size=args.batch_size, num_workers=args.workers, drop_last=True) unlabeled_loader = DataLoader(unlabeled_dataset, sampler=train_sampler(unlabeled_dataset), batch_size=args.batch_size * args.mu, num_workers=args.workers, drop_last=True) test_loader = DataLoader(test_dataset, sampler=SequentialSampler(test_dataset), batch_size=args.batch_size, num_workers=args.workers) if args.dataset == "cifar10": depth, widen_factor = 28, 2 elif args.dataset == 'cifar100': depth, widen_factor = 28, 8 if args.local_rank not in [-1, 0]: torch.distributed.barrier() # test dropout teacher_model = WideResNet(num_classes=args.num_classes, depth=depth, widen_factor=widen_factor, dropout=0, dense_dropout=args.dense_dropout) student_model = WideResNet(num_classes=args.num_classes, depth=depth, widen_factor=widen_factor, dropout=0, dense_dropout=args.dense_dropout) if args.local_rank == 0: torch.distributed.barrier() teacher_model.to(args.device) student_model.to(args.device) avg_student_model = None if args.ema > 0: avg_student_model = ModelEMA(student_model, args.ema) criterion = create_loss_fn(args) no_decay = ['bn'] teacher_parameters = [{ 'params': [ p for n, p in teacher_model.named_parameters() if not any(nd in n for nd in no_decay) ], 'weight_decay': args.weight_decay }, { 'params': [ p for n, p in teacher_model.named_parameters() if any(nd in n for nd in no_decay) ], 'weight_decay': 0.0 }] student_parameters = [{ 'params': [ p for n, p in student_model.named_parameters() if not any(nd in n for nd in no_decay) ], 'weight_decay': args.weight_decay }, { 'params': [ p for n, p in student_model.named_parameters() if any(nd in n for nd in no_decay) ], 'weight_decay': 0.0 }] t_optimizer = optim.SGD( teacher_parameters, lr=args.lr, momentum=args.momentum, # weight_decay=args.weight_decay, nesterov=args.nesterov) s_optimizer = optim.SGD( student_parameters, lr=args.lr, momentum=args.momentum, # weight_decay=args.weight_decay, nesterov=args.nesterov) t_scheduler = get_cosine_schedule_with_warmup(t_optimizer, args.warmup_steps, args.total_steps) s_scheduler = get_cosine_schedule_with_warmup(s_optimizer, args.warmup_steps, args.total_steps, args.student_wait_steps) t_scaler = amp.GradScaler(enabled=args.amp) s_scaler = amp.GradScaler(enabled=args.amp) # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): logger.info(f"=> loading checkpoint '{args.resume}'") loc = f'cuda:{args.gpu}' checkpoint = torch.load(args.resume, map_location=loc) args.best_top1 = checkpoint['best_top1'].to(torch.device('cpu')) args.best_top5 = checkpoint['best_top5'].to(torch.device('cpu')) if not (args.evaluate or args.finetune): args.start_step = checkpoint['step'] t_optimizer.load_state_dict(checkpoint['teacher_optimizer']) s_optimizer.load_state_dict(checkpoint['student_optimizer']) t_scheduler.load_state_dict(checkpoint['teacher_scheduler']) s_scheduler.load_state_dict(checkpoint['student_scheduler']) t_scaler.load_state_dict(checkpoint['teacher_scaler']) s_scaler.load_state_dict(checkpoint['student_scaler']) model_load_state_dict(teacher_model, checkpoint['teacher_state_dict']) if avg_student_model is not None: model_load_state_dict(avg_student_model, checkpoint['avg_state_dict']) else: if checkpoint['avg_state_dict'] is not None: model_load_state_dict(student_model, checkpoint['avg_state_dict']) else: model_load_state_dict(student_model, checkpoint['student_state_dict']) logger.info( f"=> loaded checkpoint '{args.resume}' (step {checkpoint['step']})" ) else: logger.info(f"=> no checkpoint found at '{args.resume}'") if args.local_rank != -1: teacher_model = nn.parallel.DistributedDataParallel( teacher_model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) student_model = nn.parallel.DistributedDataParallel( student_model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) if args.finetune: del t_scaler, t_scheduler, t_optimizer, teacher_model, unlabeled_loader del s_scaler, s_scheduler, s_optimizer finetune(args, labeled_loader, test_loader, student_model, criterion) return if args.evaluate: del t_scaler, t_scheduler, t_optimizer, teacher_model, unlabeled_loader, labeled_loader del s_scaler, s_scheduler, s_optimizer evaluate(args, test_loader, student_model, criterion) return teacher_model.zero_grad() student_model.zero_grad() train_loop(args, labeled_loader, unlabeled_loader, test_loader, teacher_model, student_model, avg_student_model, criterion, t_optimizer, s_optimizer, t_scheduler, s_scheduler, t_scaler, s_scaler) return
def main(): global args, best_prec1 args = parser.parse_args() # ensuring reproducibility SEED = 42 torch.manual_seed(SEED) torch.backends.cudnn.benchmark = False kwargs = {'num_workers': 1, 'pin_memory': True} device = torch.device("cuda") num_epochs_transient = 2 num_epochs_steady = 7 perc_to_remove = 10 torch.manual_seed(SEED) # create model model = WideResNet(args.layers, 10, args.widen_factor, dropRate=args.droprate).to(device) optimizer = torch.optim.Adam(model.parameters(), args.learning_rate, weight_decay=args.weight_decay) # instantiate loaders train_loader = get_data_loader(args.data_dir, args.batch_size, **kwargs) test_loader = get_test_loader(args.data_dir, 128, **kwargs) tic = time.time() seen_losses = None for epoch in range(1, 3): if epoch == 1: seen_losses = train_transient(model, device, train_loader, optimizer, epoch, track=True) else: train_transient(model, device, train_loader, optimizer, epoch) test(model, device, test_loader, epoch) for epoch in range(3, 4): seen_losses = [v for sublist in seen_losses for v in sublist] sorted_loss_idx = sorted(range(len(seen_losses)), key=lambda k: seen_losses[k][1], reverse=True) removed = sorted_loss_idx[-int((perc_to_remove / 100) * len(sorted_loss_idx)):] sorted_loss_idx = sorted_loss_idx[:-int((perc_to_remove / 100) * len(sorted_loss_idx))] to_add = list( np.random.choice(removed, int(0.33 * len(sorted_loss_idx)), replace=False)) sorted_loss_idx = sorted_loss_idx + to_add sorted_loss_idx.sort() weights = [seen_losses[idx][1] for idx in sorted_loss_idx] train_loader = get_weighted_loader(args.data_dir, 64 * 2, weights, **kwargs) seen_losses = train_steady_state(model, device, train_loader, optimizer, epoch) test(model, device, test_loader, epoch) for epoch in range(4, 8): train_transient(model, device, train_loader, optimizer, epoch) test(model, device, test_loader, epoch) toc = time.time() print("Time Elapsed: {}s".format(toc - tic))