コード例 #1
0
def main2(args):
    best_prec1 = 0.0

    torch.backends.cudnn.deterministic = not args.cudaNoise

    torch.manual_seed(time.time())

    if args.init != "None":
        args.name = "lrnet_%s" % args.init

    if args.tensorboard:
        configure(f"runs/{args.name}")

    dstype = nondigits(args.dataset)
    if dstype == "cifar":
        means = [125.3, 123.0, 113.9]
        stds = [63.0, 62.1, 66.7]
    elif dstype == "imgnet":
        means = [123.3, 118.1, 108.0]
        stds = [54.1, 52.6, 53.2]

    normalize = transforms.Normalize(
        mean=[x / 255.0 for x in means],
        std=[x / 255.0 for x in stds],
    )

    writer = SummaryWriter(log_dir="runs/%s" % args.name, comment=str(args))
    args.classes = onlydigits(args.dataset)

    if args.augment:
        transform_train = transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(lambda x: F.pad(x.unsqueeze(0), (4, 4, 4, 4),
                                              mode="reflect").squeeze()),
            transforms.ToPILImage(),
            transforms.RandomCrop(32),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
    else:
        transform_train = transforms.Compose(
            [transforms.ToTensor(), normalize])

    if args.cutout:
        transform_train.transforms.append(
            Cutout(n_holes=args.n_holes, length=args.length))

    transform_test = transforms.Compose([transforms.ToTensor(), normalize])

    kwargs = {"num_workers": 1, "pin_memory": True}

    assert dstype in ["cifar", "cinic", "imgnet"]

    if dstype == "cifar":
        train_loader = torch.utils.data.DataLoader(
            datasets.__dict__[args.dataset.upper()]("../data",
                                                    train=True,
                                                    download=True,
                                                    transform=transform_train),
            batch_size=args.batch_size,
            shuffle=True,
            **kwargs,
        )
        val_loader = torch.utils.data.DataLoader(
            datasets.__dict__[args.dataset.upper()]("../data",
                                                    train=False,
                                                    transform=transform_test),
            batch_size=args.batch_size,
            shuffle=True,
            **kwargs,
        )
    elif dstype == "cinic":
        cinic_directory = "%s/cinic10" % args.dir
        cinic_mean = [0.47889522, 0.47227842, 0.43047404]
        cinic_std = [0.24205776, 0.23828046, 0.25874835]
        train_loader = torch.utils.data.DataLoader(
            torchvision.datasets.ImageFolder(cinic_directory + '/train',
                                             transform=transforms.Compose([
                                                 transforms.ToTensor(),
                                                 transforms.Normalize(
                                                     mean=cinic_mean,
                                                     std=cinic_std)
                                             ])),
            batch_size=args.batch_size,
            shuffle=True,
            **kwargs,
        )
        print("Using CINIC10 dataset")
        val_loader = torch.utils.data.DataLoader(
            torchvision.datasets.ImageFolder(cinic_directory + '/valid',
                                             transform=transforms.Compose([
                                                 transforms.ToTensor(),
                                                 transforms.Normalize(
                                                     mean=cinic_mean,
                                                     std=cinic_std)
                                             ])),
            batch_size=args.batch_size,
            shuffle=True,
            **kwargs,
        )
    elif dstype == "imgnet":
        print("Using converted imagenet")
        train_loader = torch.utils.data.DataLoader(
            IMGNET("%s" % args.dir,
                   train=True,
                   transform=transform_train,
                   target_transform=None,
                   classes=args.classes),
            batch_size=args.batch_size,
            shuffle=True,
            **kwargs,
        )
        val_loader = torch.utils.data.DataLoader(
            IMGNET("%s" % args.dir,
                   train=False,
                   transform=transform_test,
                   target_transform=None,
                   classes=args.classes),
            batch_size=args.batch_size,
            shuffle=True,
            **kwargs,
        )
    else:
        print("Error matching dataset %s" % dstype)

    ##print("main bn:")
    ##print(args.batchnorm)
    ##print("main fixup:")
    ##print(args.fixup)

    if args.prune:
        pruner_state = getPruneMask(args)
        if pruner_state is None:
            print("Failed to prune network, aborting")
            return None

    if args.arch.lower() == "constnet":
        model = WideResNet(
            args.layers,
            args.classes,
            args.widen_factor,
            droprate=args.droprate,
            use_bn=args.batchnorm,
            use_fixup=args.fixup,
            varnet=args.varnet,
            noise=args.noise,
            lrelu=args.lrelu,
            sigmaW=args.sigmaW,
            init=args.init,
            dropl1=args.dropl1,
        )
    elif args.arch.lower() == "leakynet":
        model = LRNet(
            args.layers,
            args.classes,
            args.widen_factor,
            droprate=args.droprate,
            use_bn=args.batchnorm,
            use_fixup=args.fixup,
            varnet=args.varnet,
            noise=args.noise,
            lrelu=args.lrelu,
            sigmaW=args.sigmaW,
            init=args.init,
        )
    else:
        print("arch %s is not supported" % args.arch)
        return None

    ##draw(args,model)  complex installation

    param_num = sum([p.data.nelement() for p in model.parameters()])

    print(f"Number of model parameters: {param_num}")

    if torch.cuda.device_count() > 1:

        start = int(args.device[0])
        end = int(args.device[2]) + 1
        torch.cuda.set_device(start)
        dev_list = []
        for i in range(start, end):
            dev_list.append("cuda:%d" % i)
        model = torch.nn.DataParallel(model, device_ids=dev_list)

    model = model.cuda()

    if args.freeze > 0:
        cnt = 0
        for name, param in model.named_parameters():
            if intersection(['scale'], name.split('.')):
                cnt = cnt + 1
                if cnt == args.freeze:
                    break

            if cnt >= args.freeze_start:
                ##                if intersection(['conv','conv1'],name.split('.')):
                ##                    print("Freezing Block: %s" % name.split('.')[1:3]  )
                if not intersection(['conv_res', 'fc'], name.split('.')):
                    param.requires_grad = False
                    print("Freezing Block: %s" % name)

    elif args.freeze < 0:
        cnt = 0
        for name, param in model.named_parameters():
            if intersection(['scale'], name.split('.')):
                cnt = cnt + 1

            if cnt > args.layers - 3 + args.freeze - 1:
                ##                if intersection(['conv','conv1'],name.split('.')):
                ##                    print("Freezing Block: %s" % name  )

                if not intersection(['conv_res', 'fc'], name.split('.')):
                    param.requires_grad = False
                    print("Freezing Block: %s" % name)

    if args.res_freeze > 0:
        cnt = 0
        for name, param in model.named_parameters():
            if intersection(['conv_res'], name.split('.')):
                cnt = cnt + 1
                if cnt > args.res_freeze_start:
                    param.requires_grad = False
                    print("Freezing Block: %s" % name)
                if cnt >= args.res_freeze:
                    break
    elif args.res_freeze < 0:
        cnt = 0
        for name, param in model.named_parameters():
            if intersection(['conv_res'], name.split('.')):
                cnt = cnt + 1
                if cnt > 3 + args.res_freeze:
                    param.requires_grad = False
                    print("Freezing Block: %s" % name)

    if args.prune:
        if args.prune_epoch >= 100:
            weightsFile = "runs/%s-net/checkpoint.pth.tar" % args.prune
        else:
            weightsFile = "runs/%s-net/model_epoch_%d.pth.tar" % (
                args.prune, args.prune_epoch)

        if os.path.isfile(weightsFile):
            print(f"=> loading checkpoint {weightsFile}")
            checkpoint = torch.load(weightsFile)
            model.load_state_dict(checkpoint["state_dict"])
            print(
                f"=> loaded checkpoint '{weightsFile}' (epoch {checkpoint['epoch']})"
            )
        else:
            if args.prune_epoch == 0:
                print(f"=> No source data, Restarting network from scratch")
            else:
                print(f"=> no checkpoint found at {weightsFile}, aborting...")
                return None

    else:
        if args.resume:
            tarfile = "runs/%s-net/checkpoint.pth.tar" % args.resume
            if os.path.isfile(tarfile):
                print(f"=> loading checkpoint {args.resume}")
                checkpoint = torch.load(tarfile)
                args.start_epoch = checkpoint["epoch"]
                best_prec1 = checkpoint["best_prec1"]
                model.load_state_dict(checkpoint["state_dict"])
                print(
                    f"=> loaded checkpoint '{tarfile}' (epoch {checkpoint['epoch']})"
                )
            else:
                print(f"=> no checkpoint found at {tarfile}, aborting...")
                return None

    cudnn.benchmark = True
    criterion = nn.CrossEntropyLoss().cuda()

    if args.optimizer.lower() == 'sgd':
        optimizer = torch.optim.SGD(
            model.parameters(),
            args.lr,
            momentum=args.momentum,
            nesterov=args.nesterov,
            weight_decay=args.weight_decay,
        )
    elif args.optimizer.lower() == 'radam':
        optimizer = RAdam(model.parameters(),
                          lr=args.lr,
                          betas=(args.beta1, args.beta2),
                          weight_decay=args.weight_decay)

    if args.prune and pruner_state is not None:
        cutoff_retrain = prunhild.cutoff.LocalRatioCutoff(args.cutoff)
        params_retrain = get_params_for_pruning(args, model)
        pruner_retrain = prunhild.pruner.CutoffPruner(params_retrain,
                                                      cutoff_retrain)
        pruner_retrain.load_state_dict(pruner_state)
        pruner_retrain.prune(update_state=False)
        pruned_weights_count = count_pruned_weights(params_retrain,
                                                    args.cutoff)
        params_left = param_num - pruned_weights_count
        print("Pruned %d weights, New model size:  %d/%d (%d%%)" %
              (pruned_weights_count, params_left, param_num,
               int(100 * params_left / param_num)))

    else:
        pruner_retrain = None

    if args.eval:
        best_prec1 = validate(args, val_loader, model, criterion, 0, None)
    else:

        if args.varnet:
            save_checkpoint(
                args,
                {
                    "epoch": 0,
                    "state_dict": model.state_dict(),
                    "best_prec1": 0.0,
                },
                True,
            )
            best_prec1 = 0.0

        turns_above_50 = 0

        for epoch in range(args.start_epoch, args.epochs):
            adjust_learning_rate(args, optimizer, epoch + 1)
            train(args, train_loader, model, criterion, optimizer, epoch,
                  pruner_retrain, writer)

            prec1 = validate(args, val_loader, model, criterion, epoch, writer)
            correlation.measure_correlation(model, epoch, writer=writer)

            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)

            if args.savenet:
                save_checkpoint(
                    args,
                    {
                        "epoch": epoch + 1,
                        "state_dict": model.state_dict(),
                        "best_prec1": best_prec1,
                    },
                    is_best,
                )
            if args.symmetry_break:
                if prec1 > 50.0:
                    turns_above_50 += 1
                    if turns_above_50 > 3:
                        return epoch

    writer.close()

    print("Best accuracy: ", best_prec1)
    return best_prec1
