Exemple #1
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    # ensuring reproducibility
    SEED = 42
    torch.manual_seed(SEED)
    torch.backends.cudnn.benchmark = False

    kwargs = {'num_workers': 1, 'pin_memory': True}
    device = torch.device("cuda")

    num_epochs = 7

    # create model
    model = WideResNet(args.layers,
                       10,
                       args.widen_factor,
                       dropRate=args.droprate).to(device)

    optimizer = torch.optim.Adam(model.parameters(),
                                 args.learning_rate,
                                 weight_decay=args.weight_decay)

    # instantiate loaders
    train_loader = get_data_loader(args.data_dir, args.batch_size, **kwargs)
    test_loader = get_test_loader(args.data_dir, 128, **kwargs)

    tic = time.time()
    for epoch in range(1, num_epochs + 1):
        train(model, device, train_loader, optimizer, epoch)
        test(model, device, test_loader, epoch)
    toc = time.time()
    print("Time Elapsed: {}s".format(toc - tic))
Exemple #2
0
def get_model(args):
    if args.seed is not None:
        set_seed(args)

    if args.dataset == "cifar10":
        depth, widen_factor = 28, 2
    elif args.dataset == 'cifar100':
        depth, widen_factor = 28, 8

    student_model = WideResNet(num_classes=args.num_classes,
                               depth=depth,
                               widen_factor=widen_factor,
                               dropout=0,
                               dense_dropout=args.dense_dropout)

    if os.path.isfile(args.resume):
        print(f"=> loading checkpoint '{args.resume}'")
        loc = f'cpu'
        checkpoint = torch.load(args.resume, map_location=loc)
        if checkpoint['avg_state_dict'] is not None:
            model_load_state_dict(student_model, checkpoint['avg_state_dict'])
        else:
            model_load_state_dict(student_model,
                                  checkpoint['student_state_dict'])

        print(
            f"=> loaded checkpoint '{args.resume}' (step {checkpoint['step']})"
        )
    else:
        print(f"=> no checkpoint found at '{args.resume}'")
        exit(1)

    if args.device != 'cpu':
        student_model.cuda()
    return student_model
Exemple #3
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    # ensuring reproducibility
    SEED = 42
    torch.manual_seed(SEED)
    torch.backends.cudnn.benchmark = False

    kwargs = {'num_workers': 1, 'pin_memory': True}
    device = torch.device("cuda")

    num_epochs = 7

    # create model
    model = WideResNet(args.layers, 10, args.widen_factor, dropRate=args.droprate).to(device)

    optimizer = torch.optim.Adam(
        model.parameters(),
        args.learning_rate,
        weight_decay=args.weight_decay
    )

    # instantiate loaders
    train_loader = get_data_loader(args.data_dir, args.batch_size, **kwargs)
    test_loader = get_test_loader(args.data_dir, 128, **kwargs)

    tic = time.time()
    for epoch in range(1, num_epochs+1):
        train(model, device, train_loader, optimizer, epoch)
        test(model, device, test_loader, epoch)
    toc = time.time()
    print("Time Elapsed: {}s".format(toc-tic))
Exemple #4
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    # ensuring reproducibility
    SEED = 42
    torch.manual_seed(SEED)
    torch.backends.cudnn.benchmark = False

    kwargs = {'num_workers': 1, 'pin_memory': True}
    device = torch.device("cuda")

    num_epochs_transient = 2
    num_epochs_steady = 7
    perc_to_remove = 10

    torch.manual_seed(SEED)

    # create model
    model = WideResNet(args.layers, 10, args.widen_factor, dropRate=args.droprate).to(device)

    optimizer = torch.optim.Adam(
        model.parameters(),
        args.learning_rate,
        weight_decay=args.weight_decay
    )

    # instantiate loaders
    train_loader = get_data_loader(args.data_dir, args.batch_size, **kwargs)
    test_loader = get_test_loader(args.data_dir, 128, **kwargs)

    tic = time.time()
    seen_losses = None
    for epoch in range(1, 3):
        if epoch == 1:
            seen_losses = train_transient(model, device, train_loader, optimizer, epoch, track=True)
        else:
            train_transient(model, device, train_loader, optimizer, epoch)
        test(model, device, test_loader, epoch)

    for epoch in range(3, 4):
        seen_losses = [v for sublist in seen_losses for v in sublist]
        sorted_loss_idx = sorted(range(len(seen_losses)), key=lambda k: seen_losses[k][1], reverse=True)
        removed = sorted_loss_idx[-int((perc_to_remove / 100) * len(sorted_loss_idx)):]
        sorted_loss_idx = sorted_loss_idx[:-int((perc_to_remove / 100) * len(sorted_loss_idx))]
        to_add = list(np.random.choice(removed, int(0.33*len(sorted_loss_idx)), replace=False))
        sorted_loss_idx = sorted_loss_idx + to_add
        sorted_loss_idx.sort()
        weights = [seen_losses[idx][1] for idx in sorted_loss_idx]
        train_loader = get_weighted_loader(args.data_dir, 64*2, weights, **kwargs)
        seen_losses = train_steady_state(model, device, train_loader, optimizer, epoch)
        test(model, device, test_loader, epoch)

    for epoch in range(4, 8):
        train_transient(model, device, train_loader, optimizer, epoch)
        test(model, device, test_loader, epoch)
    toc = time.time()
    print("Time Elapsed: {}s".format(toc-tic))
