Beispiel #1
0
def main(argv):
    del argv
    with tf.Graph().as_default():
        model = WideResNet(hw=FLAGS.hw,
                           n_filters=FLAGS.n_filters,
                           repeat=FLAGS.repeat,
                           n_classes=FLAGS.n_classes)

        ops_dict = model.ops_dict(FLAGS.wd)
        train(ops_dict)
    return
    def load_model(self, is_cuda, load_dir=None, load_name=None, mode=None):
        """ Return WideResNet model, in gpu if applicable, and with provided checkpoint if given"""
        model = WideResNet(depth=28,
                           num_classes=10,
                           widen_factor=10,
                           dropRate=0.0)

        # Send to GPU if any
        if is_cuda:
            model = torch.nn.DataParallel(model).cuda()
            print(">>> SENDING MODEL TO GPU...")

        # Load checkpoint
        if load_dir and load_name and mode == TEST:
            model = self.load_checkpoint(model, load_dir, load_name)
            print(">>> LOADING PRE-TRAINED MODEL...")

        return model
def getNetwork(args):
    if (args.net_type == 'resnet'):
        net = ResNet(args.depth, num_classes)
        file_name = 'resnet-' + str(args.depth)
    elif (args.net_type == 'wide-resnet'):
        net = WideResNet(args.depth, args.widen_factor, args.dropout,
                         num_classes)
        file_name = 'wide-resnet-' + str(args.depth) + 'x' + str(
            args.widen_factor)
    elif (args.net_type == 'densenet'):
        pass
    else:
        print(
            'Error : Network should be either [LeNet / VGGNet / ResNet / Wide_ResNet'
        )
        sys.exit(0)

    return net, file_name
Beispiel #4
0
def save_tflite():
    output = 'agender.tflite'
    input_layer = tf.keras.layers.Input(batch_size=1, shape=[64, 64, 3])
    agender = WideResNet(input_layer)
    model = tf.keras.Model(input_layer, agender)
    model.load_weights('weights.28-3.73.hdf5')
    # model.summary()
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    if FLAGS.mode == 'full':
        converter.target_spec.supported_ops = [
            tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS
        ]
        converter.inference_input_type = tf.uint8
        converter.inference_output_type = tf.uint8
        converter.allow_custom_ops = True
        converter.representative_dataset = representative_data_gen
    if tf.__version__ >= '2.2.0':
        converter.experimental_new_converter = False
    tflite_model = converter.convert()
    open(output, 'wb').write(tflite_model)
    logging.info("tflite model is saved at {}".format(output))
    return output
def main():
    parser = argparse.ArgumentParser(
        description='PGD based adversarial training')
    args = parser.parse_args()

    # Model options
    args.adv_train = True

    # Training options
    args.dataset = 'cifar10'
    args.batch_size = 128
    args.max_epoch = 200
    args.lr = 0.1
    args.lr_step = 0.1
    args.lr_milestones = [100, 150]
    args.log_gap = 5

    # Attack options
    args.random_start = True
    args.step_size = 2.0 / 255
    args.epsilon = 8.0 / 255
    args.num_steps = 7
    args.targeted = False

    # Miscellaneous
    args.data_path = '~/datasets/CIFAR10'
    args.result_path = './results/classifier'
    args.tensorboard_path = './results/classifier/tensorboard/train'
    args.model_save_path = osp.join(args.result_path, 'model.latest')
    args.model_best_path = osp.join(args.result_path, 'model.best')

    if not osp.exists(args.result_path):
        os.makedirs(args.result_path)

    pprint(vars(args))

    transform_train = transforms.Compose([
        transforms.RandomCrop(32, 4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
    ])

    train_set = datasets.CIFAR10(root=args.data_path,
                                 train=True,
                                 download=True,
                                 transform=transform_train)
    val_set = datasets.CIFAR10(root=args.data_path,
                               train=False,
                               download=True,
                               transform=transform_test)

    train_loader = DataLoader(train_set,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=2)
    val_loader = DataLoader(val_set,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=2)

    classifier = WideResNet(depth=28, num_classes=10, widen_factor=2)
    model = AttackerModel(classifier, vars(args))
    model = torch.nn.DataParallel(model)
    model = model.cuda()

    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          momentum=0.9,
                          weight_decay=2e-4)
    schedule = optim.lr_scheduler.MultiStepLR(optimizer,
                                              milestones=args.lr_milestones,
                                              gamma=args.lr_step)

    writer = SummaryWriter(args.tensorboard_path)
    # writer = None

    train_model(args, train_loader, val_loader, model, optimizer, schedule,
                writer)