コード例 #2
0
ファイル: train.py プロジェクト: empennage98/ML2018SPRING
        fp_train = sys.argv[1]
    else:
        print('Usage:')
        print('    python3 train.py [training data]')

    ### Load data ###
    print('Loading data ...')

    train_loader = load_data(fp_train)

    print('Done!')

    ### Building model ###

    model = WideResNet()
    model.cuda()

    optimizer = optim.SGD(model.parameters(),
                          lr=.1,
                          momentum=.9,
                          nesterov=True)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [60, 120, 180], 0.2)

    num_epochs = 180

    for epoch in range(num_epochs):
        scheduler.step()
        train(model, optimizer, epoch, train_loader)

    torch.save(model.state_dict(), fp_model + '.' + str(epoch + 1) + '.pt')
コード例 #3
0
def main():
    if not torch.cuda.is_available():
        device = torch.device('cpu')
    else:
        torch.cuda.set_device(args.gpu)
        cudnn.benchmark = True
        cudnn.enabled = True
        device = torch.device("cuda")

    criterion = nn.CrossEntropyLoss().to(device)

    model = WideResNet(depth=40, num_classes=10, widen_factor=2, dropRate=0.3)
    model = model.to(device)
    summary(model, (3, 32, 32))

    optimizer = torch.optim.SGD(model.parameters(),
                                args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer,
        float(args.epochs),
        eta_min=args.learning_rate_min,
        last_epoch=-1)

    train_transform, valid_transform = data_transforms_cifar(args)
    trainset = dset.CIFAR10(root=args.data_dir,
                            train=True,
                            download=False,
                            transform=train_transform)
    train_queue = torch.utils.data.DataLoader(trainset,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              pin_memory=True,
                                              num_workers=8)
    valset = dset.CIFAR10(root=args.data_dir,
                          train=False,
                          download=False,
                          transform=valid_transform)
    valid_queue = torch.utils.data.DataLoader(valset,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              pin_memory=True,
                                              num_workers=8)

    best_acc = 0.0
    for epoch in range(args.epochs):
        t1 = time.time()

        # train
        train(args,
              epoch,
              train_queue,
              device,
              model,
              criterion=criterion,
              optimizer=optimizer)
        lr = scheduler.get_lr()[0]
        scheduler.step()

        # validate
        val_top1, val_top5, val_obj = validate(val_data=valid_queue,
                                               device=device,
                                               model=model)
        if val_top1 > best_acc:
            best_acc = val_top1
        t2 = time.time()

        print(
            '\nval: loss={:.6}, top1={:.6}, top5={:.6}, lr: {:.8}, time: {:.4}'
            .format(val_obj, val_top1, val_top5, lr, t2 - t1))
        print('Best Top1 Acc: {:.6}'.format(best_acc))