Exemple #5
0
 def __init__(
     self,
     input_shape,
     output_dim,
     patience=4,
     structure='wide_res_net',
 ):
     self.model = None
     if structure == 'wide_res_net':
         self.model = WideResNet(input_shape=input_shape,
                                 output_dim=output_dim)
     elif structure == 'res_net':
         self.model = ResNet(input_shape=input_shape, output_dim=output_dim)
     else:
         raise Exception('no structure')
     self.criterion = tf.keras.losses.CategoricalCrossentropy()
     self.optimizer = tf.keras.optimizers.SGD(learning_rate=0.1)
     self.train_loss = tf.keras.metrics.Mean()
     self.train_acc = tf.keras.metrics.CategoricalAccuracy()
     self.val_loss = tf.keras.metrics.Mean()
     self.val_acc = tf.keras.metrics.CategoricalAccuracy()
     self.history = {
         'train_loss': [],
         'val_loss': [],
         'train_acc': [],
         'val_acc': []
     }
     self.es = {'loss': float('inf'), 'patience': patience, 'step': 0}
     self.save_dir = './logs'
     if not os.path.exists(self.save_dir):
         os.mkdir('logs')
Exemple #6
0
 def __init__(self,
              input_shape,
              encode_dim,
              output_dim,
              model='efficient_net',
              loss='emd'):
     self.model = None
     if model == 'efficient_net':
         self.model = EfficientNet(input_shape, encode_dim, output_dim)
     elif model == 'wide_res_net':
         self.model = WideResNet(input_shape, output_dim)
     else:
         raise Exception('no match model name')
     optimizer = tf.keras.optimizers.SGD(learning_rate=0.1, momentum=0.1)
     loss_func = None
     if loss == 'emd':
         loss_func = EMD
     elif loss == 'categorical_crossentropy':
         loss_func = 'categorical_crossentropy'
     else:
         raise Exception('no match loss function')
     self.model.compile(optimizer=optimizer,
                        loss=loss_func,
                        metrics=['acc'])
Exemple #7
0
def get_model_for_attack(model_name):
    if model_name == 'model1':
        model = ResNet34()
        load_w(model, "./models/weights/resnet34.pt")
    elif model_name == 'model2':
        model = ResNet18()
        load_w(model, "./models/weights/resnet18_AT.pt")
    elif model_name == 'model3':
        model = SmallResNet()
        load_w(model, "./models/weights/res_small.pth")
    elif model_name == 'model4':
        model = WideResNet34()
        pref = next(model.parameters())
        model.load_state_dict(
            filter_state_dict(
                torch.load("./models/weights/trades_wide_resnet.pt",
                           map_location=pref.device)))
    elif model_name == 'model5':
        model = WideResNet()
        load_w(model, "./models/weights/wideres34-10-pgdHE.pt")
    elif model_name == 'model6':
        model = WideResNet28()
        pref = next(model.parameters())
        model.load_state_dict(
            filter_state_dict(
                torch.load('models/weights/RST-AWP_cifar10_linf_wrn28-10.pt',
                           map_location=pref.device)))
    elif model_name == 'model_vgg16bn':
        model = vgg16_bn(pretrained=True)
    elif model_name == 'model_resnet18_imgnet':
        model = resnet18(pretrained=True)
    elif model_name == 'model_inception':
        model = inception_v3(pretrained=True)
    elif model_name == 'model_vitb':
        from mnist_vit import ViT, MegaSizer
        model = MegaSizer(
            ImageNetRenormalize(ViT('B_16_imagenet1k', pretrained=True)))
    elif model_name.startswith('model_hub:'):
        _, a, b = model_name.split(":")
        model = torch.hub.load(a, b, pretrained=True)
        model = Cifar10Renormalize(model)
    elif model_name.startswith('model_mnist:'):
        _, a = model_name.split(":")
        model = torch.load('mnist.pt')[a]
    elif model_name.startswith('model_ex:'):
        _, a = model_name.split(":")
        model = torch.load(a)
    return model