def main():
    if not torch.cuda.is_available():
        device = torch.device('cpu')
    else:
        torch.cuda.set_device(args.gpu)
        cudnn.benchmark = True
        cudnn.enabled = True
        device = torch.device("cuda")

    criterion = nn.CrossEntropyLoss().to(device)

    model = WideResNet(depth=40, num_classes=10, widen_factor=2, dropRate=0.3)
    model = model.to(device)
    summary(model, (3, 32, 32))

    optimizer = torch.optim.SGD(model.parameters(),
                                args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer,
        float(args.epochs),
        eta_min=args.learning_rate_min,
        last_epoch=-1)

    train_transform, valid_transform = data_transforms_cifar(args)
    trainset = dset.CIFAR10(root=args.data_dir,
                            train=True,
                            download=False,
                            transform=train_transform)
    train_queue = torch.utils.data.DataLoader(trainset,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              pin_memory=True,
                                              num_workers=8)
    valset = dset.CIFAR10(root=args.data_dir,
                          train=False,
                          download=False,
                          transform=valid_transform)
    valid_queue = torch.utils.data.DataLoader(valset,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              pin_memory=True,
                                              num_workers=8)

    best_acc = 0.0
    for epoch in range(args.epochs):
        t1 = time.time()

        # train
        train(args,
              epoch,
              train_queue,
              device,
              model,
              criterion=criterion,
              optimizer=optimizer)
        lr = scheduler.get_lr()[0]
        scheduler.step()

        # validate
        val_top1, val_top5, val_obj = validate(val_data=valid_queue,
                                               device=device,
                                               model=model)
        if val_top1 > best_acc:
            best_acc = val_top1
        t2 = time.time()

        print(
            '\nval: loss={:.6}, top1={:.6}, top5={:.6}, lr: {:.8}, time: {:.4}'
            .format(val_obj, val_top1, val_top5, lr, t2 - t1))
        print('Best Top1 Acc: {:.6}'.format(best_acc))
                       batch_size=args.batch_size,
                       shuffle=False,
                       num_workers=4)

for data, label in te_loader:

    data, label = tensor2cuda(data), tensor2cuda(label)

    break

adv_list = []
pred_list = []

with torch.no_grad():

    model = WideResNet(depth=34, num_classes=10, widen_factor=10, dropRate=0.0)

    load_model(model, args.load_checkpoint)

    if torch.cuda.is_available():
        model.cuda()

    attack = FastGradientSignUntargeted(model,
                                        max_epsilon,
                                        args.alpha,
                                        min_val=0,
                                        max_val=1,
                                        max_iters=args.k,
                                        _type=perturbation_type)

    adv_data = attack.perturb(data, label, 'mean', False)
Beispiel #8
0
def main(args):
    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)

    # controller
    controller = Controller(args).to(device)
    controller_optimizer = torch.optim.SGD(controller.parameters(),
                                           args.controller_lr,
                                           momentum=0.9)
    baseline = None

    # search
    for epoch in range(args.search_epochs):
        print('-' * 50)
        print('{} th search'.format(epoch + 1))
        print('-' * 50)

        # sample subpolicy
        print('*' * 30)
        print('sample subpolicy')
        print('*' * 30)
        controller.eval()
        policy_dict = controller.sample()
        policy_provider = Policy(args, policy_dict)
        for p in policy_dict:
            print(p)

        # get dataset
        train_queue, len_train, valid_queue, len_val, train_transform = get_data_loader(
            args, policy_provider)

        # train cnn
        print('*' * 30)
        print('train cnn')
        print('*' * 30)
        model = WideResNet(depth=args.layers,
                           num_classes=10,
                           widen_factor=args.widening_factor,
                           dropRate=args.dropout).to(device)
        val_acc = train_cnn(args, model, device, train_queue, len_train,
                            valid_queue, len_val)

        # train controller
        print('*' * 30)
        print('train controller')
        print('*' * 30)
        train_controller(args, controller, controller_optimizer, val_acc,
                         baseline)

        # save
        state = {
            'args': args,
            'best_acc': val_acc,
            'controller_state': controller.state_dict(),
            'policy_dict': policy_dict
        }
        torch.save(state, './models/{}.pt.tar'.format(epoch))