コード例 #4
0
ファイル: train.py プロジェクト: valilenk/fixup
def main():
    global args, best_prec1
    args = parser.parse_args()

    if args.tensorboard:
        configure(f"runs/{args.name}")

    normalize = transforms.Normalize(
        mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
        std=[x / 255.0 for x in [63.0, 62.1, 66.7]],
    )

    if args.augment:
        transform_train = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Lambda(
                    lambda x: F.pad(
                        x.unsqueeze(0), (4, 4, 4, 4), mode="reflect"
                    ).squeeze()
                ),
                transforms.ToPILImage(),
                transforms.RandomCrop(32),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]
        )
    else:
        transform_train = transforms.Compose([transforms.ToTensor(), normalize])

    if args.cutout:
        transform_train.transforms.append(
            Cutout(n_holes=args.n_holes, length=args.length)
        )

    transform_test = transforms.Compose([transforms.ToTensor(), normalize])

    kwargs = {"num_workers": 1, "pin_memory": True}
    assert args.dataset == "cifar10" or args.dataset == "cifar100"

    train_loader = torch.utils.data.DataLoader(
        datasets.__dict__[args.dataset.upper()](
            "../data", train=True, download=True, transform=transform_train
        ),
        batch_size=args.batch_size,
        shuffle=True,
        **kwargs,
    )
    val_loader = torch.utils.data.DataLoader(
        datasets.__dict__[args.dataset.upper()](
            "../data", train=False, transform=transform_test
        ),
        batch_size=args.batch_size,
        shuffle=True,
        **kwargs,
    )

    model = WideResNet(
        args.layers,
        args.dataset == "cifar10" and 10 or 100,
        args.widen_factor,
        droprate=args.droprate,
        use_bn=args.batchnorm,
        use_fixup=args.fixup,
    )

    param_num = sum([p.data.nelement() for p in model.parameters()])
    print(f"Number of model parameters: {param_num}")

    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)
    model = model.cuda()

    if args.resume:
        if os.path.isfile(args.resume):
            print(f"=> loading checkpoint {args.resume}")
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint["epoch"]
            best_prec1 = checkpoint["best_prec1"]
            model.load_state_dict(checkpoint["state_dict"])
            print(f"=> loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']})")
        else:
            print(f"=> no checkpoint found at {args.resume}")

    cudnn.benchmark = True
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD(
        model.parameters(),
        args.lr,
        momentum=args.momentum,
        nesterov=args.nesterov,
        weight_decay=args.weight_decay,
    )

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch + 1)
        train(train_loader, model, criterion, optimizer, epoch)

        prec1 = validate(val_loader, model, criterion, epoch)
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint(
            {
                "epoch": epoch + 1,
                "state_dict": model.state_dict(),
                "best_prec1": best_prec1,
            },
            is_best,
        )

    print("Best accuracy: ", best_prec1)