Exemple #8
0
def create_model(config):
    model_type = config["model_type"]
    if model_type == "SimpleConvNet":
        if model_type not in config:
            config[model_type] = {
                "conv1_size": 32,
                "conv2_size": 64,
                "fc_size": 128
            }
        model = SimpleConvNet(**config[model_type])
    elif model_type == "MiniVGG":
        if model_type not in config:
            config[model_type] = {
                "conv1_size": 128,
                "conv2_size": 256,
                "classifier_size": 1024
            }
        model = MiniVGG(**config[model_type])
    elif model_type == "WideResNet":
        if model_type not in config:
            config[model_type] = {
                "depth": 34,
                "num_classes": 10,
                "widen_factor": 10
            }
        model = WideResNet(**config[model_type])
    # elif model_type == "ShuffleNetv2":
    #     if model_type not in config:
    #         config[model_type] = {}
    #     model = shufflenet_v2_x0_5()
    elif model_type == "MobileNetv2":
        if model_type not in config:
            config[model_type] = {"pretrained": False}
        model = mobilenet_v2(num_classes=10,
                             pretrained=config[model_type]["pretrained"])
    else:
        print(f"Error: MODEL_TYPE {model_type} unknown.")
        exit()

    config["num_parameters"] = sum(p.numel() for p in model.parameters())
    config["num_trainable_parameters"] = sum(p.numel()
                                             for p in model.parameters()
                                             if p.requires_grad)
    return model
Exemple #9
0
def get_model_for_attack(model_name):
    if model_name=='model1':
        model = ResNet34()
        model.load_state_dict(torch.load("models/weights/resnet34.pt"))
    elif model_name=='model2':
        model = ResNet18()
        model.load_state_dict(torch.load('models/weights/resnet18_AT.pt'))
    elif model_name=='model3':
        model = SmallResNet()
        model.load_state_dict(torch.load('models/weights/res_small.pth'))
    elif model_name=='model4':
        model = WideResNet34()
        model.load_state_dict(filter_state_dict(torch.load('models/weights/trades_wide_resnet.pt')))
    elif model_name=='model5':
        model = WideResNet()
        model.load_state_dict(torch.load('models/weights/wideres34-10-pgdHE.pt'))
    elif model_name=='model6':
        model = WideResNet28()
        model.load_state_dict(filter_state_dict(torch.load('models/weights/RST-AWP_cifar10_linf_wrn28-10.pt')))
    return model
    def get_model(weight_decay=0.0005):
        # parameters for WideResnet model
        k = 10  # widening factor
        N = 4  # number of blocks per stage. Depth = 6*N+4
        dropout = 0.3

        # WRN 28 - 10 with dropout 0.3
        model = WideResNet([16 * k, 32 * k, 64 * k], [N] * 3,
                           dropout,
                           weight_decay,
                           nb_classes=100,
                           batchnorm_training=False,
                           use_bias=False)

        weights_location = file_loc + 'saved_weights/initial_weights_C100_WRN.h5'
        if 'initial_weights_C100_WRN.h5' not in os.listdir(file_loc +
                                                           'saved_weights'):
            model.save_weights(weights_location)
        else:
            model.load_weights(weights_location)

        return model