Beispiel #9
0
def main2(args):
    best_prec1 = 0.0

    torch.backends.cudnn.deterministic = not args.cudaNoise

    torch.manual_seed(time.time())

    if args.init != "None":
        args.name = "lrnet_%s" % args.init

    if args.tensorboard:
        configure(f"runs/{args.name}")

    dstype = nondigits(args.dataset)
    if dstype == "cifar":
        means = [125.3, 123.0, 113.9]
        stds = [63.0, 62.1, 66.7]
    elif dstype == "imgnet":
        means = [123.3, 118.1, 108.0]
        stds = [54.1, 52.6, 53.2]

    normalize = transforms.Normalize(
        mean=[x / 255.0 for x in means],
        std=[x / 255.0 for x in stds],
    )

    writer = SummaryWriter(log_dir="runs/%s" % args.name, comment=str(args))
    args.classes = onlydigits(args.dataset)

    if args.augment:
        transform_train = transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(lambda x: F.pad(x.unsqueeze(0), (4, 4, 4, 4),
                                              mode="reflect").squeeze()),
            transforms.ToPILImage(),
            transforms.RandomCrop(32),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
    else:
        transform_train = transforms.Compose(
            [transforms.ToTensor(), normalize])

    if args.cutout:
        transform_train.transforms.append(
            Cutout(n_holes=args.n_holes, length=args.length))

    transform_test = transforms.Compose([transforms.ToTensor(), normalize])

    kwargs = {"num_workers": 1, "pin_memory": True}

    assert dstype in ["cifar", "cinic", "imgnet"]

    if dstype == "cifar":
        train_loader = torch.utils.data.DataLoader(
            datasets.__dict__[args.dataset.upper()]("../data",
                                                    train=True,
                                                    download=True,
                                                    transform=transform_train),
            batch_size=args.batch_size,
            shuffle=True,
            **kwargs,
        )
        val_loader = torch.utils.data.DataLoader(
            datasets.__dict__[args.dataset.upper()]("../data",
                                                    train=False,
                                                    transform=transform_test),
            batch_size=args.batch_size,
            shuffle=True,
            **kwargs,
        )
    elif dstype == "cinic":
        cinic_directory = "%s/cinic10" % args.dir
        cinic_mean = [0.47889522, 0.47227842, 0.43047404]
        cinic_std = [0.24205776, 0.23828046, 0.25874835]
        train_loader = torch.utils.data.DataLoader(
            torchvision.datasets.ImageFolder(cinic_directory + '/train',
                                             transform=transforms.Compose([
                                                 transforms.ToTensor(),
                                                 transforms.Normalize(
                                                     mean=cinic_mean,
                                                     std=cinic_std)
                                             ])),
            batch_size=args.batch_size,
            shuffle=True,
            **kwargs,
        )
        print("Using CINIC10 dataset")
        val_loader = torch.utils.data.DataLoader(
            torchvision.datasets.ImageFolder(cinic_directory + '/valid',
                                             transform=transforms.Compose([
                                                 transforms.ToTensor(),
                                                 transforms.Normalize(
                                                     mean=cinic_mean,
                                                     std=cinic_std)
                                             ])),
            batch_size=args.batch_size,
            shuffle=True,
            **kwargs,
        )
    elif dstype == "imgnet":
        print("Using converted imagenet")
        train_loader = torch.utils.data.DataLoader(
            IMGNET("%s" % args.dir,
                   train=True,
                   transform=transform_train,
                   target_transform=None,
                   classes=args.classes),
            batch_size=args.batch_size,
            shuffle=True,
            **kwargs,
        )
        val_loader = torch.utils.data.DataLoader(
            IMGNET("%s" % args.dir,
                   train=False,
                   transform=transform_test,
                   target_transform=None,
                   classes=args.classes),
            batch_size=args.batch_size,
            shuffle=True,
            **kwargs,
        )
    else:
        print("Error matching dataset %s" % dstype)

    ##print("main bn:")
    ##print(args.batchnorm)
    ##print("main fixup:")
    ##print(args.fixup)

    if args.prune:
        pruner_state = getPruneMask(args)
        if pruner_state is None:
            print("Failed to prune network, aborting")
            return None

    if args.arch.lower() == "constnet":
        model = WideResNet(
            args.layers,
            args.classes,
            args.widen_factor,
            droprate=args.droprate,
            use_bn=args.batchnorm,
            use_fixup=args.fixup,
            varnet=args.varnet,
            noise=args.noise,
            lrelu=args.lrelu,
            sigmaW=args.sigmaW,
            init=args.init,
            dropl1=args.dropl1,
        )
    elif args.arch.lower() == "leakynet":
        model = LRNet(
            args.layers,
            args.classes,
            args.widen_factor,
            droprate=args.droprate,
            use_bn=args.batchnorm,
            use_fixup=args.fixup,
            varnet=args.varnet,
            noise=args.noise,
            lrelu=args.lrelu,
            sigmaW=args.sigmaW,
            init=args.init,
        )
    else:
        print("arch %s is not supported" % args.arch)
        return None

    ##draw(args,model)  complex installation

    param_num = sum([p.data.nelement() for p in model.parameters()])

    print(f"Number of model parameters: {param_num}")

    if torch.cuda.device_count() > 1:

        start = int(args.device[0])
        end = int(args.device[2]) + 1
        torch.cuda.set_device(start)
        dev_list = []
        for i in range(start, end):
            dev_list.append("cuda:%d" % i)
        model = torch.nn.DataParallel(model, device_ids=dev_list)

    model = model.cuda()

    if args.freeze > 0:
        cnt = 0
        for name, param in model.named_parameters():
            if intersection(['scale'], name.split('.')):
                cnt = cnt + 1
                if cnt == args.freeze:
                    break

            if cnt >= args.freeze_start:
                ##                if intersection(['conv','conv1'],name.split('.')):
                ##                    print("Freezing Block: %s" % name.split('.')[1:3]  )
                if not intersection(['conv_res', 'fc'], name.split('.')):
                    param.requires_grad = False
                    print("Freezing Block: %s" % name)

    elif args.freeze < 0:
        cnt = 0
        for name, param in model.named_parameters():
            if intersection(['scale'], name.split('.')):
                cnt = cnt + 1

            if cnt > args.layers - 3 + args.freeze - 1:
                ##                if intersection(['conv','conv1'],name.split('.')):
                ##                    print("Freezing Block: %s" % name  )

                if not intersection(['conv_res', 'fc'], name.split('.')):
                    param.requires_grad = False
                    print("Freezing Block: %s" % name)

    if args.res_freeze > 0:
        cnt = 0
        for name, param in model.named_parameters():
            if intersection(['conv_res'], name.split('.')):
                cnt = cnt + 1
                if cnt > args.res_freeze_start:
                    param.requires_grad = False
                    print("Freezing Block: %s" % name)
                if cnt >= args.res_freeze:
                    break
    elif args.res_freeze < 0:
        cnt = 0
        for name, param in model.named_parameters():
            if intersection(['conv_res'], name.split('.')):
                cnt = cnt + 1
                if cnt > 3 + args.res_freeze:
                    param.requires_grad = False
                    print("Freezing Block: %s" % name)

    if args.prune:
        if args.prune_epoch >= 100:
            weightsFile = "runs/%s-net/checkpoint.pth.tar" % args.prune
        else:
            weightsFile = "runs/%s-net/model_epoch_%d.pth.tar" % (
                args.prune, args.prune_epoch)

        if os.path.isfile(weightsFile):
            print(f"=> loading checkpoint {weightsFile}")
            checkpoint = torch.load(weightsFile)
            model.load_state_dict(checkpoint["state_dict"])
            print(
                f"=> loaded checkpoint '{weightsFile}' (epoch {checkpoint['epoch']})"
            )
        else:
            if args.prune_epoch == 0:
                print(f"=> No source data, Restarting network from scratch")
            else:
                print(f"=> no checkpoint found at {weightsFile}, aborting...")
                return None

    else:
        if args.resume:
            tarfile = "runs/%s-net/checkpoint.pth.tar" % args.resume
            if os.path.isfile(tarfile):
                print(f"=> loading checkpoint {args.resume}")
                checkpoint = torch.load(tarfile)
                args.start_epoch = checkpoint["epoch"]
                best_prec1 = checkpoint["best_prec1"]
                model.load_state_dict(checkpoint["state_dict"])
                print(
                    f"=> loaded checkpoint '{tarfile}' (epoch {checkpoint['epoch']})"
                )
            else:
                print(f"=> no checkpoint found at {tarfile}, aborting...")
                return None

    cudnn.benchmark = True
    criterion = nn.CrossEntropyLoss().cuda()

    if args.optimizer.lower() == 'sgd':
        optimizer = torch.optim.SGD(
            model.parameters(),
            args.lr,
            momentum=args.momentum,
            nesterov=args.nesterov,
            weight_decay=args.weight_decay,
        )
    elif args.optimizer.lower() == 'radam':
        optimizer = RAdam(model.parameters(),
                          lr=args.lr,
                          betas=(args.beta1, args.beta2),
                          weight_decay=args.weight_decay)

    if args.prune and pruner_state is not None:
        cutoff_retrain = prunhild.cutoff.LocalRatioCutoff(args.cutoff)
        params_retrain = get_params_for_pruning(args, model)
        pruner_retrain = prunhild.pruner.CutoffPruner(params_retrain,
                                                      cutoff_retrain)
        pruner_retrain.load_state_dict(pruner_state)
        pruner_retrain.prune(update_state=False)
        pruned_weights_count = count_pruned_weights(params_retrain,
                                                    args.cutoff)
        params_left = param_num - pruned_weights_count
        print("Pruned %d weights, New model size:  %d/%d (%d%%)" %
              (pruned_weights_count, params_left, param_num,
               int(100 * params_left / param_num)))

    else:
        pruner_retrain = None

    if args.eval:
        best_prec1 = validate(args, val_loader, model, criterion, 0, None)
    else:

        if args.varnet:
            save_checkpoint(
                args,
                {
                    "epoch": 0,
                    "state_dict": model.state_dict(),
                    "best_prec1": 0.0,
                },
                True,
            )
            best_prec1 = 0.0

        turns_above_50 = 0

        for epoch in range(args.start_epoch, args.epochs):
            adjust_learning_rate(args, optimizer, epoch + 1)
            train(args, train_loader, model, criterion, optimizer, epoch,
                  pruner_retrain, writer)

            prec1 = validate(args, val_loader, model, criterion, epoch, writer)
            correlation.measure_correlation(model, epoch, writer=writer)

            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)

            if args.savenet:
                save_checkpoint(
                    args,
                    {
                        "epoch": epoch + 1,
                        "state_dict": model.state_dict(),
                        "best_prec1": best_prec1,
                    },
                    is_best,
                )
            if args.symmetry_break:
                if prec1 > 50.0:
                    turns_above_50 += 1
                    if turns_above_50 > 3:
                        return epoch

    writer.close()

    print("Best accuracy: ", best_prec1)
    return best_prec1
def main():
    args = vars(get_args())
    dir_path = os.path.dirname(os.path.realpath(__file__))
    if args['config_path'] is not None and os.path.exists(os.path.join(dir_path, args['config_path'])):
        args = load_config(args)
    start_epoch = 0
    log_path = f'.logs/{args["dataset"]}@{args["labelled_examples"]}'
    ckpt_dir = f'{log_path}/checkpoints'

    datasetX, datasetU, val_dataset, test_dataset, num_classes = fetch_dataset(args, log_path)

    model = WideResNet(num_classes, depth=28, width=2)
    model.build(input_shape=(None, 32, 32, 3))
    optimizer = tf.keras.optimizers.Adam(lr=args['learning_rate'])
    model_ckpt = tf.train.Checkpoint(step=tf.Variable(0), optimizer=optimizer, net=model)
    manager = tf.train.CheckpointManager(model_ckpt, f'{ckpt_dir}/model', max_to_keep=3)

    ema_model = WideResNet(num_classes, depth=28, width=2)
    ema_model.build(input_shape=(None, 32, 32, 3))
    ema_model.set_weights(model.get_weights())
    ema_ckpt = tf.train.Checkpoint(step=tf.Variable(0), net=ema_model)
    ema_manager = tf.train.CheckpointManager(ema_ckpt, f'{ckpt_dir}/ema', max_to_keep=3)

    if args['resume']:
        model_ckpt.restore(manager.latest_checkpoint)
        ema_ckpt.restore(manager.latest_checkpoint)
        model_ckpt.step.assign_add(1)
        ema_ckpt.step.assign_add(1)
        start_epoch = int(model_ckpt.step)
        print(f'Restored @ epoch {start_epoch} from {manager.latest_checkpoint} and {ema_manager.latest_checkpoint}')

    train_writer = None
    if args['tensorboard']:
        train_writer = tf.summary.create_file_writer(f'{log_path}/train')
        val_writer = tf.summary.create_file_writer(f'{log_path}/validation')
        test_writer = tf.summary.create_file_writer(f'{log_path}/test')

    # assigning args used in functions wrapped with tf.function to tf.constant/tf.Variable to avoid memory leaks
    args['T'] = tf.constant(args['T'])
    args['beta'] = tf.Variable(0., shape=())
    for epoch in range(start_epoch, args['epochs']):
        xe_loss, l2u_loss, total_loss, accuracy = train(datasetX, datasetU, model, ema_model, optimizer, epoch, args)
        val_xe_loss, val_accuracy = validate(val_dataset, ema_model, epoch, args, split='Validation')
        test_xe_loss, test_accuracy = validate(test_dataset, ema_model, epoch, args, split='Test')

        if (epoch - start_epoch) % 16 == 0:
            model_save_path = manager.save(checkpoint_number=int(model_ckpt.step))
            ema_save_path = ema_manager.save(checkpoint_number=int(ema_ckpt.step))
            print(f'Saved model checkpoint for epoch {int(model_ckpt.step)} @ {model_save_path}')
            print(f'Saved ema checkpoint for epoch {int(ema_ckpt.step)} @ {ema_save_path}')

        model_ckpt.step.assign_add(1)
        ema_ckpt.step.assign_add(1)

        step = args['val_iteration'] * (epoch + 1)
        if args['tensorboard']:
            with train_writer.as_default():
                tf.summary.scalar('xe_loss', xe_loss.result(), step=step)
                tf.summary.scalar('l2u_loss', l2u_loss.result(), step=step)
                tf.summary.scalar('total_loss', total_loss.result(), step=step)
                tf.summary.scalar('accuracy', accuracy.result(), step=step)
            with val_writer.as_default():
                tf.summary.scalar('xe_loss', val_xe_loss.result(), step=step)
                tf.summary.scalar('accuracy', val_accuracy.result(), step=step)
            with test_writer.as_default():
                tf.summary.scalar('xe_loss', test_xe_loss.result(), step=step)
                tf.summary.scalar('accuracy', test_accuracy.result(), step=step)

    if args['tensorboard']:
        for writer in [train_writer, val_writer, test_writer]:
            writer.flush()