Exemple #11
0
class TrainerV2(object):
    def __init__(self,
                 input_shape,
                 encode_dim,
                 output_dim,
                 model='efficient_net',
                 loss='emd'):
        self.model = None
        if model == 'efficient_net':
            self.model = EfficientNet(input_shape, encode_dim, output_dim)
        elif model == 'wide_res_net':
            self.model = WideResNet(input_shape, output_dim)
        else:
            raise Exception('no match model name')
        optimizer = tf.keras.optimizers.SGD(learning_rate=0.1, momentum=0.1)
        loss_func = None
        if loss == 'emd':
            loss_func = EMD
        elif loss == 'categorical_crossentropy':
            loss_func = 'categorical_crossentropy'
        else:
            raise Exception('no match loss function')
        self.model.compile(optimizer=optimizer,
                           loss=loss_func,
                           metrics=['acc'])

    def train(self, x_train, t_train, x_val, t_val, epochs, batch_size,
              image_path, save_name):
        train_gen = DataGenerator(x_train,
                                  t_train,
                                  image_path=image_path,
                                  batch_size=batch_size)
        val_gen = DataGenerator(x_val,
                                t_val,
                                image_path=image_path,
                                batch_size=batch_size)

        callbacks = [
            tf.keras.callbacks.ModelCheckpoint(save_name,
                                               monitor='val_loss',
                                               verbose=1,
                                               save_best_only=True,
                                               mode='min')
        ]

        self.history = self.model.fit_generator(
            train_gen,
            len(train_gen),
            epochs=30,
            validation_data=val_gen,
            validation_steps=len(val_gen),
            callbacks=callbacks,
        )

    def evaluate(
        self,
        x_test,
        t_test,
        batch_size,
        image_path,
    ):
        test_gen = DataGenerator(x_test,
                                 t_test,
                                 image_path=image_path,
                                 batch_size=batch_size)

        preds = self.model.predict_generator(
            test_gen,
            len(test_gen),
        )
        idx = np.array([0, 1, 2, 3, 4])
        acc1 = accuracy_score(np.argmax(t_test, axis=1),
                              np.argmax(preds, axis=1))
        acc2 = accuracy_score(np.argmax(t_test, axis=1),
                              np.sum(preds * idx, axis=1).astype(np.int32))

        cm = confusion_matrix(np.argmax(t_test, axis=1),
                              np.argmax(preds, axis=1))
        print(acc1, acc2, cm)
        return (acc1, acc2, cm)
Exemple #12
0
dropout_rate = 0.2
initializer = 'he_normal'
weight_decay = 5e-4
regularizer = l2(weight_decay)

# training parameters
epochs = 200
batch_size = 32
learning_rate = 0.01
max_learning_rate = 0.1
clr = OneCycleLR(num_samples=X_train.shape[0],
                 batch_size=batch_size,
                 max_lr=max_learning_rate)
chk = ModelCheckpoint(filepath='results/wrn1028',
                      save_weights_only=True,
                      monitor='val_loss',
                      mode='min',
                      save_best_only=True)

# fit the model
model = WideResNet(width, depth, classes, filters, input_shape, activation,
                   dropout_rate, initializer, regularizer).get_model()
model.compile(optimizer=SGD(lr=learning_rate),
              loss='categorical_crossentropy',
              metrics=['accuracy'])
model.fit(generator.flow(X_train, Y_train, batch_size=batch_size),
          epochs=epochs,
          batch_size=batch_size,
          verbose=2,
          validation_data=(X_test, Y_test),
          callbacks=[clr, chk])
Exemple #13
0
    elif backbone_network == 'ResNext':
        from models import ResNext

        model = ResNext(n_layers=n_layers,
                        n_groups=opt.n_groups,
                        dataset=opt.dataset,
                        attention=opt.attention_module,
                        group_size=opt.group_size)

    elif backbone_network == 'WideResNet':
        from models import WideResNet

        model = WideResNet(n_layers=n_layers,
                           widening_factor=opt.widening_factor,
                           dataset=opt.dataset,
                           attention=opt.attention_module,
                           group_size=opt.group_size)

    model = nn.DataParallel(model).to(device)

    criterion = nn.CrossEntropyLoss()

    if dataset_name in ['CIFAR10', 'CIFAR100']:
        optim = torch.optim.SGD(model.parameters(),
                                lr=opt.lr,
                                momentum=opt.momentum,
                                weight_decay=opt.weight_decay)

        milestones = [150, 225]
    parser.add_argument('--perturb_steps', type=int, default=20,
                    help='iterations for pgd attack (default pgd20)')
    parser.add_argument('--model_name', type=str, default="")
    parser.add_argument('--model_path', type=str, default="./models/weights/model-wideres-pgdHE-wide10.pt")
    parser.add_argument('--gpu_id', type=str, default="0")
    return parser.parse_args()