def main():
    global datasetX, datasetU, val_dataset, model, ema_model, optimizer, epoch, args
    args = vars(get_args())
    epoch = args['epochs']
    start_epoch = 0
    record_path = f'.logs/{args["dataset"]}@{args["labelled_examples"]}'
    ckpt_dir = f'{record_path}/checkpoints'
    datasetX, datasetU, val_dataset, test_dataset, num_classes = preprocess_dataset(args, record_path)

    model = WideResNet(num_classes, depth=28, width=2)
    model.build(input_shape=(None, 32, 32, 3))
    optimizer = tf.keras.optimizers.Adam(lr=args['learning_rate'])
    model_ckpt = tf.train.Checkpoint(step=tf.Variable(0), optimizer=optimizer, net=model)
    manager = tf.train.CheckpointManager(model_ckpt, f'{ckpt_dir}/model', max_to_keep=3)

    ema_model = WideResNet(num_classes, depth=28, width=2)
    ema_model.build(input_shape=(None, 32, 32, 3))
    ema_model.set_weights(model.get_weights())
    ema_ckpt = tf.train.Checkpoint(step=tf.Variable(0), net=ema_model)
    ema_manager = tf.train.CheckpointManager(ema_ckpt, f'{ckpt_dir}/ema', max_to_keep=3)

    if args['resume']:
        model_ckpt.restore(manager.latest_checkpoint)
        ema_ckpt.restore(manager.latest_checkpoint)
        model_ckpt.step.assign_add(1)
        ema_ckpt.step.assign_add(1)
        start_epoch = int(model_ckpt.step)
        print(f'Restored @ epoch {start_epoch} from {manager.latest_checkpoint} and {ema_manager.latest_checkpoint}')

    train_writer = None
    if args['tensorboard']:
        train_writer = tf.summary.create_file_writer(f'{record_path}/train')
        val_writer = tf.summary.create_file_writer(f'{record_path}/validation')
        test_writer = tf.summary.create_file_writer(f'{record_path}/test')

    args['T'] = tf.constant(args['T'])
    args['beta'] = tf.Variable(0., shape=())

    if args['mode']=='tuning':
        params=[datasetX, datasetU, val_dataset, model, ema_model, optimizer, epoch, args]
        Bayesian_Optimization(params)
    else:
        for epoch in range(start_epoch, args['epochs']):
            xe_loss, l2u_loss, total_loss, accuracy = train(datasetX, datasetU, model, ema_model, optimizer, epoch,
                                                            args)
            val_xe_loss, val_accuracy = validate(val_dataset, ema_model, epoch, args, split='Validation')
            test_xe_loss, test_accuracy = validate(test_dataset, ema_model, epoch, args, split='Test')

            if (epoch - start_epoch) % 16 == 0:
                model_save_path = manager.save(checkpoint_number=int(model_ckpt.step))
                ema_save_path = ema_manager.save(checkpoint_number=int(ema_ckpt.step))
                print(f'Saved model checkpoint for epoch {int(model_ckpt.step)} @ {model_save_path}')
                print(f'Saved ema checkpoint for epoch {int(ema_ckpt.step)} @ {ema_save_path}')

            model_ckpt.step.assign_add(1)
            ema_ckpt.step.assign_add(1)

            step = args['val_iteration'] * (epoch + 1)
            if args['tensorboard']:
                with train_writer.as_default():
                    tf.summary.scalar('xe_loss', xe_loss.result(), step=step)
                    tf.summary.scalar('l2u_loss', l2u_loss.result(), step=step)
                    tf.summary.scalar('total_loss', total_loss.result(), step=step)
                    tf.summary.scalar('accuracy', accuracy.result(), step=step)
                with val_writer.as_default():
                    tf.summary.scalar('xe_loss', val_xe_loss.result(), step=step)
                    tf.summary.scalar('accuracy', val_accuracy.result(), step=step)
                with test_writer.as_default():
                    tf.summary.scalar('xe_loss', test_xe_loss.result(), step=step)
                    tf.summary.scalar('accuracy', test_accuracy.result(), step=step)

    if args['tensorboard']:
        for writer in [train_writer, val_writer, test_writer]:
            writer.flush()
Beispiel #12
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    if args.tensorboard:
        configure(f"runs/{args.name}")

    normalize = transforms.Normalize(
        mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
        std=[x / 255.0 for x in [63.0, 62.1, 66.7]],
    )

    if args.augment:
        transform_train = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Lambda(
                    lambda x: F.pad(
                        x.unsqueeze(0), (4, 4, 4, 4), mode="reflect"
                    ).squeeze()
                ),
                transforms.ToPILImage(),
                transforms.RandomCrop(32),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]
        )
    else:
        transform_train = transforms.Compose([transforms.ToTensor(), normalize])

    if args.cutout:
        transform_train.transforms.append(
            Cutout(n_holes=args.n_holes, length=args.length)
        )

    transform_test = transforms.Compose([transforms.ToTensor(), normalize])

    kwargs = {"num_workers": 1, "pin_memory": True}
    assert args.dataset == "cifar10" or args.dataset == "cifar100"

    train_loader = torch.utils.data.DataLoader(
        datasets.__dict__[args.dataset.upper()](
            "../data", train=True, download=True, transform=transform_train
        ),
        batch_size=args.batch_size,
        shuffle=True,
        **kwargs,
    )
    val_loader = torch.utils.data.DataLoader(
        datasets.__dict__[args.dataset.upper()](
            "../data", train=False, transform=transform_test
        ),
        batch_size=args.batch_size,
        shuffle=True,
        **kwargs,
    )

    model = WideResNet(
        args.layers,
        args.dataset == "cifar10" and 10 or 100,
        args.widen_factor,
        droprate=args.droprate,
        use_bn=args.batchnorm,
        use_fixup=args.fixup,
    )

    param_num = sum([p.data.nelement() for p in model.parameters()])
    print(f"Number of model parameters: {param_num}")

    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)
    model = model.cuda()

    if args.resume:
        if os.path.isfile(args.resume):
            print(f"=> loading checkpoint {args.resume}")
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint["epoch"]
            best_prec1 = checkpoint["best_prec1"]
            model.load_state_dict(checkpoint["state_dict"])
            print(f"=> loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']})")
        else:
            print(f"=> no checkpoint found at {args.resume}")

    cudnn.benchmark = True
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD(
        model.parameters(),
        args.lr,
        momentum=args.momentum,
        nesterov=args.nesterov,
        weight_decay=args.weight_decay,
    )

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch + 1)
        train(train_loader, model, criterion, optimizer, epoch)

        prec1 = validate(val_loader, model, criterion, epoch)
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint(
            {
                "epoch": epoch + 1,
                "state_dict": model.state_dict(),
                "best_prec1": best_prec1,
            },
            is_best,
        )

    print("Best accuracy: ", best_prec1)
Beispiel #13
0
        'image': x_train,
        'label': y_train
    })

y_test = tf.one_hot(y_test, depth=10, dtype=tf.float32)
y_test = tf.squeeze(y_test,axis=1)
x_test = x_test/255
x_test = x_test*2-1
cifar10_test_dataset = tf.data.Dataset.from_tensor_slices({
        'image': x_test,
        'label': y_test
    })

trainX, trainU, validation = split_dataset(cifar10_train_dataset, 4000, 5000,10)
#%%
model = WideResNet(10, depth=28, width=2)
model.build(input_shape=(None, 32, 32, 3))
optimizer = tf.keras.optimizers.Adam(lr=0.01)
# model_ckpt = tf.train.Checkpoint(step=tf.Variable(0), optimizer=optimizer, net=model)
# manager = tf.train.CheckpointManager(model_ckpt, f'{ckpt_dir}/model', max_to_keep=3)

ema_model = WideResNet(10, depth=28, width=2)
ema_model.build(input_shape=(None, 32, 32, 3))
ema_model.set_weights(model.get_weights())
# ema_ckpt = tf.train.Checkpoint(step=tf.Variable(0), net=ema_model)
# ema_manager = tf.train.CheckpointManager(ema_ckpt, f'{ckpt_dir}/ema', max_to_keep=3)

#%%
def train(trainX, trainU, model, ema_model, optimizer, epoch):
    xe_loss_avg = tf.keras.metrics.Mean()
    l2u_loss_avg = tf.keras.metrics.Mean()
Beispiel #14
0
command.add_argument('--test', action='store_true', dest='test')
command.add_argument('--train', action='store_false', dest='test')


if __name__ == '__main__':
    args = parser.parse_args()
    cuda = torch.cuda.is_available() and args.cuda
    train_dataset = TRAIN_DATASETS[args.dataset]
    test_dataset = TEST_DATASETS[args.dataset]
    dataset_config = DATASET_CONFIGS[args.dataset]

    # instantiate the model instance.
    wrn = WideResNet(
        args.dataset,
        dataset_config['size'],
        dataset_config['channels'],
        dataset_config['classes'],
        total_block_number=args.total_block_number,
        widen_factor=args.widen_factor,
    )

    # prepare cuda if needed.
    if cuda:
        wrn.cuda()

    # run the given command.
    if args.test:
        utils.load_checkpoint(wrn, args.model_dir, best=True)
        utils.validate(
            wrn, test_dataset, test_size=args.test_size,
            cuda=cuda, verbose=True
        )