if __name__=='__main__':
    args = parse_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id   #多卡机设置使用的gpu卡号
    gpu_num = max(len(args.gpu_id.split(',')), 1)
    device = torch.device('cuda')
    if args.model_name!="":
        model = get_model_for_attack(args.model_name).to(device)   # 根据model_name, 切换要攻击的model
        model = nn.DataParallel(model, device_ids=[i for i in range(gpu_num)])
        
    else:
        # 防御任务, Change to your model here
        model = WideResNet()
        model.load_state_dict(torch.load('models/weights/wideres34-10-pgdHE.pt'))
        model = nn.DataParallel(model, device_ids=[i for i in range(gpu_num)])
    #攻击任务:Change to your attack function here
    #Here is a attack baseline: PGD attack
    attack = PGDAttack(args.step_size, args.epsilon, args.perturb_steps)
    model.eval()
    test_loader = get_test_cifar(args.batch_size)
    natural_acc, robust_acc, distance = eval_model_with_attack(model, test_loader, attack, args.epsilon, device)
    print(f"Natural Acc: {natural_acc:.5f}, Robust acc: {robust_acc:.5f}, distance:{distance:.5f}")
Exemple #15
0
def main(args):
    # writer = SummaryWriter('./runs/CIFAR_100_exp')
     
    train_transform = transforms.Compose([transforms.Pad(4, padding_mode='reflect'),
                                          transforms.RandomRotation(15),
                                          transforms.RandomHorizontalFlip(),
                                          transforms.RandomCrop(32),
                                          transforms.ToTensor(),
                                          transforms.Normalize((0.5071, 0.4867, 0.4408),(0.2675,0.2565,0.2761))])
    
    test_transform = transforms.Compose([transforms.ToTensor(),
                                         transforms.Normalize((0.5071, 0.4867, 0.4408),(0.2675,0.2565,0.2761))])
    
    train_dataset = datasets.CIFAR100('./dataset',train = True, transform = train_transform, download=True)
    test_dataset = datasets.CIFAR100('./dataset',train = False, transform = test_transform, download=True)
    
    train_loader = DataLoader(train_dataset, batch_size = args.batch_size, shuffle=True, num_workers=args.num_workers)
    test_loader = DataLoader(test_dataset, batch_size = args.batch_size, shuffle=False, num_workers=args.num_workers)
    
    Teacher = WideResNet(depth=args.teacher_depth, num_classes=100, widen_factor=args.teacher_width_factor, drop_rate=0.3)
    Teacher.cuda()
    Teacher.eval()
    
    teacher_weight_path = path.join(args.teacher_root_path, 'model_best.pth.tar')
    t_load = torch.load(teacher_weight_path)['state_dict']
    Teacher.load_state_dict(t_load)
    
    Student = WideResNet(depth = args.student_depth, num_classes=100, widen_factor=args.student_width_factor, drop_rate=0.0)
    Student.cuda()
    
    cudnn.benchmark = True
    
    optimizer = torch.optim.SGD(Student.parameters(), lr = args.lr, momentum=0.9, weight_decay=5e-4, nesterov=True)
    opt_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones = [60, 120, 160], gamma=2e-1)
    
    criterion = nn.CrossEntropyLoss()
    
    best_acc = 0
    best_acc5 = 0
    best_flag = False
    
    for epoch in range(args.total_epochs):
        for iter_, data in enumerate(train_loader):
            images, labels = data
            images, labels = images.cuda(), labels.cuda()
            t_outs, *t_acts = Teacher(images)
            s_outs, *s_acts = Student(images)
            
            cls_loss = criterion(s_outs, labels)
            
            """
            statistical matching and AdaIN losses
            """
            
            if args.aux_flag==0:
                aux_loss_1 = SM_Loss(t_acts[2], s_acts[2]) # group conv2
            else:
                aux_loss_1 = 0
                for i in range(3):
                    aux_loss_1 += SM_Loss(t_acts[i], s_acts[i])
                    
            F_hat = AdaIN(t_acts[2], s_acts[2])
            interim_out_q = Teacher.bn1(F_hat)
            interim_out_q = Teacher.relu(interim_out_q)
            interim_out_q = F.avg_pool2d(interim_out_q, 8)
            interim_out_q = interim_out_q.view(-1, Teacher.last_ch)
            q = Teacher.fc(interim_out_q)
            
            aux_loss_2 = torch.mean(torch.pow(t_outs-q, 2))
            
            total_loss = cls_loss + aux_loss_1 + aux_loss_2
            
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()
    
        top1, top5 = evaluator(test_loader, Student)
        
        if top1 > best_acc:
            best_acc = top1
            best_acc5 = top5
            best_flag = True    
        if best_flag:
            state = {'epoch':epoch+1, 'state_dict':Student.state_dict(), 'optimizer': optimizer.state_dict()}       
            save_ckpt(state, is_best=best_flag, root_path = args.student_weight_path)
            best_flag = False
        
        opt_scheduler.step()
        
        # writer.add_scalar('acc/top1', top1, epoch)
        # writer.add_scalar('acc/top5', top5, epoch)
        # writer.close()
        
        
    print("Best top 1 acc: {}".format(best_acc))
    print("Best top 5 acc: {}".format(best_acc5))    