Beispiel #15
0
main_command.add_argument('--train', action='store_false', dest='test')

if __name__ == '__main__':
    args = parser.parse_args()
    cuda = torch.cuda.is_available() and args.cuda
    train_dataset = TRAIN_DATASETS[args.dataset]
    test_dataset = TEST_DATASETS[args.dataset]
    dataset_config = DATASET_CONFIGS[args.dataset]

    # instantiate the model instance.
    wrn = WideResNet(
        args.dataset,
        dataset_config['size'],
        dataset_config['channels'],
        dataset_config['classes'],
        total_block_number=args.total_block_number,
        widen_factor=args.widen_factor,
        dropout_prob=args.dropout_prob,
        baseline_strides=args.baseline_strides,
        baseline_channels=args.baseline_channels,
        split_sizes=args.split_sizes,
    )

    # initialize the weights.
    utils.xavier_initialize(wrn)

    # prepare cuda if needed.
    if cuda:
        wrn.cuda()

    # run the given command.
    if args.test:
Beispiel #16
0
    if len(sys.argv) == 2:
        fp_train = sys.argv[1]
    else:
        print('Usage:')
        print('    python3 train.py [training data]')

    ### Load data ###
    print('Loading data ...')

    train_loader = load_data(fp_train)

    print('Done!')

    ### Building model ###

    model = WideResNet()
    model.cuda()

    optimizer = optim.SGD(model.parameters(),
                          lr=.1,
                          momentum=.9,
                          nesterov=True)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [60, 120, 180], 0.2)

    num_epochs = 180

    for epoch in range(num_epochs):
        scheduler.step()
        train(model, optimizer, epoch, train_loader)

    torch.save(model.state_dict(), fp_model + '.' + str(epoch + 1) + '.pt')
Beispiel #17
0
def getPruneMask(args):
    baseTar = "runs/%s-net/checkpoint.pth.tar" % args.prune
    if os.path.isfile(baseTar):

        classes = onlydigits(args.prune_classes)
        if classes == 0:
            classes = args.classes

        fullModel = WideResNet(
            args.layers,
            classes,
            args.widen_factor,
            droprate=args.droprate,
            use_bn=args.batchnorm,
            use_fixup=args.fixup,
            varnet=args.varnet,
            noise=args.noise,
            lrelu=args.lrelu,
            sigmaW=args.sigmaW,
        )

        if torch.cuda.device_count() > 1:

            start = int(args.device[0])
            end = int(args.device[2]) + 1
            torch.cuda.set_device(start)
            dev_list = []
            for i in range(start, end):
                dev_list.append("cuda:%d" % i)
            fullModel = torch.nn.DataParallel(fullModel, device_ids=dev_list)

        fullModel = fullModel.cuda()

        print(f"=> loading checkpoint {baseTar}")

        checkpoint = torch.load(baseTar)
        fullModel.load_state_dict(checkpoint["state_dict"])

        # --------------------------- #
        # --- Pruning Setup Start --- #

        cutoff = prunhild.cutoff.LocalRatioCutoff(args.cutoff)
        # don't prune the final bias weights
        params = get_params_for_pruning(args, fullModel)

        print(params)

        pruner = prunhild.pruner.CutoffPruner(params,
                                              cutoff,
                                              prune_online=True)
        pruner.prune()

        print(
            f"=> loaded checkpoint '{baseTar}' (epoch {checkpoint['epoch']})")

        if torch.cuda.device_count() > 1:
            start = int(args.device[0])
            end = int(args.device[2]) + 1
            for i in range(start, end):
                torch.cuda.set_device(i)
                torch.cuda.empty_cache()

        mask = pruner.state_dict()
        if args.randomize_mask:
            mask = randomize_mask(mask, args.cutoff)

        return mask
    else:
        print(f"=> no checkpoint found at {baseTar}")
        return None