Exemple #16
0
def load_paper_settings(args):

    WRN_path = os.path.join(args.data_path, 'WRN28-4_21.09.pt')
    Pyramid_path = os.path.join(args.data_path, 'pyramid200_mixup_15.6.tar')

    if args.paper_setting == 'a':
        teacher = WRN.WideResNet(depth=28, widen_factor=4, num_classes=100)
        state = torch.load(WRN_path, map_location={'cuda:0': 'cpu'})['model']
        teacher.load_state_dict(state)
        student = WRN.WideResNet(depth=16, widen_factor=4, num_classes=100)

    elif args.paper_setting == 'b':
        teacher = WRN.WideResNet(depth=28, widen_factor=4, num_classes=100)
        state = torch.load(WRN_path, map_location={'cuda:0': 'cpu'})['model']
        teacher.load_state_dict(state)
        student = WRN.WideResNet(depth=28, widen_factor=2, num_classes=100)

    elif args.paper_setting == 'c':
        teacher = WRN.WideResNet(depth=28, widen_factor=4, num_classes=100)
        state = torch.load(WRN_path, map_location={'cuda:0': 'cpu'})['model']
        teacher.load_state_dict(state)
        student = WRN.WideResNet(depth=16, widen_factor=2, num_classes=100)

    elif args.paper_setting == 'd':
        teacher = WRN.WideResNet(depth=28, widen_factor=4, num_classes=100)
        state = torch.load(WRN_path, map_location={'cuda:0': 'cpu'})['model']
        teacher.load_state_dict(state)
        student = RN.ResNet(depth=56, num_classes=100)

    elif args.paper_setting == 'e':
        teacher = PYN.PyramidNet(depth=200,
                                 alpha=240,
                                 num_classes=100,
                                 bottleneck=True)
        state = torch.load(Pyramid_path, map_location={'cuda:0':
                                                       'cpu'})['state_dict']
        from collections import OrderedDict
        new_state = OrderedDict()
        for k, v in state.items():
            name = k[7:]  # remove 'module.' of dataparallel
            new_state[name] = v
        teacher.load_state_dict(new_state)
        student = WRN.WideResNet(depth=28, widen_factor=4, num_classes=100)

    elif args.paper_setting == 'f':
        teacher = PYN.PyramidNet(depth=200,
                                 alpha=240,
                                 num_classes=100,
                                 bottleneck=True)
        state = torch.load(Pyramid_path, map_location={'cuda:0':
                                                       'cpu'})['state_dict']
        from collections import OrderedDict
        new_state = OrderedDict()
        for k, v in state.items():
            name = k[7:]  # remove 'module.' of dataparallel
            new_state[name] = v
        teacher.load_state_dict(new_state)
        student = PYN.PyramidNet(depth=110,
                                 alpha=84,
                                 num_classes=100,
                                 bottleneck=False)

    else:
        print('Undefined setting name !!!')

    return teacher, student, args
Exemple #17
0
parser.add_argument('--regu',
                    type=str,
                    default='no',
                    help='type of regularization. Possible values are: '
                    'no: no regularization'
                    'random-svd: employ random-svd in regularization ')

if __name__ == "__main__":
    args = parser.parse_args()
    # create model
    n_classes = args.dataset == 'cifar10' and 10 or 100
    if args.model == 'resnet':
        net = resnet110(num_classes=n_classes)
    elif args.model == 'wideresnet':
        net = WideResNet(depth=28,
                         widen_factor=10,
                         dropRate=0.3,
                         num_classes=n_classes)
    elif args.model == 'resnext':
        net = CifarResNeXt(cardinality=8,
                           depth=29,
                           base_width=64,
                           widen_factor=4,
                           nlabels=n_classes)
    else:
        raise Exception('Invalid model name')
    # create optimizer
    optimizer = torch.optim.SGD(net.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                nesterov=args.nesterov,
                                weight_decay=args.weight_decay)