Beispiel #18
0
def main(args):

    save_folder = '%s_%s' % (args.dataset, args.affix)

    log_folder = os.path.join(args.log_root, save_folder)
    model_folder = os.path.join(args.model_root, save_folder)

    makedirs(log_folder)
    makedirs(model_folder)

    setattr(args, 'log_folder', log_folder)
    setattr(args, 'model_folder', model_folder)

    logger = create_logger(log_folder, args.todo, 'info')

    print_args(args, logger)

    # Using a WideResNet model
    model = WideResNet(depth=34, num_classes=10, widen_factor=1, dropRate=0.0)
    flop, param = get_model_infos(model, (1, 3, 32, 32))
    logger.info('Model Info: FLOP = {:.2f} M, Params = {:.2f} MB'.format(
        flop, param))

    # Configuring the train attack mode
    if args.adv_train_mode == 'FGSM':
        train_attack = FastGradientSignUntargeted(model,
                                                  args.epsilon,
                                                  args.alpha,
                                                  min_val=0,
                                                  max_val=1,
                                                  max_iters=args.k,
                                                  _type=args.perturbation_type,
                                                  logger=logger)
    elif args.adv_train_mode == 'CW':
        mean = [0]
        std = [1]
        inputs_box = (min((0 - m) / s for m, s in zip(mean, std)),
                      max((1 - m) / s for m, s in zip(mean, std)))
        train_attack = carlini_wagner_L2.L2Adversary(targeted=False,
                                                     confidence=0.0,
                                                     search_steps=10,
                                                     optimizer_lr=5e-4,
                                                     logger=logger)

    # Configuring the test attack mode
    if args.adv_test_mode == 'FGSM':
        test_attack = FastGradientSignUntargeted(model,
                                                 args.epsilon,
                                                 args.alpha,
                                                 min_val=0,
                                                 max_val=1,
                                                 max_iters=args.k,
                                                 _type=args.perturbation_type,
                                                 logger=logger)
    elif args.adv_test_mode == 'CW':
        mean = [0]
        std = [1]
        inputs_box = (min((0 - m) / s for m, s in zip(mean, std)),
                      max((1 - m) / s for m, s in zip(mean, std)))
        test_attack = carlini_wagner_L2.L2Adversary(targeted=False,
                                                    confidence=0.0,
                                                    search_steps=10,
                                                    optimizer_lr=5e-4,
                                                    logger=logger)

    if torch.cuda.is_available():
        model.cuda()

    trainer = Trainer(args, logger, train_attack, test_attack)

    if args.todo == 'train':
        transform_train = tv.transforms.Compose([
            tv.transforms.ToTensor(),
            tv.transforms.Lambda(lambda x: F.pad(
                x.unsqueeze(0),
                (4, 4, 4, 4), mode='constant', value=0).squeeze()),
            tv.transforms.ToPILImage(),
            tv.transforms.RandomCrop(32),
            tv.transforms.RandomHorizontalFlip(),
            tv.transforms.ToTensor(),
        ])
        tr_dataset = tv.datasets.CIFAR10(args.data_root,
                                         train=True,
                                         transform=transform_train,
                                         download=True)

        tr_loader = DataLoader(tr_dataset,
                               batch_size=args.batch_size,
                               shuffle=True,
                               num_workers=4)

        # evaluation during training
        te_dataset = tv.datasets.CIFAR10(args.data_root,
                                         train=False,
                                         transform=tv.transforms.ToTensor(),
                                         download=True)

        te_loader = DataLoader(te_dataset,
                               batch_size=args.batch_size,
                               shuffle=False,
                               num_workers=4)

        trainer.train(model, tr_loader, te_loader, args.adv_train)
    elif args.todo == 'test':
        pass
    else:
        raise NotImplementedError
def main(args):

    save_folder = '%s_%s' % (args.dataset, args.affix)

    log_folder = os.path.join(args.log_root, save_folder)
    model_folder = os.path.join(args.model_root, save_folder)

    makedirs(log_folder)
    makedirs(model_folder)

    setattr(args, 'log_folder', log_folder)
    setattr(args, 'model_folder', model_folder)

    logger = create_logger(log_folder, args.todo, 'info')

    print_args(args, logger)

    model = WideResNet(depth=34, num_classes=10, widen_factor=10, dropRate=0.0)

    attack = FastGradientSignUntargeted(model, 
                                        args.epsilon, 
                                        args.alpha, 
                                        min_val=0, 
                                        max_val=1, 
                                        max_iters=args.k, 
                                        _type=args.perturbation_type)

    if torch.cuda.is_available():
        model.cuda()

    trainer = Trainer(args, logger, attack)

    if args.todo == 'train':
        transform_train = tv.transforms.Compose([
                tv.transforms.ToTensor(),
                tv.transforms.Lambda(lambda x: F.pad(x.unsqueeze(0),
                                    (4,4,4,4), mode='constant', value=0).squeeze()),
                tv.transforms.ToPILImage(),
                tv.transforms.RandomCrop(32),
                tv.transforms.RandomHorizontalFlip(),
                tv.transforms.ToTensor(),
            ])
        tr_dataset = tv.datasets.CIFAR10(args.data_root, 
                                       train=True, 
                                       transform=transform_train, 
                                       download=True)

        tr_loader = DataLoader(tr_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4)

        # evaluation during training
        te_dataset = tv.datasets.CIFAR10(args.data_root, 
                                       train=False, 
                                       transform=tv.transforms.ToTensor(), 
                                       download=True)

        te_loader = DataLoader(te_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)

        trainer.train(model, tr_loader, te_loader, args.adv_train)
    elif args.todo == 'test':
        pass
    else:
        raise NotImplementedError
Beispiel #20
0
def main():
    model = WideResNet(numClasses, depth=28, width=2)
    emaModel = WideResNet(numClasses, depth=28, width=2)

    (X_train, Y_train), U_train, (X_test, Y_test) = load_CIFAR_10(labeledExamples=labeledExamples)
    model.build(input_shape=(None, 32, 32, 3))
    emaModel.build(input_shape=(None, 32, 32, 3))

    X_train = tf.data.Dataset.from_tensor_slices({'image': X_train, 'label': Y_train})
    X_test = tf.data.Dataset.from_tensor_slices({'image': X_test, 'label': Y_test})  
    U_train = tf.data.Dataset.from_tensor_slices(U_train)

    optimizer = tf.keras.optimizers.Adam(lr=lr)
    emaModel.set_weights(model.get_weights())

    accuracy = tf.keras.metrics.SparseCategoricalAccuracy()

    T = tf.constant(0.5)
    beta = tf.Variable(0., shape=())

    for epoch in range(0, epochs):
        train(X_train, U_train, model, emaModel, optimizer, epoch, T, beta)
        testAccuracy = validate(X_test, emaModel)
        testAccuracy = testAccuracy.result()
        print("Epoch: {} and test accuracy: {}".format(epoch, testAccuracy))
        
        with open('results.txt', 'w') as f:
            f.write("num_label={}, accuracy={}, epoch={}".format(labeledExamples, testAccuracy, epoch))
            f.close()