def main():
    args = parser.parse_args()
    args.best_top1 = 0.
    args.best_top5 = 0.

    if args.local_rank != -1:
        args.gpu = args.local_rank
        torch.distributed.init_process_group(backend='nccl')
        args.world_size = torch.distributed.get_world_size()
    else:
        args.gpu = 0
        args.world_size = 1

    args.device = torch.device('cuda', args.gpu)

    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARNING)

    logger.warning(f"Process rank: {args.local_rank}, "
                   f"device: {args.device}, "
                   f"distributed training: {bool(args.local_rank != -1)}, "
                   f"16-bits training: {args.amp}")

    logger.info(dict(args._get_kwargs()))

    if args.local_rank in [-1, 0]:
        args.writer = SummaryWriter(f"results/{args.name}")

    if args.seed is not None:
        set_seed(args)

    if args.local_rank not in [-1, 0]:
        torch.distributed.barrier()

    labeled_dataset, unlabeled_dataset, test_dataset = DATASET_GETTERS[
        args.dataset](args)

    if args.local_rank == 0:
        torch.distributed.barrier()

    train_sampler = RandomSampler if args.local_rank == -1 else DistributedSampler
    labeled_loader = DataLoader(labeled_dataset,
                                sampler=train_sampler(labeled_dataset),
                                batch_size=args.batch_size,
                                num_workers=args.workers,
                                drop_last=True)

    unlabeled_loader = DataLoader(unlabeled_dataset,
                                  sampler=train_sampler(unlabeled_dataset),
                                  batch_size=args.batch_size * args.mu,
                                  num_workers=args.workers,
                                  drop_last=True)

    test_loader = DataLoader(test_dataset,
                             sampler=SequentialSampler(test_dataset),
                             batch_size=args.batch_size,
                             num_workers=args.workers)

    if args.dataset == "cifar10":
        depth, widen_factor = 28, 2
    elif args.dataset == 'cifar100':
        depth, widen_factor = 28, 8

    if args.local_rank not in [-1, 0]:
        torch.distributed.barrier()

    # test dropout
    teacher_model = WideResNet(num_classes=args.num_classes,
                               depth=depth,
                               widen_factor=widen_factor,
                               dropout=0,
                               dense_dropout=args.dense_dropout)
    student_model = WideResNet(num_classes=args.num_classes,
                               depth=depth,
                               widen_factor=widen_factor,
                               dropout=0,
                               dense_dropout=args.dense_dropout)

    if args.local_rank == 0:
        torch.distributed.barrier()

    teacher_model.to(args.device)
    student_model.to(args.device)
    avg_student_model = None
    if args.ema > 0:
        avg_student_model = ModelEMA(student_model, args.ema)

    criterion = create_loss_fn(args)

    no_decay = ['bn']
    teacher_parameters = [{
        'params': [
            p for n, p in teacher_model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        args.weight_decay
    }, {
        'params': [
            p for n, p in teacher_model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }]
    student_parameters = [{
        'params': [
            p for n, p in student_model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        args.weight_decay
    }, {
        'params': [
            p for n, p in student_model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }]

    t_optimizer = optim.SGD(
        teacher_parameters,
        lr=args.lr,
        momentum=args.momentum,
        # weight_decay=args.weight_decay,
        nesterov=args.nesterov)
    s_optimizer = optim.SGD(
        student_parameters,
        lr=args.lr,
        momentum=args.momentum,
        # weight_decay=args.weight_decay,
        nesterov=args.nesterov)

    t_scheduler = get_cosine_schedule_with_warmup(t_optimizer,
                                                  args.warmup_steps,
                                                  args.total_steps)
    s_scheduler = get_cosine_schedule_with_warmup(s_optimizer,
                                                  args.warmup_steps,
                                                  args.total_steps,
                                                  args.student_wait_steps)

    t_scaler = amp.GradScaler(enabled=args.amp)
    s_scaler = amp.GradScaler(enabled=args.amp)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            logger.info(f"=> loading checkpoint '{args.resume}'")
            loc = f'cuda:{args.gpu}'
            checkpoint = torch.load(args.resume, map_location=loc)
            args.best_top1 = checkpoint['best_top1'].to(torch.device('cpu'))
            args.best_top5 = checkpoint['best_top5'].to(torch.device('cpu'))
            if not (args.evaluate or args.finetune):
                args.start_step = checkpoint['step']
                t_optimizer.load_state_dict(checkpoint['teacher_optimizer'])
                s_optimizer.load_state_dict(checkpoint['student_optimizer'])
                t_scheduler.load_state_dict(checkpoint['teacher_scheduler'])
                s_scheduler.load_state_dict(checkpoint['student_scheduler'])
                t_scaler.load_state_dict(checkpoint['teacher_scaler'])
                s_scaler.load_state_dict(checkpoint['student_scaler'])
                model_load_state_dict(teacher_model,
                                      checkpoint['teacher_state_dict'])
                if avg_student_model is not None:
                    model_load_state_dict(avg_student_model,
                                          checkpoint['avg_state_dict'])

            else:
                if checkpoint['avg_state_dict'] is not None:
                    model_load_state_dict(student_model,
                                          checkpoint['avg_state_dict'])
                else:
                    model_load_state_dict(student_model,
                                          checkpoint['student_state_dict'])

            logger.info(
                f"=> loaded checkpoint '{args.resume}' (step {checkpoint['step']})"
            )
        else:
            logger.info(f"=> no checkpoint found at '{args.resume}'")

    if args.local_rank != -1:
        teacher_model = nn.parallel.DistributedDataParallel(
            teacher_model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)
        student_model = nn.parallel.DistributedDataParallel(
            student_model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    if args.finetune:
        del t_scaler, t_scheduler, t_optimizer, teacher_model, unlabeled_loader
        del s_scaler, s_scheduler, s_optimizer
        finetune(args, labeled_loader, test_loader, student_model, criterion)
        return

    if args.evaluate:
        del t_scaler, t_scheduler, t_optimizer, teacher_model, unlabeled_loader, labeled_loader
        del s_scaler, s_scheduler, s_optimizer
        evaluate(args, test_loader, student_model, criterion)
        return

    teacher_model.zero_grad()
    student_model.zero_grad()
    train_loop(args, labeled_loader, unlabeled_loader, test_loader,
               teacher_model, student_model, avg_student_model, criterion,
               t_optimizer, s_optimizer, t_scheduler, s_scheduler, t_scaler,
               s_scaler)
    return
Exemple #19
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    # ensuring reproducibility
    SEED = 42
    torch.manual_seed(SEED)
    torch.backends.cudnn.benchmark = False

    kwargs = {'num_workers': 1, 'pin_memory': True}
    device = torch.device("cuda")

    num_epochs_transient = 2
    num_epochs_steady = 7
    perc_to_remove = 10

    torch.manual_seed(SEED)

    # create model
    model = WideResNet(args.layers,
                       10,
                       args.widen_factor,
                       dropRate=args.droprate).to(device)

    optimizer = torch.optim.Adam(model.parameters(),
                                 args.learning_rate,
                                 weight_decay=args.weight_decay)

    # instantiate loaders
    train_loader = get_data_loader(args.data_dir, args.batch_size, **kwargs)
    test_loader = get_test_loader(args.data_dir, 128, **kwargs)

    tic = time.time()
    seen_losses = None
    for epoch in range(1, 3):
        if epoch == 1:
            seen_losses = train_transient(model,
                                          device,
                                          train_loader,
                                          optimizer,
                                          epoch,
                                          track=True)
        else:
            train_transient(model, device, train_loader, optimizer, epoch)
        test(model, device, test_loader, epoch)

    for epoch in range(3, 4):
        seen_losses = [v for sublist in seen_losses for v in sublist]
        sorted_loss_idx = sorted(range(len(seen_losses)),
                                 key=lambda k: seen_losses[k][1],
                                 reverse=True)
        removed = sorted_loss_idx[-int((perc_to_remove / 100) *
                                       len(sorted_loss_idx)):]
        sorted_loss_idx = sorted_loss_idx[:-int((perc_to_remove / 100) *
                                                len(sorted_loss_idx))]
        to_add = list(
            np.random.choice(removed,
                             int(0.33 * len(sorted_loss_idx)),
                             replace=False))
        sorted_loss_idx = sorted_loss_idx + to_add
        sorted_loss_idx.sort()
        weights = [seen_losses[idx][1] for idx in sorted_loss_idx]
        train_loader = get_weighted_loader(args.data_dir, 64 * 2, weights,
                                           **kwargs)
        seen_losses = train_steady_state(model, device, train_loader,
                                         optimizer, epoch)
        test(model, device, test_loader, epoch)

    for epoch in range(4, 8):
        train_transient(model, device, train_loader, optimizer, epoch)
        test(model, device, test_loader, epoch)
    toc = time.time()
    print("Time Elapsed: {}s".format(toc - tic))