コード例 #1
0
ファイル: main_dxy.py プロジェクト: jss367/fastai-pythonic
def main():
    if not os.path.isdir(args.save_path):
        os.makedirs(args.save_path)
    log = open(
        os.path.join(args.save_path, "log_seed_{}.txt".format(args.manualSeed)), "w"
    )
    print_log("save path : {}".format(args.save_path), log)
    state = {k: v for k, v in args._get_kwargs()}
    print_log(state, log)
    print_log("Random Seed: {}".format(args.manualSeed), log)
    print_log("python version : {}".format(sys.version.replace("\n", " ")), log)
    print_log("torch  version : {}".format(torch.__version__), log)
    print_log("cudnn  version : {}".format(torch.backends.cudnn.version()), log)

    # Init dataset
    if not os.path.isdir(args.data_path):
        os.makedirs(args.data_path)

    if args.dataset == "cifar10":
        mean = [x / 255 for x in [125.3, 123.0, 113.9]]
        std = [x / 255 for x in [63.0, 62.1, 66.7]]
    elif args.dataset == "cifar100":
        mean = [x / 255 for x in [129.3, 124.1, 112.4]]
        std = [x / 255 for x in [68.2, 65.4, 70.4]]
    else:
        assert False, "Unknow dataset : {}".format(args.dataset)

    train_transform = transforms.Compose(
        [
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, padding=4),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ]
    )
    test_transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize(mean, std)]
    )

    if args.dataset == "cifar10":
        train_data = dset.CIFAR10(
            args.data_path, train=True, transform=train_transform, download=True
        )
        test_data = dset.CIFAR10(
            args.data_path, train=False, transform=test_transform, download=True
        )
        num_classes = 10
    elif args.dataset == "cifar100":
        train_data = dset.CIFAR100(
            args.data_path, train=True, transform=train_transform, download=True
        )
        test_data = dset.CIFAR100(
            args.data_path, train=False, transform=test_transform, download=True
        )
        num_classes = 100
    elif args.dataset == "svhn":
        train_data = dset.SVHN(
            args.data_path, split="train", transform=train_transform, download=True
        )
        test_data = dset.SVHN(
            args.data_path, split="test", transform=test_transform, download=True
        )
        num_classes = 10
    elif args.dataset == "stl10":
        train_data = dset.STL10(
            args.data_path, split="train", transform=train_transform, download=True
        )
        test_data = dset.STL10(
            args.data_path, split="test", transform=test_transform, download=True
        )
        num_classes = 10
    elif args.dataset == "imagenet":
        assert False, "Do not finish imagenet code"
    else:
        assert False, "Do not support dataset : {}".format(args.dataset)

    train_loader = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.workers,
        pin_memory=True,
    )
    test_loader = torch.utils.data.DataLoader(
        test_data,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=True,
    )

    # Init model, criterion, and optimizer
    # net = models.__dict__[args.arch](num_classes).cuda()
    net = SENet34()

    # define loss function (criterion) and optimizer
    criterion = F.nll_loss
    optimizer = torch.optim.SGD(
        net.parameters(),
        state["learning_rate"],
        momentum=state["momentum"],
        weight_decay=state["decay"],
        nesterov=True,
    )

    if args.use_cuda:
        net.cuda()

    recorder = RecorderMeter(args.epochs)
    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print_log("=> loading checkpoint '{}'".format(args.resume), log)
            checkpoint = torch.load(args.resume)
            recorder = checkpoint["recorder"]
            args.start_epoch = checkpoint["epoch"]
            net.load_state_dict(checkpoint["state_dict"])
            optimizer.load_state_dict(checkpoint["optimizer"])
            print_log(
                "=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint["epoch"]
                ),
                log,
            )
        else:
            print_log("=> no checkpoint found at '{}'".format(args.resume), log)
    else:
        print_log("=> do not use any checkpoint for model", log)

    if args.evaluate:
        validate(test_loader, net, criterion, log)
        return

    # Main loop
    start_time = time.time()
    epoch_time = AverageMeter()
    for epoch in range(args.start_epoch, args.epochs):
        current_learning_rate = adjust_learning_rate(
            optimizer, epoch, args.gammas, args.schedule
        )

        need_hour, need_mins, need_secs = convert_secs2time(
            epoch_time.avg * (args.epochs - epoch)
        )
        need_time = "[Need: {:02d}:{:02d}:{:02d}]".format(
            need_hour, need_mins, need_secs
        )

        print_log(
            "\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f}]".format(
                time_string(), epoch, args.epochs, need_time, current_learning_rate
            )
            + " [Best : Accuracy={:.2f}, Error={:.2f}]".format(
                recorder.max_accuracy(False), 100 - recorder.max_accuracy(False)
            ),
            log,
        )

        # train for one epoch
        train_acc, train_los = train(
            train_loader, net, criterion, optimizer, epoch, log
        )

        # evaluate on validation set
        val_acc, val_los = validate(test_loader, net, criterion, log)
        is_best = recorder.update(epoch, train_los, train_acc, val_los, val_acc)

        save_checkpoint(
            {
                "epoch": epoch + 1,
                "state_dict": net.state_dict(),
                "recorder": recorder,
                "optimizer": optimizer.state_dict(),
            },
            is_best,
            args.save_path,
            "checkpoint.pth.tar",
        )

        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()
        recorder.plot_curve(os.path.join(args.save_path, "curve.png"))

    log.close()
コード例 #2
0
def main():
    # Init logger
    if not os.path.isdir(args.save_path):
        os.makedirs(args.save_path)
    log = open(os.path.join(args.save_path, 'log_seed_{}.txt'.format(args.manualSeed)), 'w')
    print_log('save path : {}'.format(args.save_path), log)
    state = {k: v for k, v in args._get_kwargs()}
    print_log(state, log)
    print_log("Random Seed: {}".format(args.manualSeed), log)
    print_log("CS Pruning Rate: {}".format(args.prune_rate_cs), log)
    print_log("GM Pruning Rate: {}".format(args.prune_rate_gm), log)
    print_log("Layer Begin: {}".format(args.layer_begin), log)
    print_log("Layer End: {}".format(args.layer_end), log)
    print_log("Layer Inter: {}".format(args.layer_inter), log)
    print_log("Epoch prune: {}".format(args.epoch_prune), log)
    print_log("use pretrain: {}".format(args.use_pretrain), log)
    print_log("Pretrain path: {}".format(args.pretrain_path), log)

    # Init dataset
    if not os.path.isdir(args.data_path):
        os.makedirs(args.data_path)

    if args.dataset == 'cifar10':
        mean = [x / 255 for x in [125.3, 123.0, 113.9]]
        std = [x / 255 for x in [63.0, 62.1, 66.7]]
    else:
        assert False, "Unknow dataset : {}".format(args.dataset)

    train_transform = transforms.Compose(
        [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(),
         transforms.Normalize(mean, std)])
    test_transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize(mean, std)])

    if args.dataset == 'cifar10':
        train_data = dset.CIFAR10(args.data_path, train=True, transform=train_transform, download=True)
        test_data = dset.CIFAR10(args.data_path, train=False, transform=test_transform, download=True)
        num_classes = 10
    else:
        assert False, 'Do not support dataset : {}'.format(args.dataset)

    train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True,
                                               num_workers=args.workers, pin_memory=True)
    train_prune_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_prune_size, shuffle=False,
                                               num_workers=args.workers, pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=False,
                                              num_workers=args.workers, pin_memory=True)

    print_log("=> creating model '{}'".format(args.arch), log)

    # Init model, criterion, and optimizer
    net = models.__dict__[args.arch](num_classes)
    print_log("=> network :\n {}".format(net), log)

    net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))

    # define loss function (criterion) and optimizer
    criterion = torch.nn.CrossEntropyLoss()

    optimizer = torch.optim.SGD(net.parameters(), state['learning_rate'], momentum=state['momentum'],
                                weight_decay=state['decay'], nesterov=True)

    if args.use_cuda:
        net.cuda()
        criterion.cuda()

    if args.use_pretrain:
        pretrain = torch.load(args.pretrain_path)
        if args.use_state_dict:
            net.load_state_dict(pretrain['state_dict'])
        else:
            net = pretrain['state_dict']

    recorder = RecorderMeter(args.epochs)

    mdlIdx2ConvIdx = [] # module index to conv filter index
    for index1, layr in enumerate(net.modules()):
        if isinstance(layr, torch.nn.Conv2d):
            mdlIdx2ConvIdx.append(index1)

    prmIdx2ConvIdx = [] # parameter index to conv filter index
    for index2, item in enumerate(net.parameters()):
        if len(item.size()) == 4:
            prmIdx2ConvIdx.append(index2)

    # set index of last layer depending on the known architecture
    if args.arch == 'resnet20':
        args.layer_end = 54
    elif args.arch == 'resnet56':
        args.layer_end = 162
    elif args.arch == 'resnet110':
        args.layer_end = 324
    else:
        pass # unkonwn architecture, use input value

    # asymptotic schedule
    total_pruning_rate = args.prune_rate_gm + args.prune_rate_cs
    compress_rates_total, scalling_factors, compress_rates_cs, compress_rates_fpgm, e2 =\
        cmpAsymptoticSchedule(theta3=total_pruning_rate, e3=args.epochs-1, tau=args.tau, theta_cs_final=args.prune_rate_cs, scaling_attn=args.scaling_attenuation) # tau=8.
    keep_rate_cs = 1. - compress_rates_cs

    if args.use_zero_scaling:
        scalling_factors = np.zeros(scalling_factors.shape)

    m = Mask(net, train_prune_loader, mdlIdx2ConvIdx, prmIdx2ConvIdx, scalling_factors, keep_rate_cs, compress_rates_fpgm, args.max_iter_cs)
    m.set_curr_epoch(0)
    m.set_epoch_cs(args.epoch_apply_cs)
    m.init_selected_filts()
    m.init_length()

    val_acc_1, val_los_1 = validate(test_loader, net, criterion, log)

    m.model = net
    m.init_mask(keep_rate_cs[0], compress_rates_fpgm[0], scalling_factors[0])
    #    m.if_zero()
    m.do_mask()
    m.do_similar_mask()
    net = m.model
    #    m.if_zero()
    if args.use_cuda:
        net = net.cuda()
    val_acc_2, val_los_2 = validate(test_loader, net, criterion, log)

    # Main loop
    start_time = time.time()
    epoch_time = AverageMeter()

    for epoch in range(args.start_epoch, args.epochs):

        current_learning_rate = adjust_learning_rate(optimizer, epoch, args.gammas, args.schedule)

        need_hour, need_mins, need_secs = convert_secs2time(epoch_time.avg * (args.epochs - epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(need_hour, need_mins, need_secs)

        print_log(
            '\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f}]'.format(time_string(), epoch, args.epochs,
                                                                                   need_time, current_learning_rate) \
            + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False),
                                                               100 - recorder.max_accuracy(False)), log)

        # train for one epoch
        train_acc, train_los = train(train_loader, net, criterion, optimizer, epoch, log, m)
        
        # evaluate on validation set
        if epoch % args.epoch_prune == 0 or epoch == args.epochs - 1:
            m.model = net
            m.set_curr_epoch(epoch)
            # m.if_zero()
            m.init_mask(keep_rate_cs[epoch], compress_rates_fpgm[epoch], scalling_factors[epoch])
            m.do_mask()
            m.do_similar_mask()
            # m.if_zero()
            net = m.model
            if args.use_cuda:
                net = net.cuda()
            if epoch == args.epochs - 1:
                m.if_zero()

        val_acc_2, val_los_2 = validate(test_loader, net, criterion, log)

        is_best = recorder.update(epoch, train_los, train_acc, val_los_2, val_acc_2)

        save_checkpoint({
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': net,
            'recorder': recorder,
            'optimizer': optimizer.state_dict(),
        }, is_best, args.save_path, 'checkpoint.pth.tar')

        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()
        recorder.plot_curve(os.path.join(args.save_path, 'curve.png'))

    log.close()
コード例 #3
0
def main_worker(gpu, ngpus_per_node, args):
    global best_acc
    args.gpu = gpu
    assert args.gpu is not None
    print("Use GPU: {} for training".format(args.gpu))

    log = open(
        os.path.join(
            args.save_path,
            'log_seed{}{}.txt'.format(args.manualSeed,
                                      '_eval' if args.evaluate else '')), 'w')
    log = (log, args.gpu)

    net = models.__dict__[args.arch](pretrained=True)
    disable_dropout(net)
    net = to_bayesian(net, args.psi_init_range)
    net.apply(unfreeze)

    print_log("Python version : {}".format(sys.version.replace('\n', ' ')),
              log)
    print_log("PyTorch  version : {}".format(torch.__version__), log)
    print_log("CuDNN  version : {}".format(torch.backends.cudnn.version()),
              log)
    print_log(
        "Number of parameters: {}".format(
            sum([p.numel() for p in net.parameters()])), log)
    print_log(str(args), log)

    if args.distributed:
        if args.multiprocessing_distributed:
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url + ":" +
                                args.dist_port,
                                world_size=args.world_size,
                                rank=args.rank)
        torch.cuda.set_device(args.gpu)
        net.cuda(args.gpu)
        args.batch_size = int(args.batch_size / ngpus_per_node)
        net = torch.nn.parallel.DistributedDataParallel(net,
                                                        device_ids=[args.gpu])
    else:
        torch.cuda.set_device(args.gpu)
        net = net.cuda(args.gpu)

    criterion = torch.nn.CrossEntropyLoss().cuda(args.gpu)

    mus, psis = [], []
    for name, param in net.named_parameters():
        if 'psi' in name: psis.append(param)
        else: mus.append(param)
    mu_optimizer = SGD(mus,
                       args.learning_rate,
                       args.momentum,
                       weight_decay=args.decay,
                       nesterov=(args.momentum > 0.0))

    psi_optimizer = PsiSGD(psis,
                           args.learning_rate,
                           args.momentum,
                           weight_decay=args.decay,
                           nesterov=(args.momentum > 0.0))

    recorder = RecorderMeter(args.epochs)
    if args.resume:
        if args.resume == 'auto':
            args.resume = os.path.join(args.save_path, 'checkpoint.pth.tar')
        if os.path.isfile(args.resume):
            print_log("=> loading checkpoint '{}'".format(args.resume), log)
            checkpoint = torch.load(args.resume,
                                    map_location='cuda:{}'.format(args.gpu))
            recorder = checkpoint['recorder']
            recorder.refresh(args.epochs)
            args.start_epoch = checkpoint['epoch']
            net.load_state_dict(
                checkpoint['state_dict'] if args.distributed else {
                    k.replace('module.', ''): v
                    for k, v in checkpoint['state_dict'].items()
                })
            mu_optimizer.load_state_dict(checkpoint['mu_optimizer'])
            psi_optimizer.load_state_dict(checkpoint['psi_optimizer'])
            best_acc = recorder.max_accuracy(False)
            print_log(
                "=> loaded checkpoint '{}' accuracy={} (epoch {})".format(
                    args.resume, best_acc, checkpoint['epoch']), log)
        else:
            print_log("=> no checkpoint found at '{}'".format(args.resume),
                      log)
    else:
        print_log("=> do not use any checkpoint for the model", log)

    cudnn.benchmark = True

    train_loader, ood_train_loader, test_loader, adv_loader, \
        fake_loader, adv_loader2 = load_dataset_ft(args)
    psi_optimizer.num_data = len(train_loader.dataset)

    if args.evaluate:
        evaluate(test_loader, adv_loader, fake_loader, adv_loader2, net,
                 criterion, args, log, 20, 100)
        return

    start_time = time.time()
    epoch_time = AverageMeter()
    train_los = -1

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_loader.sampler.set_epoch(epoch)
            ood_train_loader.sampler.set_epoch(epoch)
        cur_lr, cur_slr = adjust_learning_rate(mu_optimizer, psi_optimizer,
                                               epoch, args)

        need_hour, need_mins, need_secs = convert_secs2time(
            epoch_time.avg * (args.epochs - epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(
            need_hour, need_mins, need_secs)

        print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f} {:6.4f}]'.format(
                                    time_string(), epoch, args.epochs, need_time, cur_lr, cur_slr) \
                    + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False), 100-recorder.max_accuracy(False)), log)

        train_acc, train_los = train(train_loader, ood_train_loader, net,
                                     criterion, mu_optimizer, psi_optimizer,
                                     epoch, args, log)
        val_acc, val_los = 0, 0
        recorder.update(epoch, train_los, train_acc, val_acc, val_los)

        is_best = False
        if val_acc > best_acc:
            is_best = True
            best_acc = val_acc

        if args.gpu == 0:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': net.state_dict(),
                    'recorder': recorder,
                    'mu_optimizer': mu_optimizer.state_dict(),
                    'psi_optimizer': psi_optimizer.state_dict(),
                }, False, args.save_path, 'checkpoint.pth.tar')

        epoch_time.update(time.time() - start_time)
        start_time = time.time()
        recorder.plot_curve(os.path.join(args.save_path, 'log.png'))

    evaluate(test_loader, adv_loader, fake_loader, adv_loader2, net, criterion,
             args, log, 20, 100)

    log[0].close()
コード例 #4
0
def main():
    # Init logger
    if not os.path.isdir(args.save_path):
        os.makedirs(args.save_path)
    log = open(
        os.path.join(args.save_path,
                     'log_seed_{}.txt'.format(args.manualSeed)), 'w')
    print_log('save path : {}'.format(args.save_path), log)
    state = {k: v for k, v in args._get_kwargs()}
    print_log(state, log)
    print_log("Random Seed: {}".format(args.manualSeed), log)
    print_log("python version : {}".format(sys.version.replace('\n', ' ')),
              log)
    print_log("torch  version : {}".format(torch.__version__), log)
    print_log("cudnn  version : {}".format(torch.backends.cudnn.version()),
              log)
    print_log("Compress Rate: {}".format(args.rate), log)
    print_log("Layer Begin: {}".format(args.layer_begin), log)
    print_log("Layer End: {}".format(args.layer_end), log)
    print_log("Layer Inter: {}".format(args.layer_inter), log)
    print_log("Epoch prune: {}".format(args.epoch_prune), log)
    # Init dataset
    if not os.path.isdir(args.data_path):
        os.makedirs(args.data_path)

    if args.dataset == 'cifar10':
        mean = [x / 255 for x in [125.3, 123.0, 113.9]]
        std = [x / 255 for x in [63.0, 62.1, 66.7]]
    elif args.dataset == 'cifar100':
        mean = [x / 255 for x in [129.3, 124.1, 112.4]]
        std = [x / 255 for x in [68.2, 65.4, 70.4]]
    else:
        assert False, "Unknow dataset : {}".format(args.dataset)

    train_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32, padding=4),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])
    test_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize(mean, std)])

    if args.dataset == 'cifar10':
        train_data = dset.CIFAR10(args.data_path,
                                  train=True,
                                  transform=train_transform,
                                  download=True)
        test_data = dset.CIFAR10(args.data_path,
                                 train=False,
                                 transform=test_transform,
                                 download=True)
        num_classes = 10
    elif args.dataset == 'cifar100':
        train_data = dset.CIFAR100(args.data_path,
                                   train=True,
                                   transform=train_transform,
                                   download=True)
        test_data = dset.CIFAR100(args.data_path,
                                  train=False,
                                  transform=test_transform,
                                  download=True)
        num_classes = 100
    elif args.dataset == 'svhn':
        train_data = dset.SVHN(args.data_path,
                               split='train',
                               transform=train_transform,
                               download=True)
        test_data = dset.SVHN(args.data_path,
                              split='test',
                              transform=test_transform,
                              download=True)
        num_classes = 10
    elif args.dataset == 'stl10':
        train_data = dset.STL10(args.data_path,
                                split='train',
                                transform=train_transform,
                                download=True)
        test_data = dset.STL10(args.data_path,
                               split='test',
                               transform=test_transform,
                               download=True)
        num_classes = 10
    elif args.dataset == 'imagenet':
        assert False, 'Do not finish imagenet code'
    else:
        assert False, 'Do not support dataset : {}'.format(args.dataset)

    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=args.workers,
                                              pin_memory=True)

    print_log("=> creating model '{}'".format(args.arch), log)
    # Init model, criterion, and optimizer
    net = models.__dict__[args.arch](num_classes)
    net_ref = models.__dict__[args.arch](num_classes)
    print_log("=> network :\n {}".format(net), log)

    net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))
    net_ref = torch.nn.DataParallel(net_ref, device_ids=list(range(args.ngpu)))

    # define loss function (criterion) and optimizer
    criterion = torch.nn.CrossEntropyLoss()

    optimizer = torch.optim.SGD(net.parameters(),
                                state['learning_rate'],
                                momentum=state['momentum'],
                                weight_decay=state['decay'],
                                nesterov=True)

    if args.use_cuda:
        net.cuda()
        criterion.cuda()

    recorder = RecorderMeter(args.epochs)
    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print_log("=> loading checkpoint '{}'".format(args.resume), log)
            checkpoint = torch.load(args.resume)
            net_ref = checkpoint['state_dict']
            print_log(
                "=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint['epoch']), log)
        else:
            print_log("=> no checkpoint found at '{}'".format(args.resume),
                      log)
    else:
        print_log(
            "=> do not use any checkpoint for {} model".format(args.arch), log)

    ###################################################################################################################
    for m, m_ref in zip(net.modules(), net_ref.modules()):
        if isinstance(m, nn.Conv2d):
            weight_copy = m_ref.weight.data.abs().clone()
            mask = weight_copy.gt(0).float().cuda()
            n = mask.sum() / float(m.in_channels)
            m.weight.data.normal_(0, math.sqrt(2. / n))
            m.weight.data.mul_(mask)
    ###################################################################################################################

    if args.evaluate:
        time1 = time.time()
        validate(test_loader, net, criterion, log)
        time2 = time.time()
        print('function took %0.3f ms' % ((time2 - time1) * 1000.0))
        return

    m = Mask(net)

    m.init_length()

    comp_rate = args.rate
    print("-" * 10 + "one epoch begin" + "-" * 10)
    print("the compression rate now is %f" % comp_rate)

    val_acc_1, val_los_1 = validate(test_loader, net_ref, criterion, log)

    print(" accu before is: %.3f %%" % val_acc_1)

    if args.use_cuda:
        net = net.cuda()
    val_acc_2, val_los_2 = validate(test_loader, net, criterion, log)
    print(" accu after is: %s %%" % val_acc_2)

    # Main loop
    start_time = time.time()
    epoch_time = AverageMeter()
    for epoch in range(args.start_epoch, args.epochs):
        current_learning_rate = adjust_learning_rate(optimizer, epoch,
                                                     args.gammas,
                                                     args.schedule)

        need_hour, need_mins, need_secs = convert_secs2time(
            epoch_time.avg * (args.epochs - epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(
            need_hour, need_mins, need_secs)

        print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f}]'.format(time_string(), epoch, args.epochs, need_time, current_learning_rate) \
                                + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False), 100-recorder.max_accuracy(False)), log)

        num_parameters = get_conv_zero_param(net)
        print_log('Zero parameters: {}'.format(num_parameters), log)
        num_parameters = sum([param.nelement() for param in net.parameters()])
        print_log('Parameters: {}'.format(num_parameters), log)

        # train for one epoch
        train_acc, train_los = train(train_loader, net, criterion, optimizer,
                                     epoch, log)

        # evaluate on validation set
        val_acc_1, val_los_1 = validate(test_loader, net, criterion, log)

        is_best = recorder.update(epoch, train_los, train_acc, val_los_2,
                                  val_acc_2)

        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': net,
                'recorder': recorder,
                'optimizer': optimizer.state_dict(),
            }, is_best, args.save_path, 'checkpoint.pth.tar')

        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()

    log.close()
コード例 #5
0
ファイル: main.py プロジェクト: syt2/mnist
def main(arch=None):

    if not os.path.isdir(args.save_path):
        os.makedirs(args.save_path)
    log = open(os.path.join(args.save_path, '{}.txt'.format('log')), 'w')

    if args.tensorboard is None:
        writer = SummaryWriter(args.save_path)
    else:
        writer = SummaryWriter(args.tensorboard)

    print_log('save path : {}'.format(args.save_path), log)
    state = {k: v for k, v in args._get_kwargs()}
    print_log(state, log)
    print_log("Random Seed: {}".format(args.manualSeed), log)
    print_log("use cuda: {}".format(args.use_cuda), log)
    print_log("python version : {}".format(sys.version.replace('\n', ' ')),
              log)
    print_log("torch  version : {}".format(torch.__version__), log)
    print_log("cudnn  version : {}".format(torch.backends.cudnn.version()),
              log)

    # Init data loader
    train_loader = dataset.mnistDataLoader(args.train_dir, True,
                                           args.train_batch_size, True,
                                           args.workers)
    test_loader = dataset.mnistDataLoader(args.test_dir, False,
                                          args.test_batch_size, False,
                                          args.workers)
    num_classes = 10
    input_size = (1, 28, 28)
    net = arch(num_classes)
    print_log("=> network:\n {}".format(net), log)
    summary = model_summary(net, input_size)
    print_log(summary, log)

    writer.add_graph(net, torch.rand([1, 1, 28, 28]))

    if args.ngpu > 1:
        net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))

    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(),
                                state['learning_rate'],
                                momentum=state['momentum'],
                                weight_decay=state['decay'],
                                nesterov=True)

    if args.use_cuda:
        net.cuda()
        criterion.cuda()

    recorder = RecorderMeter(args.epochs)

    if args.resume:
        if os.path.isfile(args.resume):
            print_log("=> loading checkpoint '{}'".format(args.resume), log)
            checkpoint = torch.load(args.resume)
            recorder = checkpoint['recorder']
            args.start_epoch = checkpoint['epoch']
            net.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print_log(
                "=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint['epoch']), log)
        else:
            print_log("=> no checkpoint found at '{}'".format(args.resume),
                      log)
    else:
        print_log("=> not use any checkpoint for model", log)

    if args.evaluate:
        checkpoint = torch.load(args.save_path + '/model_best.pth.tar')
        net.load_state_dict(checkpoint['state_dict'])
        time1 = time.time()
        validate(test_loader, net, criterion, log, writer, embedding=True)
        time2 = time.time()
        print('validate function took %0.3f ms' % ((time2 - time1) * 1000.0))
        return

    start_time = time.time()
    epoch_time = AverageMeter()
    for epoch in range(args.start_epoch, args.epochs):
        current_learning_rate = adjust_learning_rate(args.learning_rate,
                                                     optimizer, epoch,
                                                     args.gammas,
                                                     args.schedule)
        need_hour, need_mins, need_secs = convert_secs2time(
            epoch_time.avg * (args.epochs - epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(
            need_hour, need_mins, need_secs)
        print_log(
            '\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f}]'.format(time_string(), epoch, args.epochs,
                                                                                   need_time, current_learning_rate) \
            + ' [Best : Accuracy={:.2f}]'.format(recorder.max_accuracy(False)), log)
        train_acc, train_los = train(train_loader, net, criterion, optimizer,
                                     log)
        val_acc, val_los = validate(test_loader, net, criterion, log)

        is_best = recorder.update(epoch, train_los, train_acc, val_los,
                                  val_acc)
        if is_best:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': net.state_dict(),
                    'recorder': recorder,
                    'optimizer': optimizer.state_dict(),
                }, is_best, args.save_path, 'checkpoint.pth.tar')
            print('save ckpt done!')

        writer.add_scalar('Train/loss', train_los, epoch)
        writer.add_scalar('Train/acc', train_acc, epoch)
        writer.add_scalar('Test/acc', val_acc, epoch)
        for name, param in net.named_parameters():
            writer.add_histogram(name, param.clone().cpu().data.numpy(), epoch)

        epoch_time.update(time.time() - start_time)
        start_time = time.time()

    save_checkpoint(
        {
            'state_dict': net.state_dict(),
            'recorder': recorder,
            'optimizer': optimizer.state_dict(),
        }, is_best, args.save_path, 'model.pth.tar')
    print('save model done!')

    checkpoint = torch.load(args.save_path + '/model_best.pth.tar')
    net.load_state_dict(checkpoint['state_dict'])
    time1 = time.time()
    validate(test_loader, net, criterion, log, writer, embedding=True)
    time2 = time.time()
    print_log('validate function took %0.3f ms' % ((time2 - time1) * 1000.0),
              log)

    log.close()
    writer.close()
コード例 #6
0
def main():
    # Init logger
    if not os.path.isdir(args.save_path):
        os.makedirs(args.save_path)

    # used for file names, etc
    time_stamp = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    log = open(
        os.path.join(
            args.save_path,
            'log_seed_{0}_{1}.txt'.format(args.manualSeed, time_stamp)), 'w')
    print_log('save path : {}'.format(args.save_path), log)
    state = {k: v for k, v in args._get_kwargs()}
    print_log(state, log)
    print_log("Random Seed: {}".format(args.manualSeed), log)
    print_log("python version : {}".format(sys.version.replace('\n', ' ')),
              log)
    print_log("torch  version : {}".format(torch.__version__), log)
    print_log("cudnn  version : {}".format(torch.backends.cudnn.version()),
              log)

    # Init dataset
    if not os.path.isdir(args.data_path):
        os.makedirs(args.data_path)

    if args.dataset == 'cifar10':
        mean = [x / 255 for x in [125.3, 123.0, 113.9]]
        std = [x / 255 for x in [63.0, 62.1, 66.7]]
    elif args.dataset == 'cifar100':
        mean = [x / 255 for x in [129.3, 124.1, 112.4]]
        std = [x / 255 for x in [68.2, 65.4, 70.4]]
    else:
        assert False, "Unknow dataset : {}".format(args.dataset)

    writer = SummaryWriter()

    #   # Data transforms
    # mean = [0.5071, 0.4867, 0.4408]
    # std = [0.2675, 0.2565, 0.2761]

    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])
    #[transforms.CenterCrop(32), transforms.ToTensor(),
    # transforms.Normalize(mean, std)])
    #)
    test_transform = transforms.Compose([
        transforms.CenterCrop(32),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])

    if args.dataset == 'cifar10':
        train_data = dset.CIFAR10(args.data_path,
                                  train=True,
                                  transform=train_transform,
                                  download=True)
        test_data = dset.CIFAR10(args.data_path,
                                 train=False,
                                 transform=test_transform,
                                 download=True)
        num_classes = 10
    elif args.dataset == 'cifar100':
        train_data = dset.CIFAR100(args.data_path,
                                   train=True,
                                   transform=train_transform,
                                   download=True)
        test_data = dset.CIFAR100(args.data_path,
                                  train=False,
                                  transform=test_transform,
                                  download=True)
        num_classes = 100
    elif args.dataset == 'imagenet':
        assert False, 'Did not finish imagenet code'
    else:
        assert False, 'Does not support dataset : {}'.format(args.dataset)

    #step_sizes = 2500
    step_sizes = args.alinit
    indices = [l for l in range(0, 50000)]

    annot_indices = [
    ]  # indices which are added to the training pool, list as we store it for all steps
    unannot_indices = [
        indices
    ]  # indices which have not been added to the training pool

    selections = random.sample(range(0, len(unannot_indices[-1])), step_sizes)
    temp = list(np.asarray(unannot_indices[-1])[selections])
    annot_indices.append(temp)

    unannot_indices.append(
        list(set(unannot_indices[-1]) - set(annot_indices[-1])))

    labelled_dset = torch.utils.data.Subset(train_data, annot_indices[-1])
    unlabelled_dset = torch.utils.data.Subset(train_data, unannot_indices[-1])

    #train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True,
    #                       num_workers=args.workers, pin_memory=True)
    labelled_loader = torch.utils.data.DataLoader(labelled_dset,
                                                  batch_size=args.batch_size,
                                                  shuffle=True,
                                                  num_workers=args.workers,
                                                  pin_memory=True)

    #unlabelled_loader = torch.utils.data.DataLoader(unlabelled_dset, batch_size=args.batch_size, shuffle=True,
    #num_workers=args.workers, pin_memory=True)

    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=args.workers,
                                              pin_memory=True)

    print_log("=> creating model '{}'".format(args.arch), log)
    # Init model, criterion, and optimizer
    net = models.__dict__[args.arch](num_classes)
    #torch.save(net, 'net.pth')
    #init_net = torch.load('net.pth')
    #net.load_my_state_dict(init_net.state_dict())
    print_log("=> network :\n {}".format(net), log)

    net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))

    # define loss function (criterion) and optimizer
    criterion = torch.nn.CrossEntropyLoss()

    #optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=0.005, nesterov=False)
    optimizer = torch.optim.Adadelta(
        net.parameters(),
        lr=0.1,
        rho=0.9,
        eps=1e-3,  # momentum=state['momentum'],
        weight_decay=0.001)

    print_log("=> Seed '{}'".format(args.manualSeed), log)
    print_log("=> dataset mean and std '{} - {}'".format(str(mean), str(std)),
              log)

    states_settings = {'optimizer': optimizer.state_dict()}

    print_log("=> optimizer '{}'".format(states_settings), log)
    # 50k,95k,153k,195k,220k
    milestones = [100, 190, 306, 390, 440, 540]
    scheduler = lr_scheduler.MultiStepLR(optimizer, milestones, gamma=0.1)

    if args.use_cuda:
        net.cuda()
        criterion.cuda()

    recorder = RecorderMeter(args.epochs)
    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print_log("=> loading checkpoint '{}'".format(args.resume), log)
            checkpoint = torch.load(args.resume)
            recorder = checkpoint['recorder']
            args.start_epoch = checkpoint['epoch']
            net.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print_log(
                "=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint['epoch']), log)
        else:
            print_log("=> no checkpoint found at '{}'".format(args.resume),
                      log)
    else:
        print_log(
            "=> did not use any checkpoint for {} model".format(args.arch),
            log)

    if args.evaluate:
        validate(test_loader, net, criterion, log)
        return

    # Main loop
    start_time = time.time()
    epoch_time = AverageMeter()

    al_steps = int(50000 / args.alinit)

    curr_al_step = 0
    dump_data = []

    for (al_step, epoch) in [(a, b) for a in range(al_steps)
                             for b in range(args.start_epoch, args.epochs)]:
        print(" Current AL_step and epoch " + str((al_step, epoch)))
        if (al_step != curr_al_step):

            #These return scores of datapoints in unlabelled dataset according to their indices
            #indices of the data points(w.r.t to the original indexing from 1 to 50000) in the
            #unlabelled dataset
            curr_al_step = al_step
            #Resetting the learning rate scheduler
            scheduler = lr_scheduler.MultiStepLR(optimizer,
                                                 milestones,
                                                 gamma=0.1)

            scores_unlabelled = score(unlabelled_dset, net, criterion)
            indices_sorted = np.argsort(scores_unlabelled)

            #Greedy Sampling
            temp_selections = indices_sorted[-1 * args.alinit:]
            selections = np.asarray(list(
                unlabelled_dset.indices))[temp_selections].tolist()

            annot_indices.append(selections)

            unannot_indices.append(
                set(unannot_indices[-1]) - set(annot_indices[-1]))

            labelled_dset = torch.utils.data.Subset(train_data,
                                                    annot_indices[-1])
            labelled_loader = torch.utils.data.DataLoader(
                labelled_dset,
                batch_size=args.batch_size,
                shuffle=True,
                num_workers=args.workers,
                pin_memory=True)
            unlabelled_dset = torch.utils.data.Subset(train_data,
                                                      unannot_indices[-1])

            indices_data = [annot_indices, unannot_indices]
            filehandler = open("indices.pickle", "wb")
            pickle.dump(indices_data, filehandler)
            filehandler.close()

        #current_learning_rate = adjust_learning_rate(optimizer, epoch, args.gammas, args.schedule)
        current_learning_rate = float(scheduler.get_lr()[-1])
        #print('lr:',current_learning_rate)

        scheduler.step()

        #adjust_learning_rate(optimizer, epoch)

        need_hour, need_mins, need_secs = convert_secs2time(
            epoch_time.avg * (args.epochs - epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(
            need_hour, need_mins, need_secs)

        print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:.6f}]'.format(time_string(), epoch, args.epochs, need_time, current_learning_rate) \
                    + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False), 100-recorder.max_accuracy(False)), log)

        # train for one epoch
        #train_acc, train_los = train(train_loader, net, criterion, optimizer, epoch, log)
        train_acc, train_los = train(labelled_loader, net, criterion,
                                     optimizer, epoch, log)

        # evaluate on validation set
        #val_acc,   val_los   = extract_features(test_loader, net, criterion, log)
        val_acc, val_los = validate(test_loader, net, criterion, log)
        is_best = recorder.update(epoch, train_los, train_acc, val_los,
                                  val_acc)

        dump_data.append(([al_step, epoch], [train_acc,
                                             train_los], [val_acc, val_los]))
        if (epoch % 50 == 0):
            filehandler = open("accuracy.pickle", "wb")
            pickle.dump(dump_data, filehandler)
            filehandler.close()

        if epoch == 180:
            save_checkpoint(
                {
                    'epoch': epoch,
                    'arch': args.arch,
                    'state_dict': net.state_dict(),
                    'recorder': recorder,
                    'optimizer': optimizer.state_dict(),
                }, False, args.save_path,
                'checkpoint_{0}_{1}.pth.tar'.format(epoch,
                                                    time_stamp), time_stamp)

        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': net.state_dict(),
                'recorder': recorder,
                'optimizer': optimizer.state_dict(),
            }, is_best, args.save_path,
            'checkpoint_{0}.pth.tar'.format(time_stamp), time_stamp)

        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()
        recorder.plot_curve(
            os.path.join(
                args.save_path,
                'training_plot_{0}_{1}.png'.format(args.manualSeed,
                                                   time_stamp)))

    writer.close()
    log.close()
コード例 #7
0
ファイル: main_dxy.py プロジェクト: Henley13/imagenet-fast
def main():
  if not os.path.isdir(args.save_path): os.makedirs(args.save_path)
  log = open(os.path.join(args.save_path, 'log_seed_{}.txt'.format(args.manualSeed)), 'w')
  print_log('save path : {}'.format(args.save_path), log)
  state = {k: v for k, v in args._get_kwargs()}
  print_log(state, log)
  print_log("Random Seed: {}".format(args.manualSeed), log)
  print_log("python version : {}".format(sys.version.replace('\n', ' ')), log)
  print_log("torch  version : {}".format(torch.__version__), log)
  print_log("cudnn  version : {}".format(torch.backends.cudnn.version()), log)

  # Init dataset
  if not os.path.isdir(args.data_path):
    os.makedirs(args.data_path)

  if args.dataset == 'cifar10':
    mean = [x / 255 for x in [125.3, 123.0, 113.9]]
    std = [x / 255 for x in [63.0, 62.1, 66.7]]
  elif args.dataset == 'cifar100':
    mean = [x / 255 for x in [129.3, 124.1, 112.4]]
    std = [x / 255 for x in [68.2, 65.4, 70.4]]
  else:
    assert False, "Unknow dataset : {}".format(args.dataset)

  train_transform = transforms.Compose(
    [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(),
     transforms.Normalize(mean, std)])
  test_transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize(mean, std)])

  if args.dataset == 'cifar10':
    train_data = dset.CIFAR10(args.data_path, train=True, transform=train_transform, download=True)
    test_data = dset.CIFAR10(args.data_path, train=False, transform=test_transform, download=True)
    num_classes = 10
  elif args.dataset == 'cifar100':
    train_data = dset.CIFAR100(args.data_path, train=True, transform=train_transform, download=True)
    test_data = dset.CIFAR100(args.data_path, train=False, transform=test_transform, download=True)
    num_classes = 100
  elif args.dataset == 'svhn':
    train_data = dset.SVHN(args.data_path, split='train', transform=train_transform, download=True)
    test_data = dset.SVHN(args.data_path, split='test', transform=test_transform, download=True)
    num_classes = 10
  elif args.dataset == 'stl10':
    train_data = dset.STL10(args.data_path, split='train', transform=train_transform, download=True)
    test_data = dset.STL10(args.data_path, split='test', transform=test_transform, download=True)
    num_classes = 10
  elif args.dataset == 'imagenet':
    assert False, 'Do not finish imagenet code'
  else:
    assert False, 'Do not support dataset : {}'.format(args.dataset)

  train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True,
                         num_workers=args.workers, pin_memory=True)
  test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=False,
                        num_workers=args.workers, pin_memory=True)

  # Init model, criterion, and optimizer
  #net = models.__dict__[args.arch](num_classes).cuda()
  net = SENet34()

  # define loss function (criterion) and optimizer
  criterion = F.nll_loss
  optimizer = torch.optim.SGD(net.parameters(), state['learning_rate'], momentum=state['momentum'],
                weight_decay=state['decay'], nesterov=True)

  if args.use_cuda: net.cuda()

  recorder = RecorderMeter(args.epochs)
  # optionally resume from a checkpoint
  if args.resume:
    if os.path.isfile(args.resume):
      print_log("=> loading checkpoint '{}'".format(args.resume), log)
      checkpoint = torch.load(args.resume)
      recorder = checkpoint['recorder']
      args.start_epoch = checkpoint['epoch']
      net.load_state_dict(checkpoint['state_dict'])
      optimizer.load_state_dict(checkpoint['optimizer'])
      print_log("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch']), log)
    else:
      print_log("=> no checkpoint found at '{}'".format(args.resume), log)
  else:
    print_log("=> do not use any checkpoint for model", log)

  if args.evaluate:
    validate(test_loader, net, criterion, log)
    return

  # Main loop
  start_time = time.time()
  epoch_time = AverageMeter()
  for epoch in range(args.start_epoch, args.epochs):
    current_learning_rate = adjust_learning_rate(optimizer, epoch, args.gammas, args.schedule)

    need_hour, need_mins, need_secs = convert_secs2time(epoch_time.avg * (args.epochs-epoch))
    need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(need_hour, need_mins, need_secs)

    print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f}]'.format(time_string(), epoch, args.epochs, need_time, current_learning_rate) \
                + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False), 100-recorder.max_accuracy(False)), log)

    # train for one epoch
    train_acc, train_los = train(train_loader, net, criterion, optimizer, epoch, log)

    # evaluate on validation set
    val_acc,   val_los   = validate(test_loader, net, criterion, log)
    is_best = recorder.update(epoch, train_los, train_acc, val_los, val_acc)

    save_checkpoint({
      'epoch': epoch + 1,
      'state_dict': net.state_dict(),
      'recorder': recorder,
      'optimizer' : optimizer.state_dict(),
    }, is_best, args.save_path, 'checkpoint.pth.tar')

    # measure elapsed time
    epoch_time.update(time.time() - start_time)
    start_time = time.time()
    recorder.plot_curve( os.path.join(args.save_path, 'curve.png') )

  log.close()
コード例 #8
0
def main():
    # Init logger
    if not os.path.isdir(args.save_path):
        os.makedirs(args.save_path)
    log = open(os.path.join(args.save_path, 'log_seed_{}.txt'.format(args.manualSeed)), 'w')
    print_log('save path : {}'.format(args.save_path), log)
    state = {k: v for k, v in args._get_kwargs()}
    print_log(state, log)
    print_log("Random Seed: {}".format(args.manualSeed), log)

    # Init dataset
    if not os.path.isdir(args.data_path):
        os.makedirs(args.data_path)

    if args.dataset == 'cifar10':
        mean = [x / 255 for x in [125.3, 123.0, 113.9]]
        std = [x / 255 for x in [63.0, 62.1, 66.7]]
    elif args.dataset == 'cifar100':
        mean = [x / 255 for x in [129.3, 124.1, 112.4]]
        std = [x / 255 for x in [68.2, 65.4, 70.4]]
    else:
        assert False, "Unknow dataset : {}".format(args.dataset)

    train_transform = transforms.Compose(
        [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(),
         transforms.Normalize(mean, std)])
    test_transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize(mean, std)])

    if args.dataset == 'cifar10':
        train_data = dset.CIFAR10(args.data_path, train=True, transform=train_transform, download=True)
        test_data = dset.CIFAR10(args.data_path, train=False, transform=test_transform, download=True)
        num_classes = 10
    elif args.dataset == 'cifar100':
        train_data = dset.CIFAR100(args.data_path, train=True, transform=train_transform, download=True)
        test_data = dset.CIFAR100(args.data_path, train=False, transform=test_transform, download=True)
        num_classes = 100

    train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, k=args.k, shuffle=True,
                                                 num_workers=args.workers, pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, k=args.k, shuffle=False,
                                                num_workers=args.workers, pin_memory=True)

    print_log("=> creating model '{}'".format(args.arch), log)
    # Init model, criterion, and optimizer
    #net = vgg(dataset=args.dataset, depth=19)
    net = models.__dict__[args.arch](num_classes)
    
    net_small = models.__dict__[args.arch_small](num_classes)

    print_log("=>small network :\n {}".format(net_small), log)

    check=torch.load('baseline_newres/cifar10_resnet110/best.resnet110.pth.tar')
    #check=torch.load('finetune/cifar10_resnet56_0.7_f/best.resnet56.pth.tar')
    net.load_state_dict(check['state_dict'])

    # define loss function (criterion) and optimizer
    criterion = torch.nn.CrossEntropyLoss()
    if args.loss == 'L1':
        criterion_s = torch.nn.L1Loss()
    elif args.loss == 'MSE':
        criterion_s = torch.nn.MSELoss()
    elif args.loss == 'SmoothL1':
        criterion_s = torch.nn.SmoothL1Loss()
    #criterion_s = torch.nn.KLDivLoss()

    optimizer = torch.optim.SGD(net_small.parameters(), state['learning_rate'], momentum=state['momentum'],weight_decay=state['decay'], nesterov=True)
#    optimizer = SGD_HT(net.parameters(), state['learning_rate'], momentum=state['momentum'],weight_decay=state['decay'], nesterov=True, HTrate=state['HTrate'])
#    optimizer = HSG(net.parameters(), state['learning_rate'], momentum=state['momentum'],weight_decay=state['decay'], nesterov=True)
#    optimizer = AHSG(net.parameters(), state['learning_rate'], momentum=state['momentum'],weight_decay=state['decay'], nesterov=True, v=state['learning_rate'])
#    optimizer = HSG_HT(net.parameters(), state['learning_rate'], momentum=state['momentum'],weight_decay=state['decay'], nesterov=True, HTrate=state['HTrate'])    
#    optimizer = AHSG_HT(net_small.parameters(), state['learning_rate'], momentum=state['momentum'],weight_decay=state['decay'], nesterov=True, HTrate=state['HTrate'], v=state['v'])
    if args.use_cuda:
        net.cuda()
        net_small.cuda()
        criterion.cuda()
        criterion_s.cuda()
        
#    L1_norm_resnet(net_small.parameters(),args.HTrate)

#    i=0
#    for name_new, param_new in net.named_parameters():
#        print("i: {}, name_new: {}, param_new.size: {}".format(i, name_new, param_new.size()))
#        i=i+1

    recorder = RecorderMeter(args.epochs)
    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print_log("=> loading checkpoint '{}'".format(args.resume), log)
            checkpoint = torch.load(args.resume)
            recorder = checkpoint['recorder']
            if args.use_state_dict:
                net_small.load_state_dict(checkpoint['state_dict'])
            else:
                net_small = checkpoint['state_dict']
                
            optimizer.load_state_dict(checkpoint['optimizer'])
            print_log("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch']), log)
        else:
            print_log("=> no checkpoint found at '{}'".format(args.resume), log)
    else:
        print_log("=> do not use any checkpoint for {} model".format(args.arch), log)

    if args.evaluate:
        time1 = time.time()
        validate(test_loader, net_small, criterion, log)
        time2 = time.time()
        print ('function took %0.3f ms' % ((time2-time1)*1000.0))
        return

    validate_net(test_loader, net, criterion, log)

    for name_nor, param_nor in net.named_parameters():
        if name_nor == 'classifier2.weight':
            param_w_final = param_nor           
        elif name_nor == 'classifier2.bias':
            param_b_final = param_nor          

    for name_nor, param_nor in net_small.named_parameters():
        if name_nor == 'classifier2.weight':
            param_nor.data = param_w_final 
        elif name_nor == 'classifier2.bias':
            param_nor.data = param_b_final
    
    filename = os.path.join(args.save_path, 'checkpoint.{:}.pth.tar'.format(args.arch))
    bestname = os.path.join(args.save_path, 'best.{:}.pth.tar'.format(args.arch))
    # Main loop
    start_time = time.time()
    epoch_time = AverageMeter()
    
    best_prec1=0.
    for epoch in range(args.epochs):
        current_learning_rate = adjust_learning_rate(optimizer, epoch, args.gammas, args.schedule)
        #current_v = adjust_v(optimizer, epoch, args.gammas, args.schedule)

        need_hour, need_mins, need_secs = convert_secs2time(epoch_time.avg * (args.epochs-epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(need_hour, need_mins, need_secs)

        print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f}] [v={:6.4f}]'.format(time_string(), epoch, args.epochs, need_time, current_learning_rate, args.v) \
                                + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False), 100-recorder.max_accuracy(False)), log)

        # train for one epoch
        train_acc, train_los = train(train_loader, net, net_small, criterion, criterion_s, optimizer, epoch, log)

        # evaluate on validation set
        val_acc_1,   val_los_1   = validate(test_loader, net_small, criterion, log)

        is_best = recorder.update(epoch, train_los, train_acc, val_los_1, val_acc_1)
        best_prec1 = max(val_acc_1,best_prec1)


        '''save_checkpoint({
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': net,
            'recorder': recorder,
            'optimizer' : optimizer.state_dict(),
        }, is_best, args.save_path, 'checkpoint.pth.tar')'''
        save_checkpoint({
            'epoch': epoch + 1,
            'arch': args.arch_small,
            'state_dict': net_small.state_dict(),
            'best_prec1': best_prec1,
            'optimizer': optimizer.state_dict(),
        }, is_best, filename, bestname)
        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()
    #torch.save(net_small.state_dict(),os.path.join(args.save_path,"cifar10_resnet110_pretrained_parameters.pth"))


    log.close()
コード例 #9
0
ファイル: main.py プロジェクト: jozhang97/WaveApp
def main():
  # Init logger
  if not os.path.isdir(args.save_path):
    os.makedirs(args.save_path)
  log = open(os.path.join(args.save_path, 'log_seed_{}.txt'.format(args.manualSeed)), 'w')
  print_log('save path : {}'.format(args.save_path), log)
  state = {k: v for k, v in args._get_kwargs()}
  print_log(state, log)
  print_log("Random Seed: {}".format(args.manualSeed), log)
  print_log("python version : {}".format(sys.version.replace('\n', ' ')), log)
  print_log("torch  version : {}".format(torch.__version__), log)
  print_log("cudnn  version : {}".format(torch.backends.cudnn.version()), log)

  # Data loading code
  # Any other preprocessings? http://pytorch.org/audio/transforms.html
  sample_length = 10000
  scale = transforms.Scale()
  padtrim = transforms.PadTrim(sample_length)
  downmix = transforms.DownmixMono()
  transforms_audio = transforms.Compose([
    scale, padtrim, downmix
  ])

  if not os.path.isdir(args.data_path):
    os.makedirs(args.data_path)
  train_dir = os.path.join(args.data_path, 'train')
  val_dir = os.path.join(args.data_path, 'val')

  #Choose dataset to use
  if args.dataset == 'arctic':
    # TODO No ImageFolder equivalent for audio. Need to create a Dataset manually
    train_dataset = Arctic(train_dir, transform=transforms_audio, download=True)
    val_dataset = Arctic(val_dir, transform=transforms_audio, download=True)
    num_classes = 4
  elif args.dataset == 'vctk':
    train_dataset = dset.VCTK(train_dir, transform=transforms_audio, download=True)
    val_dataset = dset.VCTK(val_dir, transform=transforms_audio, download=True)
    num_classes = 10
  elif args.dataset == 'yesno':
    train_dataset = dset.YESNO(train_dir, transform=transforms_audio, download=True)
    val_dataset = dset.YESNO(val_dir, transform=transforms_audio, download=True)
    num_classes = 2
  else:
    assert False, 'Dataset is incorrect'

  train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=args.batch_size,
    shuffle=True,
    num_workers=args.workers,
    # pin_memory=True, # What is this?
    # sampler=None     # What is this?
  )
  val_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=args.batch_size, shuffle=False,
    num_workers=args.workers, pin_memory=True)


  #Feed in respective model file to pass into model (alexnet.py)
  print_log("=> creating model '{}'".format(args.arch), log)
  # Init model, criterion, and optimizer
  # net = models.__dict__[args.arch](num_classes)
  net = AlexNet(num_classes)
  #
  print_log("=> network :\n {}".format(net), log)

  # net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))

  # define loss function (criterion) and optimizer
  criterion = torch.nn.CrossEntropyLoss()

  # Define stochastic gradient descent as optimizer (run backprop on random small batch)
  optimizer = torch.optim.SGD(net.parameters(), state['learning_rate'], momentum=state['momentum'],
                weight_decay=state['decay'], nesterov=True)

  #Sets use for GPU if available
  if args.use_cuda:
    net.cuda()
    criterion.cuda()

  recorder = RecorderMeter(args.epochs)
  # optionally resume from a checkpoint
  # Need same python vresion that the resume was in 
  if args.resume:
    if os.path.isfile(args.resume):
      print_log("=> loading checkpoint '{}'".format(args.resume), log)
      if args.ngpu == 0:
        checkpoint = torch.load(args.resume, map_location=lambda storage, loc: storage)
      else:
        checkpoint = torch.load(args.resume)

      recorder = checkpoint['recorder']
      args.start_epoch = checkpoint['epoch']
      net.load_state_dict(checkpoint['state_dict'])
      optimizer.load_state_dict(checkpoint['optimizer'])
      print_log("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch']), log)
    else:
      print_log("=> no checkpoint found at '{}'".format(args.resume), log)
  else:
    print_log("=> do not use any checkpoint for {} model".format(args.arch), log)

  if args.evaluate:
    validate(val_loader, net, criterion, 0, log, val_dataset)
    return

  # Main loop
  start_time = time.time()
  epoch_time = AverageMeter()

  # Training occurs here
  for epoch in range(args.start_epoch, args.epochs):
    current_learning_rate = adjust_learning_rate(optimizer, epoch, args.gammas, args.schedule)

    need_hour, need_mins, need_secs = convert_secs2time(epoch_time.avg * (args.epochs-epoch))
    need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(need_hour, need_mins, need_secs)

    print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f}]'.format(time_string(), epoch, args.epochs, need_time, current_learning_rate) \
                + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False), 100-recorder.max_accuracy(False)), log)

    print("One epoch")
    # train for one epoch
    # Call to train (note that our previous net is passed into the model argument)
    train_acc, train_los = train(train_loader, net, criterion, optimizer, epoch, log, train_dataset)

    # evaluate on validation set
    #val_acc,   val_los   = extract_features(test_loader, net, criterion, log)
    val_acc,   val_los   = validate(val_loader, net, criterion, epoch, log, val_dataset)
    is_best = recorder.update(epoch, train_los, train_acc, val_los, val_acc)

    save_checkpoint({
      'epoch': epoch + 1,
      'arch': args.arch,
      'state_dict': net.state_dict(),
      'recorder': recorder,
      'optimizer' : optimizer.state_dict(),
    }, is_best, args.save_path, 'checkpoint.pth.tar')

    # measure elapsed time
    epoch_time.update(time.time() - start_time)
    start_time = time.time()
    recorder.plot_curve( os.path.join(args.save_path, 'curve.png') )

  log.close()
コード例 #10
0
    start_time = time.time()
    epoch_time = AverageMeter()
    print("start epoch is", args.start_epoch)
    print("epoch is", args.epochs)
    for epoch in range(args.start_epoch, args.epochs):
        current_learning_rate = adjust_learning_rate(optimizer, epoch,
                                                     args.gammas,
                                                     args.schedule)

        need_hour, need_mins, need_secs = convert_secs2time(
            epoch_time.avg * (args.epochs - epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(
            need_hour, need_mins, need_secs)

        print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f}]'.format(time_string(), epoch, args.epochs, need_time, current_learning_rate) \
           + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False), 100-recorder.max_accuracy(False)), log)

        # train for one epoch
        train_acc, train_los = train(train_loader, net, criterion, optimizer,
                                     epoch, log)

        # evaluate on validation set
        val_acc, val_los = validate(test_loader, net, criterion, log)
        is_best = recorder.update(epoch, train_los, train_acc, val_los,
                                  val_acc, args.epochs)

        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': net.state_dict(),
コード例 #11
0
def main(args):
  # Init logger
  if not os.path.isdir(args.save_path):
    os.makedirs(args.save_path)
  log = open(os.path.join(args.save_path, 'log_seed_{}.txt'.format(args.manualSeed)), 'w')
  print_log('save path : {}'.format(args.save_path), log)
  state = {k: v for k, v in args._get_kwargs()}
  print_log(state, log)
  print_log("Random Seed: {}".format(args.manualSeed), log)
  print_log("python version : {}".format(sys.version.replace('\n', ' ')), log)
  print_log("torch  version : {}".format(torch.__version__), log)
  print_log("cudnn  version : {}".format(torch.backends.cudnn.version()), log)

  # Init dataset
  if not os.path.isdir(args.data_path):
    os.makedirs(args.data_path)

  if args.dataset == 'cifar10':
    mean = [x / 255 for x in [125.3, 123.0, 113.9]]
    std = [x / 255 for x in [63.0, 62.1, 66.7]]
  elif args.dataset == 'cifar100':
    mean = [x / 255 for x in [129.3, 124.1, 112.4]]
    std = [x / 255 for x in [68.2, 65.4, 70.4]]
  elif args.dataset == 'imagenet32x32':
    mean = [x / 255 for x in [122.7, 116.7, 104.0]] 
    std = [x / 255 for x in [66.4, 64.6, 68.4]]
  elif args.dataset == 'svhn':
    pass
  else:
    assert False, "Unknow dataset : {}".format(args.dataset)

  if args.dataset == 'cifar10' or args.dataset == 'cifar100' or args.dataset == 'imagenet32x32':
    train_transform = transforms.Compose(
      [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(),
      transforms.Normalize(mean, std)])
    test_transform = transforms.Compose(
      [transforms.ToTensor(), transforms.Normalize(mean, std)])

  if args.dataset == 'cifar10':
    train_data = dset.CIFAR10(args.data_path, train=True, transform=train_transform, download=True)
    test_data = dset.CIFAR10(args.data_path, train=False, transform=test_transform, download=True)
    num_classes = 10
  elif args.dataset == 'cifar100':
    train_data = dset.CIFAR100(args.data_path, train=True, transform=train_transform, download=True)
    test_data = dset.CIFAR100(args.data_path, train=False, transform=test_transform, download=True)
    num_classes = 100
  elif args.dataset == 'svhn':
    def target_transform(target):
      return int(target[0])-1
    train_data = dset.SVHN(args.data_path, split='train', transform=transforms.Compose(
        [transforms.ToTensor(),]), download=True, target_transform=target_transform)
    extra_data = dset.SVHN(args.data_path, split='extra', transform=transforms.Compose(
        [transforms.ToTensor(),]), download=True, target_transform=target_transform)
    train_data.data = np.concatenate([train_data.data, extra_data.data])
    train_data.labels = np.concatenate([train_data.labels, extra_data.labels])
    print(train_data.data.shape, train_data.labels.shape)
    test_data = dset.SVHN(args.data_path, split='test', transform=transforms.Compose([transforms.ToTensor(),]), download=True, target_transform=target_transform)
    num_classes = 10
  elif args.dataset == 'imagenet32x32':
    train_data = IMAGENET32X32(args.data_path, train=True, transform=train_transform, download=True)
    test_data = IMAGENET32X32(args.data_path, train=False, transform=test_transform, download=True)
    num_classes = 1000
  else:
    assert False, 'Do not support dataset : {}'.format(args.dataset)

  train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True,
                         num_workers=args.workers, pin_memory=True)
  test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=False,
                        num_workers=args.workers, pin_memory=True)
  M_loader = torch.utils.data.DataLoader(train_data, batch_size=8, shuffle=True,
                         num_workers=args.workers, pin_memory=True)

  print_log("=> creating model '{}'".format(args.arch), log)
  # Init model, criterion, and optimizer
  net = models.__dict__[args.arch](num_classes=num_classes)

  #print_log("=> network:\n {}".format(net), log)

  net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))

  # define loss function (criterion) and optimizer
  criterion = torch.nn.CrossEntropyLoss()

  """
  params_skip = []
  params_noskip = []
  skip_lists = ['bn', 'bias']
  for name, param in net.named_parameters():
    if any(name in skip_name for skip_name in skip_lists):
      params_skip.append(param)
    else:
      params_noskip.append(param)
  param_lrs = [{'params':params_skip, 'lr':state['learning_rate']},
		{'params':params_noskip, 'lr':state['learning_rate']}]
  param_lrs = []
  params = []
  names = []
  layers = [3,] + [54,]*3 + [2,]
  for i, (name, param) in enumerate(net.named_parameters()):
    params.append(param)
    names.append(name)
    if len(params) == layers[0]:
      param_dict = {'params': params, 'lr':state['learning_rate']}
      param_lrs.append(param_dict)
      params = []
      names = []
      layers.pop(0)
      
  """ 
  skip_lists = ['bn', 'bias']
  skip_idx = []
  for idx, (name, param) in enumerate(net.named_parameters()):
    if any(skip_name in name for skip_name in skip_lists):
      skip_idx.append(idx)

  param_lrs = net.parameters()
  
  if args.lars:
    optimizer = LARSOptimizer(param_lrs, state['learning_rate'], momentum=state['momentum'],
                weight_decay=state['decay'], nesterov=False, steps=state['steps'], eta=state['eta'], skip_idx=skip_idx)
  else:
    optimizer = optim.SGD(param_lrs, state['learning_rate'], momentum=state['momentum'],
                weight_decay=state['decay'], nesterov=False)

  if args.use_cuda:
    net.cuda()
    criterion.cuda()

  recorder = RecorderMeter(args.epochs)
  # optionally resume from a checkpoint

  avg_norm = []
  if args.lw: 
    for param in net.parameters():
      avg_norm.append(0)

  # Main loop
  print_log('Epoch  Train_Prec@1  Train_Prec@5  Train_Loss  Test_Prec@1  Test_Prec@5  Test_Loss  Best_Prec@1  Time', log)
  for epoch in range(args.start_epoch, args.epochs):

    # train for one epoch
    start_time = time.time()
    train_top1, train_top5, train_loss = train(train_loader, M_loader, net, criterion, optimizer, epoch, log, args, avg_norm)
    training_time = time.time() - start_time

    # evaluate on validation set
    val_top1, val_top5, val_loss = validate(test_loader, net, criterion, log, args)
    recorder.update(epoch, train_loss, train_top1, val_loss, val_top1)

    print('{epoch:d}        {train_top1:.3f}      {train_top5:.3f}     {train_loss:.3f}      {test_top1:.3f}      {test_top5:.3f}    {test_loss:.3f}    {best_top1:.3f}      {time:.3f} '.format(epoch=epoch, time=training_time, train_top1=train_top1, train_top5=train_top5, train_loss=train_loss, test_top1=val_top1, test_top5=val_top5, test_loss=val_loss, best_top1=recorder.max_accuracy(False)))


  log.close()
コード例 #12
0
def main():
    # Init logger
    if not os.path.isdir(args.save_path):
        os.makedirs(args.save_path)
    log = open(os.path.join(args.save_path, 'log_seed_{}.txt'.format(args.manualSeed)), 'w')
    print_log('save path : {}'.format(args.save_path), log)
    state = {k: v for k, v in args._get_kwargs()}
    print_log(state, log)
    print_log("Random Seed: {}".format(args.manualSeed), log)
    print_log("CS Pruning Rate: {}".format(args.prune_rate_cs), log)
    print_log("GM Pruning Rate: {}".format(args.prune_rate_gm), log)
    print_log("Layer Begin: {}".format(args.layer_begin), log)
    print_log("Layer End: {}".format(args.layer_end), log)
    print_log("Layer Inter: {}".format(args.layer_inter), log)
    print_log("Epoch prune: {}".format(args.epoch_prune), log)
    print_log("use pretrain: {}".format(args.use_pretrain), log)
    print_log("Pretrain path: {}".format(args.pretrain_path), log)

    # Init dataset
    if not os.path.isdir(args.data_path):
        os.makedirs(args.data_path)

    print_log("=> creating model '{}'".format(args.arch), log)

    num_classes = 1000 # number of imagenet classes

    # Init model, criterion, and optimizer
    net = models.__dict__[args.arch](num_classes)
    print_log("=> network :\n {}".format(net), log)

    net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))

    # define loss function (criterion) and optimizer
    criterion = torch.nn.CrossEntropyLoss()

    optimizer = torch.optim.SGD(net.parameters(), state['learning_rate'], momentum=state['momentum'],
                                weight_decay=state['decay'], nesterov=True)

    if args.use_cuda:
        net.cuda()
        criterion.cuda()

    if args.use_pretrain:
        pretrain = torch.load(args.pretrain_path)
        if args.use_state_dict:
            net.load_state_dict(pretrain['state_dict'])
        else:
            net = pretrain['state_dict']

    recorder = RecorderMeter(args.epochs)

    traindir = os.path.join(args.data_path, 'train')
    valdir = os.path.join(args.data_path, 'val')

    # create loaders for testing and pruning
    img_size = 32
    batch_range = list(range(1, 11))
    random.shuffle(batch_range)
    imnetdata = imgnet_utils.load_databatch(traindir, batch_range[0], img_size=img_size)
    X_train_prune = imnetdata['X_train']
    Y_train_prune = imnetdata['Y_train']
    mean_image = imnetdata['mean']
    num_batches_pruner = 0
    for ib in batch_range[1:1 + num_batches_pruner]:
        print('train batch for prune loader: {}'.format(ib))
        imnetdata = imgnet_utils.load_databatch(traindir, ib, img_size=img_size)
        X_train_prune = np.concatenate((X_train_prune, imnetdata['X_train']), axis=0)
        Y_train_prune = np.concatenate((Y_train_prune, imnetdata['Y_train']), axis=0)

    del imnetdata

    train_data_prune_loader = torch.utils.data.TensorDataset(
        torch.cat([torch.FloatTensor(X_train_prune), torch.FloatTensor(X_train_prune[:,:,:,::-1].copy())], dim=0),
        torch.cat([torch.LongTensor(Y_train_prune), torch.LongTensor(Y_train_prune)], dim=0))
    train_prune_loader = torch.utils.data.DataLoader(train_data_prune_loader, batch_size=args.batch_prune_size,
                                                     shuffle=True, num_workers=args.workers, pin_memory=True)

    del X_train_prune, Y_train_prune

    # create test loader
    imnetdata = imgnet_utils.load_validation_data(valdir, mean_image=mean_image, img_size=img_size)
    X_test = imnetdata['X_test']
    Y_test = imnetdata['Y_test']

    del imnetdata

    test_data = torch.utils.data.TensorDataset(torch.FloatTensor(X_test), torch.LongTensor(Y_test))
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=False,
                                              num_workers=args.workers, pin_memory=True)

    del X_test, Y_test

    mdlIdx2ConvIdx = [] # module index to conv filter index
    for index1, layr in enumerate(net.modules()):
        if isinstance(layr, torch.nn.Conv2d):
            mdlIdx2ConvIdx.append(index1)

    prmIdx2ConvIdx = [] # parameter index to conv filter index
    for index2, item in enumerate(net.parameters()):
        if len(item.size()) == 4:
            prmIdx2ConvIdx.append(index2)

    # set index of last layer depending on the known architecture
    if args.arch == 'resnet20':
        args.layer_end = 54
    elif args.arch == 'resnet56':
        args.layer_end = 162
    elif args.arch == 'resnet110':
        args.layer_end = 324
    else:
        pass # unkonwn architecture, use input value

    # asymptotic schedule
    total_pruning_rate = args.prune_rate_gm + args.prune_rate_cs
    compress_rates_total, scalling_factors, compress_rates_cs, compress_rates_fpgm, e2 =\
        cmpAsymptoticSchedule(theta3=total_pruning_rate, e3=args.epochs-1, tau=args.tau, theta_cs_final = args.prune_rate_cs, scaling_attn = args.scaling_attenuation) # tau=8.
    keep_rate_cs = 1. - compress_rates_cs

    if args.use_zero_scaling:
        scalling_factors = np.zeros(scalling_factors.shape)

    m = Mask(net, train_prune_loader, mdlIdx2ConvIdx, prmIdx2ConvIdx, scalling_factors, keep_rate_cs, compress_rates_fpgm, args.max_iter_cs)
    m.set_curr_epoch(0)
    m.set_epoch_cs(args.epoch_apply_cs)
    m.init_selected_filts()
    m.init_length()

    val_acc_1, val_los_1 = validate(test_loader, net, criterion, log)

    m.model = net
    m.init_mask(keep_rate_cs[0], compress_rates_fpgm[0], scalling_factors[0])
    #    m.if_zero()
    m.do_mask()
    m.do_similar_mask()
    net = m.model
    #    m.if_zero()
    if args.use_cuda:
        net = net.cuda()
    val_acc_2, val_los_2 = validate(test_loader, net, criterion, log)

    # Main loop
    start_time = time.time()
    epoch_time = AverageMeter()

    for epoch in range(args.start_epoch, args.epochs):

        current_learning_rate = adjust_learning_rate(optimizer, epoch, args.gammas, args.schedule)

        need_hour, need_mins, need_secs = convert_secs2time(epoch_time.avg * (args.epochs - epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(need_hour, need_mins, need_secs)

        print_log(
            '\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f}]'.format(time_string(), epoch, args.epochs,
                                                                                   need_time, current_learning_rate) \
            + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False),
                                                               100 - recorder.max_accuracy(False)), log)

        # train for one epoch
        train_acc_i = torch.zeros([11, ], dtype=torch.float32)
        train_los_i = torch.zeros([11, ], dtype=torch.float32)
        idx = 0
        for ib in batch_range:
            print('train batch: {}'.format(ib))

            imnetdata = imgnet_utils.load_databatch(traindir, ib, img_size=img_size)
            X_train = imnetdata['X_train']
            Y_train = imnetdata['Y_train']

            del imnetdata

            train_data = torch.utils.data.TensorDataset(
                torch.cat([torch.FloatTensor(X_train), torch.FloatTensor(X_train[:, :, :, ::-1].copy())], dim=0),
                torch.cat([torch.LongTensor(Y_train), torch.LongTensor(Y_train)], dim=0))
            train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True,
                                                       num_workers=args.workers, pin_memory=True)

            del X_train, Y_train

            train_acc_i[idx], train_los_i[idx] = train(train_loader, net, criterion, optimizer, epoch, log, m)

            idx += 1
        train_acc = torch.mean(train_acc_i)
        train_los = torch.mean(train_los_i)

        # evaluate on validation set
        val_acc_1,   val_los_1   = validate(test_loader, net, criterion, log)
        print('Before: val_acc_1: {}, val_los_1: {}'.format(val_acc_1, val_los_1))

        #train_acc, train_los = train(train_loader, net, criterion, optimizer, epoch, log, m)
        
        # evaluate on validation set
        if epoch % args.epoch_prune == 0 or epoch == args.epochs - 1:
            m.model = net
            m.set_curr_epoch(epoch)
            # m.if_zero()
            m.init_mask(keep_rate_cs[epoch], compress_rates_fpgm[epoch], scalling_factors[epoch])
            m.do_mask()
            m.do_similar_mask()
            # m.if_zero()
            net = m.model
            if args.use_cuda:
                net = net.cuda()
            if epoch == args.epochs - 1:
                m.if_zero()

        val_acc_2, val_los_2 = validate(test_loader, net, criterion, log)
        print('After: val_acc_2: {}, val_los_2: {}'.format(val_acc_2, val_los_1))

        is_best = recorder.update(epoch, train_los, train_acc, val_los_2, val_acc_2)

        save_checkpoint({
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': net,
            'recorder': recorder,
            'optimizer': optimizer.state_dict(),
        }, is_best, args.save_path, 'checkpoint.pth.tar')

        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()
        recorder.plot_curve(os.path.join(args.save_path, 'curve.png'))

    log.close()
コード例 #13
0
def main():
    # Init logger
    if not os.path.isdir(args.save_path):
        os.makedirs(args.save_path)
    log = open(
        os.path.join(args.save_path,
                     'log_seed_{}.txt'.format(args.manualSeed)), 'w')
    print_log('save path : {}'.format(args.save_path), log)
    state = {k: v for k, v in args._get_kwargs()}
    print_log(state, log)
    print_log("Random Seed: {}".format(args.manualSeed), log)
    print_log("Compress Rate: {}".format(args.rate), log)
    print_log("Layer Begin: {}".format(args.layer_begin), log)
    print_log("Layer End: {}".format(args.layer_end), log)
    print_log("Layer Inter: {}".format(args.layer_inter), log)
    print_log("Epoch prune: {}".format(args.epoch_prune), log)
    # Init dataset
    if not os.path.isdir(args.data_path):
        os.makedirs(args.data_path)

    if args.dataset == 'cifar10':
        mean = [x / 255 for x in [125.3, 123.0, 113.9]]
        std = [x / 255 for x in [63.0, 62.1, 66.7]]
    elif args.dataset == 'cifar100':
        mean = [x / 255 for x in [129.3, 124.1, 112.4]]
        std = [x / 255 for x in [68.2, 65.4, 70.4]]
    else:
        assert False, "Unknow dataset : {}".format(args.dataset)

    train_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32, padding=4),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])
    test_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize(mean, std)])

    if args.dataset == 'cifar10':
        train_data = dset.CIFAR10(args.data_path,
                                  train=True,
                                  transform=train_transform,
                                  download=True)
        test_data = dset.CIFAR10(args.data_path,
                                 train=False,
                                 transform=test_transform,
                                 download=True)
        num_classes = 10
    elif args.dataset == 'cifar100':
        train_data = dset.CIFAR100(args.data_path,
                                   train=True,
                                   transform=train_transform,
                                   download=True)
        test_data = dset.CIFAR100(args.data_path,
                                  train=False,
                                  transform=test_transform,
                                  download=True)
        num_classes = 100
    elif args.dataset == 'svhn':
        train_data = dset.SVHN(args.data_path,
                               split='train',
                               transform=train_transform,
                               download=True)
        test_data = dset.SVHN(args.data_path,
                              split='test',
                              transform=test_transform,
                              download=True)
        num_classes = 10
    elif args.dataset == 'stl10':
        train_data = dset.STL10(args.data_path,
                                split='train',
                                transform=train_transform,
                                download=True)
        test_data = dset.STL10(args.data_path,
                               split='test',
                               transform=test_transform,
                               download=True)
        num_classes = 10
    elif args.dataset == 'imagenet':
        assert False, 'Do not finish imagenet code'
    else:
        assert False, 'Do not support dataset : {}'.format(args.dataset)

    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               k=args.k,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=args.batch_size,
                                              k=args.k,
                                              shuffle=False,
                                              num_workers=args.workers,
                                              pin_memory=True)

    print_log("=> creating model '{}'".format(args.arch), log)
    # Init model, criterion, and optimizer
    net = models.__dict__[args.arch](num_classes)
    #net = models.__dict__[args.arch](dataset=args.dataset, depth=args.depth)
    print_log("=> network :\n {}".format(net), log)

    #    net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))

    # define loss function (criterion) and optimizer
    criterion = torch.nn.CrossEntropyLoss()

    #    optimizer = torch.optim.SGD(net.parameters(), state['learning_rate'], momentum=state['momentum'],weight_decay=state['decay'], nesterov=True)
    #    optimizer = SGD_HT(net.parameters(), state['learning_rate'], momentum=state['momentum'],weight_decay=state['decay'], nesterov=True, HTrate=state['HTrate'])
    #    optimizer = HSG(net.parameters(), state['learning_rate'], momentum=state['momentum'],weight_decay=state['decay'], nesterov=True)
    #    optimizer = AHSG(net.parameters(), state['learning_rate'], momentum=state['momentum'],weight_decay=state['decay'], nesterov=True, v=state['learning_rate'])
    #    optimizer = HSG_HT(net.parameters(), state['learning_rate'], momentum=state['momentum'],weight_decay=state['decay'], nesterov=True, HTrate=state['HTrate'])
    optimizer = AHSG_HT(net.parameters(),
                        state['learning_rate'],
                        momentum=state['momentum'],
                        weight_decay=state['decay'],
                        nesterov=True,
                        HTrate=state['HTrate'],
                        v=state['v'])
    if args.use_cuda:
        net.cuda()
        criterion.cuda()


#    L1_norm_resnet(net.parameters(),args.HTrate)

    recorder = RecorderMeter(args.epochs)
    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print_log("=> loading checkpoint '{}'".format(args.resume), log)
            checkpoint = torch.load(args.resume)
            #            recorder = checkpoint['recorder']
            #            args.start_epoch = checkpoint['epoch']
            #            if args.use_state_dict:
            net.load_state_dict(checkpoint['state_dict'])
            #            else:
            #                net = checkpoint['state_dict']

            optimizer.load_state_dict(checkpoint['optimizer'])
            print_log(
                "=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint['epoch']), log)
        else:
            print_log("=> no checkpoint found at '{}'".format(args.resume),
                      log)
    else:
        print_log(
            "=> do not use any checkpoint for {} model".format(args.arch), log)

    if args.evaluate:
        time1 = time.time()
        validate(test_loader, net, criterion, log)
        time2 = time.time()
        print('function took %0.3f ms' % ((time2 - time1) * 1000.0))
        return

    filename = os.path.join(args.save_path,
                            'checkpoint.{:}.pth.tar'.format(args.arch))
    bestname = os.path.join(args.save_path,
                            'best.{:}.pth.tar'.format(args.arch))
    # Main loop
    start_time = time.time()
    epoch_time = AverageMeter()

    best_prec1 = 0.
    for epoch in range(args.start_epoch, args.epochs):
        current_learning_rate = adjust_learning_rate(optimizer, epoch,
                                                     args.gammas,
                                                     args.schedule)
        #current_v = adjust_v(optimizer, epoch, args.gammas, args.schedule)

        need_hour, need_mins, need_secs = convert_secs2time(
            epoch_time.avg * (args.epochs - epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(
            need_hour, need_mins, need_secs)

        print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f}]'.format(time_string(), epoch, args.epochs, need_time, current_learning_rate) \
                                + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False), 100-recorder.max_accuracy(False)), log)

        # train for one epoch
        train_acc, train_los = train(train_loader, net, criterion, optimizer,
                                     epoch, log)

        # evaluate on validation set
        val_acc_1, val_los_1 = validate(test_loader, net, criterion, log)

        is_best = recorder.update(epoch, train_los, train_acc, val_los_1,
                                  val_acc_1)
        #print('val_acc_1',val_acc_1)
        best_prec1 = max(val_acc_1, best_prec1)
        #print('best_prec1',best_prec1)
        '''save_checkpoint({
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': net,
            'recorder': recorder,
            'optimizer' : optimizer.state_dict(),
        }, is_best, args.save_path, 'checkpoint.pth.tar')'''
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': net.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict(),
            }, is_best, filename, bestname)

        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()
        #recorder.plot_curve( os.path.join(args.save_path, 'curve.png') )

    log.close()
コード例 #14
0
def main():
    global best_acc

    if not os.path.isdir(args.save_path): os.makedirs(args.save_path)
    log = open(
        os.path.join(args.save_path,
                     'log_seed_{}.txt'.format(args.manualSeed)), 'w')
    print_log('save path : {}'.format(args.save_path), log)
    state = {k: v for k, v in args._get_kwargs()}
    print_log(state, log)
    print_log("Random Seed: {}".format(args.manualSeed), log)
    print_log("Python version : {}".format(sys.version.replace('\n', ' ')),
              log)
    print_log("PyTorch  version : {}".format(torch.__version__), log)
    print_log("CuDNN  version : {}".format(torch.backends.cudnn.version()),
              log)

    if not os.path.isdir(args.data_path): os.makedirs(args.data_path)

    num_classes, train_loader, test_loader = load_dataset()
    net = load_model(num_classes, log)

    criterion = torch.nn.CrossEntropyLoss().cuda()

    params = group_weight_decay(net, state['decay'], ['coefficients'])
    optimizer = torch.optim.SGD(params,
                                state['learning_rate'],
                                momentum=state['momentum'],
                                nesterov=(state['momentum'] > 0.0))

    recorder = RecorderMeter(args.epochs)
    if args.resume:
        if args.resume == 'auto':
            args.resume = os.path.join(args.save_path, 'checkpoint.pth.tar')
        if os.path.isfile(args.resume):
            print_log("=> loading checkpoint '{}'".format(args.resume), log)
            checkpoint = torch.load(args.resume)
            recorder = checkpoint['recorder']
            recorder.refresh(args.epochs)
            args.start_epoch = checkpoint['epoch']
            net.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            best_acc = recorder.max_accuracy(False)
            print_log(
                "=> loaded checkpoint '{}' accuracy={} (epoch {})".format(
                    args.resume, best_acc, checkpoint['epoch']), log)
        else:
            print_log("=> no checkpoint found at '{}'".format(args.resume),
                      log)
    else:
        print_log(
            "=> do not use any checkpoint for {} model".format(args.arch), log)

    if args.evaluate:
        validate(test_loader, net, criterion, log)
        return

    start_time = time.time()
    epoch_time = AverageMeter()
    train_los = -1

    for epoch in range(args.start_epoch, args.epochs):
        current_learning_rate = adjust_learning_rate(optimizer, epoch,
                                                     args.gammas,
                                                     args.schedule, train_los)

        need_hour, need_mins, need_secs = convert_secs2time(
            epoch_time.avg * (args.epochs - epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(
            need_hour, need_mins, need_secs)

        print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f}]'.format(time_string(), epoch, args.epochs, need_time, current_learning_rate) \
                    + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False), 100-recorder.max_accuracy(False)), log)

        train_acc, train_los = train(train_loader, net, criterion, optimizer,
                                     epoch, log)

        val_acc, val_los = validate(test_loader, net, criterion, log)
        recorder.update(epoch, train_los, train_acc, val_los, val_acc)

        is_best = False
        if val_acc > best_acc:
            is_best = True
            best_acc = val_acc

        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': net.state_dict(),
                'recorder': recorder,
                'optimizer': optimizer.state_dict(),
            }, is_best, args.save_path, 'checkpoint.pth.tar')

        epoch_time.update(time.time() - start_time)
        start_time = time.time()
        recorder.plot_curve(result_png_path)
    log.close()
コード例 #15
0
ファイル: main.py プロジェクト: Coderx7/TF_Pytorch_testbed
def main():
  # Init logger
  if not os.path.isdir(args.save_path):
    os.makedirs(args.save_path)

  # used for file names, etc 
  time_stamp = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')   
  log = open(os.path.join(args.save_path, '{0}_{1}_{2}.txt'.format(args.arch, args.dataset, time_stamp)), 'w')
  print_log('save path : {}'.format(args.save_path), log)
  state = {k: v for k, v in args._get_kwargs()}
  print_log(state, log)
  print_log("Random Seed: {}".format(args.manualSeed), log)
  print_log("python version : {}".format(sys.version.replace('\n', ' ')), log)
  print_log("torch  version : {}".format(torch.__version__), log)
  print_log("cudnn  version : {}".format(torch.backends.cudnn.version()), log)

  # Init dataset
  if not os.path.isdir(args.data_path):
    os.makedirs(args.data_path)


  if args.dataset == 'cifar10':
    mean = [x / 255 for x in [125.3, 123.0, 113.9]]
    std = [x / 255 for x in [63.0, 62.1, 66.7]]
  elif args.dataset == 'cifar100':
    mean = [x / 255 for x in [129.3, 124.1, 112.4]]
    std = [x / 255 for x in [68.2, 65.4, 70.4]]
  else:
    assert False, "Unknow dataset : {}".format(args.dataset)



  writer = SummaryWriter()


  train_transform = transforms.Compose(
    [ transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(), transforms.ToTensor(),
     transforms.Normalize(mean, std)])

  test_transform = transforms.Compose(
    [transforms.CenterCrop(32), transforms.ToTensor(), transforms.Normalize(mean, std)])


  if args.dataset == 'cifar10':
    train_data = dset.CIFAR10(args.data_path, train=True, transform=train_transform, download=True)
    test_data = dset.CIFAR10(args.data_path, train=False, transform=test_transform, download=True)
    num_classes = 10
  elif args.dataset == 'cifar100':
    train_data = dset.CIFAR100(args.data_path, train=True, transform=train_transform, download=True)
    test_data = dset.CIFAR100(args.data_path, train=False, transform=test_transform, download=True)
    num_classes = 100
  elif args.dataset == 'svhn':
    train_data = dset.SVHN(args.data_path, split='train', transform=train_transform, download=True)
    train_data_ext = dset.SVHN(args.data_path, split='extra', transform=train_transform, download=True)
    test_data = dset.SVHN(args.data_path, split='test', transform=test_transform, download=True)
    num_classes = 10
  elif args.dataset == 'stl10':
    train_data = dset.STL10(args.data_path, split='train', transform=train_transform, download=True)
    test_data = dset.STL10(args.data_path, split='test', transform=test_transform, download=True)
    num_classes = 10
  elif args.dataset == 'imagenet':
    assert False, 'Do not finish imagenet code'
  else:
    assert False, 'Do not support dataset : {}'.format(args.dataset)

  train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True,
                         num_workers=args.workers, pin_memory=True)
  test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=False,
                        num_workers=args.workers, pin_memory=True)

  print_log("=> creating model '{}'".format(args.arch), log)
  # Init model, criterion, and optimizer
  net = models.__dict__[args.arch](num_classes)
  print_log("=> network :\n {}".format(net), log)


  net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))

  # define loss function (criterion) and optimizer
  criterion = torch.nn.CrossEntropyLoss()

  #optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=0.005, nesterov=False)
  optimizer = torch.optim.Adadelta(net.parameters(), lr=0.1, rho=0.9, eps=1e-3, weight_decay=0.001)


  print_log("=> Seed '{}'".format(args.manualSeed), log)
  print_log("=> dataset mean and std '{} - {}'".format(str(mean), str(std)), log)
  
  states_settings = {
                     'optimizer': optimizer.state_dict()
                    }


  print_log("=> optimizer '{}'".format(states_settings), log)
  # 50k, 95k, 153k,195k, 220k 
  milestones = [100, 190, 306, 390, 440, 500]
  scheduler = lr_scheduler.MultiStepLR(optimizer, milestones, gamma=0.1)

  if args.use_cuda:
    net.cuda()
    criterion.cuda()

  recorder = RecorderMeter(args.epochs)
  # optionally resume from a checkpoint
  if args.resume:
    if os.path.isfile(args.resume):
      print_log("=> loading checkpoint '{}'".format(args.resume), log)
      checkpoint = torch.load(args.resume)
      recorder = checkpoint['recorder']
      args.start_epoch = checkpoint['epoch']
      net.load_state_dict(checkpoint['state_dict'])
      scheduler.load_state_dict(checkpoint['scheduler'])
      optimizer.load_state_dict(checkpoint['optimizer'])
      print_log("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch']), log)
    else:
      print_log("=> no checkpoint found at '{}'".format(args.resume), log)
  else:
    print_log("=> did not use any checkpoint for {} model".format(args.arch), log)

  if args.evaluate:
    validate(test_loader, net, criterion, log)
    return

  # Main loop
  start_time = time.time()
  epoch_time = AverageMeter()

  for epoch in range(args.start_epoch, args.epochs):
    #current_learning_rate = adjust_learning_rate(optimizer, epoch, args.gammas, args.schedule)
    current_learning_rate = float(scheduler.get_lr()[-1])

    scheduler.step()

    #adjust_learning_rate(optimizer, epoch)

    need_hour, need_mins, need_secs = convert_secs2time(epoch_time.avg * (args.epochs-epoch))
    need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(need_hour, need_mins, need_secs)

    print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:.6f}]'.format(time_string(), epoch, args.epochs, need_time, current_learning_rate) \
                + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False), 100-recorder.max_accuracy(False)), log)

    # train for one epoch
    train_acc, train_los = train(train_loader, net, criterion, optimizer, epoch, log)

    # evaluate on validation set
    #val_acc,   val_los   = extract_features(test_loader, net, criterion, log)
    val_acc,   val_los   = validate(test_loader, net, criterion, log)
    is_best = recorder.update(epoch, train_los, train_acc, val_los, val_acc)


    writer.add_scalar('training/loss', train_los, epoch)
    writer.add_scalar('training/acc', train_acc, epoch)

    writer.add_scalar('validation/loss', val_los, epoch)
    writer.add_scalar('validation/acc', val_acc, epoch)

   

    if epoch == 180:
        save_checkpoint({
          'epoch': epoch ,
          'arch': args.arch,
          'state_dict': net.state_dict(),
          'recorder': recorder,
          'optimizer' : optimizer.state_dict(),
          'scheduler' : scheduler.state_dict(),
        }, False, args.save_path, 'chkpt_{0}_{1}_{2}_{3}.pth.tar'.format(args.arch, args.dataset, epoch, time_stamp))

    save_checkpoint({
      'epoch': epoch + 1,
      'arch': args.arch,
      'state_dict': net.state_dict(),
      'recorder': recorder,
      'optimizer' : optimizer.state_dict(),
      'scheduler' : scheduler.state_dict(),
    }, is_best, args.save_path, 'chkpt_{0}_{1}_{2}.pth.tar'.format(args.arch, args.dataset, time_stamp))


    epoch_time.update(time.time() - start_time)
    start_time = time.time()
    recorder.plot_curve( os.path.join(args.save_path, 'plot_{0}_{1}_{2}.png'.format(args.arch, args.dataset, time_stamp)) )

  writer.close()
  log.close()
コード例 #16
0
def main_worker(gpu, ngpus_per_node, args):
    global best_acc
    args.gpu = gpu
    assert args.gpu is not None
    print("Use GPU: {} for training".format(args.gpu))

    log = open(os.path.join(args.save_path, 'log{}{}.txt'.format('_seed'+
                   str(args.manualSeed), '_eval' if args.evaluate else '')), 'w')
    log = (log, args.gpu)

    net = models.__dict__[args.arch](args, args.depth, args.wide, args.num_classes)
    print_log("Python version : {}".format(sys.version.replace('\n', ' ')), log)
    print_log("PyTorch  version : {}".format(torch.__version__), log)
    print_log("CuDNN  version : {}".format(torch.backends.cudnn.version()), log)
    print_log("Number of parameters: {}".format(sum([p.numel() for p in net.parameters()])), log)
    print_log(str(args), log)

    if args.distributed:
        if args.multiprocessing_distributed:
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url+":"+args.dist_port,
                                world_size=args.world_size, rank=args.rank)
        torch.cuda.set_device(args.gpu)
        net.cuda(args.gpu)
        args.batch_size = int(args.batch_size / ngpus_per_node)
        net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[args.gpu])
    else:
        torch.cuda.set_device(args.gpu)
        net = net.cuda(args.gpu)

    criterion = torch.nn.CrossEntropyLoss().cuda(args.gpu)

    mus, vars = [], []
    for name, param in net.named_parameters():
        if 'log_sigma' in name: vars.append(param)
        else: assert(param.requires_grad); mus.append(param)
    optimizer = torch.optim.SGD([{'params': mus, 'weight_decay': args.decay}],
                args.batch_size * args.world_size / 128 * args.learning_rate,
                momentum=args.momentum, nesterov=(args.momentum > 0.0))
    if args.bayes:
        assert(len(mus) == len(vars))
        var_optimizer = VarSGD([{'params': vars, 'weight_decay': args.decay}],
                    args.batch_size * args.world_size / 128 * args.log_sigma_lr,
                    momentum=args.momentum, nesterov=(args.momentum > 0.0),
                    num_data = args.num_data)
    else:
        assert(len(vars) == 0)
        var_optimizer = NoneOptimizer()


    recorder = RecorderMeter(args.epochs)
    if args.resume:
        if args.resume == 'auto': args.resume = os.path.join(args.save_path, 'checkpoint.pth.tar')
        if os.path.isfile(args.resume):
            print_log("=> loading checkpoint '{}'".format(args.resume), log)
            checkpoint = torch.load(args.resume, map_location='cuda:{}'.format(args.gpu))
            recorder = checkpoint['recorder']
            recorder.refresh(args.epochs)
            args.start_epoch = checkpoint['epoch']
            net.load_state_dict(checkpoint['state_dict'] if args.distributed else {k.replace('module.', ''): v for k,v in checkpoint['state_dict'].items()})
            optimizer.load_state_dict(checkpoint['optimizer'])
            var_optimizer.load_state_dict(checkpoint['var_optimizer'])
            best_acc = recorder.max_accuracy(False)
            print_log("=> loaded checkpoint '{}' accuracy={} (epoch {})" .format(args.resume, best_acc, checkpoint['epoch']), log)
        else:
            print_log("=> no checkpoint found at '{}'".format(args.resume), log)
    else:
        print_log("=> do not use any checkpoint for {} model".format(args.arch), log)

    cudnn.benchmark = True

    train_loader, test_loader = load_dataset(args)

    if args.evaluate:
        validate(test_loader, net, criterion, args, log)
        return

    start_time = time.time()
    epoch_time = AverageMeter()
    train_los = -1

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_loader.sampler.set_epoch(epoch)
        cur_lr, cur_slr = adjust_learning_rate(optimizer, var_optimizer, epoch, args)

        need_hour, need_mins, need_secs = convert_secs2time(epoch_time.avg * (args.epochs-epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(need_hour, need_mins, need_secs)

        print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f} {:6.4f}]'.format(
                                    time_string(), epoch, args.epochs, need_time, cur_lr, cur_slr) \
                    + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False), 100-recorder.max_accuracy(False)), log)

        train_acc, train_los = train(train_loader, net, criterion, optimizer, var_optimizer, epoch, args, log)
        val_acc, val_los   = validate(test_loader, net, criterion, args, log)
        recorder.update(epoch, train_los, train_acc, val_los, val_acc)

        is_best = False
        if val_acc > best_acc:
            is_best = True
            best_acc = val_acc

        if args.gpu == 0:
            save_checkpoint({
              'epoch': epoch + 1,
              'arch': args.arch,
              'state_dict': net.state_dict(),
              'recorder': recorder,
              'optimizer' : optimizer.state_dict(),
              'var_optimizer' : var_optimizer.state_dict(),
            }, is_best, args.save_path, 'checkpoint.pth.tar')

        epoch_time.update(time.time() - start_time)
        start_time = time.time()
        recorder.plot_curve(os.path.join(args.save_path, 'log.png'))

    log[0].close()
def main():
    # Init logger
    if not os.path.isdir(args.save_path):
        os.makedirs(args.save_path)
    log = open(
        os.path.join(args.save_path,
                     'log_seed_{}.txt'.format(args.manualSeed)), 'w')
    print_log('save path : {}'.format(args.save_path), log)
    state = {k: v for k, v in args._get_kwargs()}
    print_log(state, log)
    print_log("Random Seed: {}".format(args.manualSeed), log)
    print_log("python version : {}".format(sys.version.replace('\n', ' ')),
              log)
    print_log("torch  version : {}".format(torch.__version__), log)
    print_log("cudnn  version : {}".format(torch.backends.cudnn.version()),
              log)
    print_log("Norm Pruning Rate: {}".format(args.rate_norm), log)
    print_log("Distance Pruning Rate: {}".format(args.rate_dist), log)
    print_log("Layer Begin: {}".format(args.layer_begin), log)
    print_log("Layer End: {}".format(args.layer_end), log)
    print_log("Layer Inter: {}".format(args.layer_inter), log)
    print_log("Epoch prune: {}".format(args.epoch_prune), log)
    print_log("use pretrain: {}".format(args.use_pretrain), log)
    print_log("Pretrain path: {}".format(args.pretrain_path), log)
    print_log("Dist type: {}".format(args.dist_type), log)

    # Init dataset
    if not os.path.isdir(args.data_path):
        os.makedirs(args.data_path)

    if args.dataset == 'cifar10':
        mean = [x / 255 for x in [125.3, 123.0, 113.9]]
        std = [x / 255 for x in [63.0, 62.1, 66.7]]
    elif args.dataset == 'cifar100':
        mean = [x / 255 for x in [129.3, 124.1, 112.4]]
        std = [x / 255 for x in [68.2, 65.4, 70.4]]
    else:
        assert False, "Unknow dataset : {}".format(args.dataset)

    train_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32, padding=4),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])
    test_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize(mean, std)])

    if args.dataset == 'cifar10':
        train_data = dset.CIFAR10(args.data_path,
                                  train=True,
                                  transform=train_transform,
                                  download=True)
        test_data = dset.CIFAR10(args.data_path,
                                 train=False,
                                 transform=test_transform,
                                 download=True)
        num_classes = 10
    elif args.dataset == 'cifar100':
        train_data = dset.CIFAR100(args.data_path,
                                   train=True,
                                   transform=train_transform,
                                   download=True)
        test_data = dset.CIFAR100(args.data_path,
                                  train=False,
                                  transform=test_transform,
                                  download=True)
        num_classes = 100

    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=args.workers,
                                              pin_memory=True)

    # Init model, criterion, and optimizer
    net = mobilenetv2.MobileNetV2()
    model_loaded = torch.load('ckpt.pth')
    mobile = model_loaded['net'].copy()
    print(type(mobile))
    net.load_state_dict(
        {k.replace('module.', ''): v
         for k, v in mobile.items()})

    print_log("=> network :\n {}".format(net), log)

    net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))

    # define loss function (criterion) and optimizer
    criterion = torch.nn.CrossEntropyLoss()

    optimizer = torch.optim.SGD(net.parameters(),
                                state['lr'],
                                momentum=state['momentum'],
                                weight_decay=state['decay'],
                                nesterov=True)

    if args.use_cuda:
        net.cuda()
        criterion.cuda()

    recorder = RecorderMeter(args.epochs)

    if args.evaluate:
        time1 = time.time()
        validate(test_loader, net, criterion, log)
        time2 = time.time()
        print('function took %0.3f ms' % ((time2 - time1) * 1000.0))
        return

    m = Mask(net)
    m.init_length()
    print("-" * 10 + "one epoch begin" + "-" * 10)
    print("remaining ratio of pruning : Norm is %f" % args.rate_norm)
    print("reducing ratio of pruning : Distance is %f" % args.rate_dist)
    print("total remaining ratio is %f" % (args.rate_norm - args.rate_dist))

    validation_accurate_1, validation_loss_1 = validate(
        test_loader, net, criterion, log)

    print(" accu before is: %.3f %%" % validation_accurate_1)

    m.model = net

    m.init_mask(args.rate_norm, args.rate_dist, args.dist_type)
    #    m.if_zero()
    m.do_mask()
    m.do_similar_mask()
    net = m.model
    #    m.if_zero()
    if args.use_cuda:
        net = net.cuda()
    validation_accurate_2, validation_loss_2 = validate(
        test_loader, net, criterion, log)
    print(" accu after is: %s %%" % validation_accurate_2)

    # Main loop
    start_time = time.time()
    epoch_time = AverageMeter()
    small_filter_idx = []
    large_filter_idx = []

    for epoch in range(args.start_epoch, args.epochs):
        current_lr = adjust_lr(optimizer, epoch, args.gammas, args.schedule)

        required_hour, required_mins, required_secs = convert_secs2time(
            epoch_time.avg * (args.epochs - epoch))
        required_time = '[required: {:02d}:{:02d}:{:02d}]'.format(
            required_hour, required_mins, required_secs)

        print_log(
            '\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [lr={:6.4f}]'.format(time_string(), epoch, args.epochs,
                                                                                   required_time, current_lr) \
            + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False),
                                                               100 - recorder.max_accuracy(False)), log)

        # train for one epoch
        train_accurate, train_loss = train(train_loader, net, criterion,
                                           optimizer, epoch, log, m)

        # evaluate on validation set
        #         validation_accurate_1, validation_loss_1 = validate(test_loader, net, criterion, log)
        if epoch % args.epoch_prune == 0 or epoch == args.epochs - 1:
            m.model = net
            m.if_zero()
            m.init_mask(args.rate_norm, args.rate_dist, args.dist_type)
            m.do_mask()
            m.do_similar_mask()
            m.if_zero()
            net = m.model
            if args.use_cuda:
                net = net.cuda()

        validation_accurate_2, validation_loss_2 = validate(
            test_loader, net, criterion, log)

        best = recorder.update(epoch, train_loss, train_accurate,
                               validation_loss_2, validation_accurate_2)

        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': net,
                'recorder': recorder,
                'optimizer': optimizer.state_dict(),
            }, best, args.save_path, 'checkpoint.pth.tar')

        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()
        recorder.plot_curve(os.path.join(args.save_path, 'curve.png'))

    start = time.time()
    validation_accurate_2, validation_loss_2 = validate(
        test_loader, net, criterion, log)
    print("Acc=%.4f\n" % (validation_accurate_2))
    print(f"Run time: {(time.time() - start):.3f} s")
    torch.save(
        net,
        'geometric_median-%s%.2f.pth' % (args.prune_method, args.rate_dist))
    log.close()
コード例 #18
0
def main():
    # Init logger
    if not os.path.isdir(args.save_path):
        os.makedirs(args.save_path)

    # Get timestamp
    time_stamp = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')

    # Random seed used
    print("Random Seed: {0}".format(args.manualSeed))
    # Python Version used
    print("python version : {}".format(sys.version.replace('\n', ' ')), log)
    # Torch Version used
    print("torch  version : {}".format(torch.__version__), log)
    # Cudnn Version used
    print("cudnn  version : {0}".format(torch.backends.cudnn.version()))

    # Path for the dataset. If not present, it is downloaded
    if not os.path.isdir(args.data_path):
        os.makedirs(args.data_path)

    # Preparing dataset
    if args.dataset == 'cifar10':
        mean = [x / 255 for x in [125.3, 123.0, 113.9]]
        std = [x / 255 for x in [63.0, 62.1, 66.7]]
    elif args.dataset == 'cifar100':
        mean = [x / 255 for x in [129.3, 124.1, 112.4]]
        std = [x / 255 for x in [68.2, 65.4, 70.4]]
    else:
        assert False, "Unknow dataset : {}".format(args.dataset)

    # Additional dataset transforms like padding, crop and flipping of images
    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])
    test_transform = transforms.Compose([
        transforms.CenterCrop(32),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])

    # Loading dataset from the path to data
    if args.dataset == 'cifar10':
        train_data = dset.CIFAR10(args.data_path,
                                  train=True,
                                  transform=train_transform,
                                  download=True)
        test_data = dset.CIFAR10(args.data_path,
                                 train=False,
                                 transform=test_transform,
                                 download=True)
        num_classes = 10
    elif args.dataset == 'cifar100':
        train_data = dset.CIFAR100(args.data_path,
                                   train=True,
                                   transform=train_transform,
                                   target_transform=None,
                                   download=True)
        test_data = dset.CIFAR100(args.data_path,
                                  train=False,
                                  transform=test_transform,
                                  target_transform=None,
                                  download=True)
        num_classes = 100
    else:
        assert False, 'Does not support dataset : {}'.format(args.dataset)

    # Splitting the training dataset into train and val sets
    num_train = len(train_data)
    indices = list(range(num_train))
    split = int(10000)

    np.random.seed(args.manualSeed)
    np.random.shuffle(indices)

    train_idx, valid_idx = indices[split:], indices[:split]
    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)

    # Loading the data into loaders
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               sampler=train_sampler,
                                               num_workers=args.workers,
                                               pin_memory=True)

    val_loader = torch.utils.data.DataLoader(train_data,
                                             batch_size=args.batch_size,
                                             sampler=valid_sampler,
                                             num_workers=args.workers,
                                             pin_memory=True)

    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=args.workers,
                                              pin_memory=True)

    # Print the architecture of the model being used
    print("=> creating model '{}'".format(args.arch), log)

    # Inutializing the model. Initializes using weights.
    net = simplenet(classes=100)

    # Using GPUs for the Neural Network
    if args.ngpu > 0:
        net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))

    # Define loss function (criterion)
    criterion = torch.nn.CrossEntropyLoss()
    # Define the optimizer used in the laerning algorithm (optimizer)
    optimizer = torch.optim.Adadelta(pars,
                                     lr=0.1,
                                     rho=0.9,
                                     eps=1e-3,
                                     weight_decay=0)

    # The list of weights that are to be learnt while training

    # This is the list of all the parameters. Used when all the layers are to be trained
    # (Comment for training only the last layer and uncomment the following command)
    pars = list(net.parameters())

    # This is the list of parameters in the last layer of the network. Used when all the layers are to be trained
    # (Uncomment for training only the last layer and the rest of the layers of the network frozen)
    # pars = list(net.module.classifier.parameters())

    states_settings = {'optimizer': optimizer.state_dict()}

    # Epochs after which when the learning rate is to be changed
    milestones = [50, 70, 100]
    # Defines how the leraning rate is changed at the above mentioned epochs.
    # 'gamma' is the factor by which the learning rate is reduced
    scheduler = lr_scheduler.MultiStepLR(optimizer, milestones, gamma=0.1)

    # Put the network and the other data on the CUDA device
    if args.use_cuda:
        net.cuda()
        criterion.cuda()
        print('__Number CUDA Devices:', torch.cuda.device_count())

    # A structure to record different results while training and evaluating
    recorder = RecorderMeter(args.epochs)

    # If this argument is given, the network loads weights from the checkpoint file given in the arguments
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)

            # If no GPU is being used, uncomment the following line
            # checkpoint = torch.load(args.resume,map_location='cpu')

            # When loading from a checkpoint with only features saved, uncomment the following line else comment it
            # net.module.features.load_state_dict(checkpoint['state_dict'])

            # When loading from a checkpoint with all the parameters saved, uncomment the following line else comment it.
            net.load_state_dict(checkpoint['state_dict'])

            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

    else:
        print("=> did not use any checkpoint for {} model".format(args.arch))

    # Opening a log, None when not using it
    log = None

    # Evaluating a model on training, testing and validation sets
    if args.evaluate:
        print("Train data : ")
        acc, loss, tloss = validate(train_loader, net, criterion, log)
        print(loss, acc, tloss)
        print("Test data : ")
        acc, loss, tloss = validate(test_loader, net, criterion, log)
        print(loss, acc, tloss)
        print("Validation data : ")
        acc, loss, tloss = validate(val_loader, net, criterion, log)
        print(loss, acc, tloss)
        return

    # Main loop

    # Start timer
    start_time = time.time()
    # Structure to record time for each epoch
    epoch_time = AverageMeter()

    # Starts training from epoch 0
    args.start_epoch = 0

    best_loss = 1000000
    best_acc = 100000
    log = None

    # Loop for training for each epoch
    for epoch in range(args.start_epoch, args.epochs):

        # Get learning rate for the epoch
        current_learning_rate = adjust_learning_rate(optimizer, epoch,
                                                     args.gammas,
                                                     args.schedule)
        current_learning_rate = float(scheduler.get_lr()[-1])

        scheduler.step()

        # Calcluate the time for the remaing number of epochs
        need_hour, need_mins, need_secs = convert_secs2time(
            epoch_time.avg * (args.epochs - epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(
            need_hour, need_mins, need_secs)

        # Train the network on the training data and return accuracy, training loss and the t-loss on the training dataset
        train_acc, train_los, tloss_train = train(train_loader, net, criterion,
                                                  optimizer, epoch, log)

        # Print after each epoch the remaining time, epoch number, max accuracy recorded on validation set so far and other details
        print('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:.6f}]'.
              format(time_string(), epoch, args.epochs, need_time,
                     current_learning_rate) +
              ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(
                  recorder.max_accuracy(False), 100 -
                  recorder.max_accuracy(False)))

        # Evaluate on the validation data and update the recorded values
        val_acc, val_los, tloss_val = validate(val_loader, net, criterion, log)
        is_best = recorder.update(epoch, train_los, train_acc, val_los,
                                  val_acc)

        # Save checkpoint after every 30 epochs
        if epoch % 30 == 29:
            save_checkpoint(
                {
                    'epoch': epoch,
                    'arch': args.arch,
                    # Save the whole network with the following line uncommented. If only features are to be saved, comment it.
                    # 'state_dict': net.state_dict(),
                    # Save only the features layers of the network with the following line uncommented. If thw whole network is to be saved, comment it.
                    'state_dict': net.module.features.state_dict(),
                    'recorder': recorder,
                    'optimizer': optimizer.state_dict(),
                    # Name of the ckeckpoint file to be saved to
                },
                False,
                args.save_path,
                'crossEntropy_full_features.ckpt'.format(epoch),
                time_stamp)

        # measure elapsed time for one epoch
        epoch_time.update(time.time() - start_time)
        start_time = time.time()
        recorder.plot_curve(
            os.path.join(
                args.save_path,
                'training_plot_crossEntropy_{0}.png'.format(args.manualSeed)))

    # writer.close()
    # End loop

    # Evaluate and print the results on training, testing and validation data

    test_acc, test_los, tl = validate(train_loader, net, criterion, log)
    print("Train accuracy : ")
    print(test_acc)
    print(test_los)
    print(tl)

    test_acc, test_los, tl = validate(val_loader, net, criterion, log)
    print("Val accuracy : ")
    print(test_acc)
    print(test_los)
    print(tl)

    test_acc, test_los, tl = validate(test_loader, net, criterion, log)
    print("Test accuracy : ")
    print(test_acc)
    print(test_los)
    print(tl)
コード例 #19
0
ファイル: main.py プロジェクト: syt2/my_filter_pruning
def main():
    # Init logger
    if not os.path.isdir(args.save_path):
        os.makedirs(args.save_path)
    if args.resume:
        if not os.path.isdir(args.resume):
            os.makedirs(args.resume)
    log = open(os.path.join(args.save_path, '{}.txt'.format(args.description)), 'w')
    print_log('save path : {}'.format(args.save_path), log)
    state = {k: v for k, v in args._get_kwargs()}
    print_log(state, log)
    print_log("Random Seed: {}".format(args.manualSeed), log)
    print_log("use cuda: {}".format(args.use_cuda), log)
    print_log("python version : {}".format(sys.version.replace('\n', ' ')), log)
    print_log("torch  version : {}".format(torch.__version__), log)
    print_log("cudnn  version : {}".format(torch.backends.cudnn.version()), log)
    print_log("Compress Rate: {}".format(args.rate), log)
    print_log("Epoch prune: {}".format(args.epoch_prune), log)
    print_log("description: {}".format(args.description), log)

    # Init data loader
    if args.dataset=='cifar10':
        train_loader=dataset.cifar10DataLoader(True,args.batch_size,True,args.workers)
        test_loader=dataset.cifar10DataLoader(False,args.batch_size,False,args.workers)
        num_classes=10
    elif args.dataset=='cifar100':
        train_loader=dataset.cifar100DataLoader(True,args.batch_size,True,args.workers)
        test_loader=dataset.cifar100DataLoader(False,args.batch_size,False,args.workers)
        num_classes=100
    elif args.dataset=='imagenet':
        assert False,'Do not finish imagenet code'
    else:
        assert False,'Do not support dataset : {}'.format(args.dataset)

    # Init model
    if args.arch=='cifarvgg16':
        net=models.vgg16_cifar(True,num_classes)
    elif args.arch=='resnet32':
        net=models.resnet32(num_classes)
    elif args.arch=='resnet56':
        net=models.resnet56(num_classes)
    elif args.arch=='resnet110':
        net=models.resnet110(num_classes)
    else:
        assert False,'Not finished'


    print_log("=> network:\n {}".format(net),log)
    net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))
    # define loss function (criterion) and optimizer
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), state['learning_rate'], momentum=state['momentum'],
                                weight_decay=state['decay'], nesterov=True)
    if args.use_cuda:
        net.cuda()
        criterion.cuda()

    recorder = RecorderMeter(args.epochs)
    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume+'checkpoint.pth.tar'):
            print_log("=> loading checkpoint '{}'".format(args.resume+'checkpoint.pth.tar'), log)
            checkpoint = torch.load(args.resume+'checkpoint.pth.tar')
            recorder = checkpoint['recorder']
            args.start_epoch = checkpoint['epoch']
            if args.use_state_dict:
                net.load_state_dict(checkpoint['state_dict'])
            else:
                net = checkpoint['state_dict']
            optimizer.load_state_dict(checkpoint['optimizer'])
            print_log("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']), log)

            if args.evaluate:
                time1=time.time()
                validate(test_loader,net,criterion,args.use_cuda,log)
                time2=time.time()
                print('validate function took %0.3f ms' % ((time2 - time1) * 1000.0))
                return
        else:
            print_log("=> no checkpoint found at '{}'".format(args.resume), log)
    else:
        print_log("=> not use any checkpoint for {} model".format(args.description), log)

    if args.original_train:
        original_train.args.arch=args.arch
        original_train.args.dataset=args.dataset
        original_train.main()
        return

    comp_rate=args.rate
    m=mask.Mask(net,args.use_cuda)
    print("-" * 10 + "one epoch begin" + "-" * 10)
    print("the compression rate now is %f" % comp_rate)

    val_acc_1, val_los_1 = validate(test_loader, net, criterion, args.use_cuda,log)
    print(" accu before is: %.3f %%" % val_acc_1)

    m.model=net
    print('before pruning')
    m.init_mask(comp_rate,args.last_index)
    m.do_mask()
    print('after pruning')
    m.print_weights_zero()
    net=m.model#update net

    if args.use_cuda:
        net=net.cuda()
    val_acc_2, val_los_2 = validate(test_loader, net, criterion, args.use_cuda,log)
    print(" accu after is: %.3f %%" % val_acc_2)
    #

    start_time=time.time()
    epoch_time=AverageMeter()
    for epoch in range(args.start_epoch,args.epochs):
        current_learning_rate=adjust_learning_rate(args.learning_rate,optimizer,epoch,args.gammas,args.schedule)
        need_hour, need_mins, need_secs = convert_secs2time(epoch_time.avg * (args.epochs - epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(need_hour, need_mins, need_secs)
        print_log(
            '\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f}]'.format(time_string(), epoch, args.epochs,
                                                                                   need_time, current_learning_rate) \
            + ' [Best : Accuracy={:.2f}]'.format(recorder.max_accuracy(False)), log)
        train_acc,train_los=train(train_loader,net,criterion,optimizer,epoch,args.use_cuda,log)
        validate(test_loader, net, criterion,args.use_cuda, log)
        if (epoch % args.epoch_prune == 0 or epoch == args.epochs - 1):
            m.model=net
            print('before pruning')
            m.print_weights_zero()
            m.init_mask(comp_rate,args.last_index)
            m.do_mask()
            print('after pruning')
            m.print_weights_zero()
            net=m.model
            if args.use_cuda:
                net=net.cuda()

        val_acc_2, val_los_2 = validate(test_loader, net, criterion,args.use_cuda,log)

        is_best = recorder.update(epoch, train_los, train_acc, val_los_2, val_acc_2)
        if args.resume:
            save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': net,
                'recorder': recorder,
                'optimizer': optimizer.state_dict(),
            }, is_best, args.resume, 'checkpoint.pth.tar')
        print('save ckpt done')

        epoch_time.update(time.time()-start_time)
        start_time=time.time()
    torch.save(net,args.model_save)
    # torch.save(net,args.save_path)
    flops.print_model_param_nums(net)
    flops.count_model_param_flops(net,32,False)
    log.close()
コード例 #20
0
def main():
    # Init logger
    if not os.path.isdir(args.save_path):
        os.makedirs(args.save_path)
    log = open(
        os.path.join(args.save_path,
                     'log_seed_{}.txt'.format(args.manualSeed)), 'w')
    print_log('save path : {}'.format(args.save_path), log)
    state = {k: v for k, v in args._get_kwargs()}
    print_log(state, log)
    print_log("Random Seed: {}".format(args.manualSeed), log)
    print_log("python version : {}".format(sys.version.replace('\n', ' ')),
              log)
    print_log("torch  version : {}".format(torch.__version__), log)
    print_log("cudnn  version : {}".format(torch.backends.cudnn.version()),
              log)
    print_log("Norm Pruning Rate: {}".format(args.rate_norm), log)
    print_log("Distance Pruning Rate: {}".format(args.rate_dist), log)
    print_log("Layer Begin: {}".format(args.layer_begin), log)
    print_log("Layer End: {}".format(args.layer_end), log)
    print_log("Layer Inter: {}".format(args.layer_inter), log)
    print_log("Epoch prune: {}".format(args.epoch_prune), log)
    print_log("use pretrain: {}".format(args.use_pretrain), log)
    print_log("Pretrain path: {}".format(args.pretrain_path), log)
    print_log("Dist type: {}".format(args.dist_type), log)

    # Init dataset
    if not os.path.isdir(args.data_path):
        os.makedirs(args.data_path)

    if args.dataset == 'cifar10':
        mean = [x / 255 for x in [125.3, 123.0, 113.9]]
        std = [x / 255 for x in [63.0, 62.1, 66.7]]
    elif args.dataset == 'cifar100':
        mean = [x / 255 for x in [129.3, 124.1, 112.4]]
        std = [x / 255 for x in [68.2, 65.4, 70.4]]
    else:
        assert False, "Unknow dataset : {}".format(args.dataset)

    train_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32, padding=4),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])
    test_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize(mean, std)])

    if args.dataset == 'cifar10':
        train_data = dset.CIFAR10(args.data_path,
                                  train=True,
                                  transform=train_transform,
                                  download=True)
        test_data = dset.CIFAR10(args.data_path,
                                 train=False,
                                 transform=test_transform,
                                 download=True)
        num_classes = 10
    elif args.dataset == 'cifar100':
        train_data = dset.CIFAR100(args.data_path,
                                   train=True,
                                   transform=train_transform,
                                   download=True)
        test_data = dset.CIFAR100(args.data_path,
                                  train=False,
                                  transform=test_transform,
                                  download=True)
        num_classes = 100
    elif args.dataset == 'svhn':
        train_data = dset.SVHN(args.data_path,
                               split='train',
                               transform=train_transform,
                               download=True)
        test_data = dset.SVHN(args.data_path,
                              split='test',
                              transform=test_transform,
                              download=True)
        num_classes = 10
    elif args.dataset == 'stl10':
        train_data = dset.STL10(args.data_path,
                                split='train',
                                transform=train_transform,
                                download=True)
        test_data = dset.STL10(args.data_path,
                               split='test',
                               transform=test_transform,
                               download=True)
        num_classes = 10
    elif args.dataset == 'imagenet':
        assert False, 'Do not finish imagenet code'
    else:
        assert False, 'Do not support dataset : {}'.format(args.dataset)

    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=args.workers,
                                              pin_memory=True)

    print_log("=> creating model '{}'".format(args.arch), log)
    # Init model, criterion, and optimizer
    net = models.__dict__[args.arch](num_classes)
    print_log("=> network :\n {}".format(net), log)

    net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))

    # define loss function (criterion) and optimizer
    criterion = torch.nn.CrossEntropyLoss()

    optimizer = torch.optim.SGD(net.parameters(),
                                state['learning_rate'],
                                momentum=state['momentum'],
                                weight_decay=state['decay'],
                                nesterov=True)

    if args.use_cuda:
        net.cuda()
        criterion.cuda()

    if args.use_pretrain:
        if os.path.isfile(args.pretrain_path):
            print_log(
                "=> loading pretrain model '{}'".format(args.pretrain_path),
                log)
        else:
            dir = '/data/yahe/cifar10_base/'
            # dir = '/data/uts521/yang/progress/cifar10_base/'
            whole_path = dir + 'cifar10_' + args.arch + '_base'
            args.pretrain_path = whole_path + '/checkpoint.pth.tar'
            print_log("Pretrain path: {}".format(args.pretrain_path), log)
        pretrain = torch.load(args.pretrain_path)
        if args.use_state_dict:
            net.load_state_dict(pretrain['state_dict'])
        else:
            net = pretrain['state_dict']

    recorder = RecorderMeter(args.epochs)
    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print_log("=> loading checkpoint '{}'".format(args.resume), log)
            checkpoint = torch.load(args.resume)
            recorder = checkpoint['recorder']
            args.start_epoch = checkpoint['epoch']
            if args.use_state_dict:
                net.load_state_dict(checkpoint['state_dict'])
            else:
                net = checkpoint['state_dict']

            optimizer.load_state_dict(checkpoint['optimizer'])
            print_log(
                "=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint['epoch']), log)
        else:
            print_log("=> no checkpoint found at '{}'".format(args.resume),
                      log)
    else:
        print_log(
            "=> do not use any checkpoint for {} model".format(args.arch), log)

    if args.evaluate:
        time1 = time.time()
        validate(test_loader, net, criterion, log)
        time2 = time.time()
        print('function took %0.3f ms' % ((time2 - time1) * 1000.0))
        return

    m = Mask(net, args)
    m.init_length()
    print("-" * 10 + "one epoch begin" + "-" * 10)
    print("remaining ratio of pruning : Norm is %f" % args.rate_norm)
    print("reducing ratio of pruning : Distance is %f" % args.rate_dist)
    print("total remaining ratio is %f" % (args.rate_norm - args.rate_dist))

    val_acc_1, val_los_1 = validate(test_loader, net, criterion, log)

    print(" accu before is: %.3f %%" % val_acc_1)

    m.model = net

    m.init_mask(args.rate_norm, args.rate_dist, args.dist_type)
    #    m.if_zero()
    m.do_mask()
    m.do_similar_mask()
    net = m.model
    #    m.if_zero()
    if args.use_cuda:
        net = net.cuda()
    val_acc_2, val_los_2 = validate(test_loader, net, criterion, log)
    print(" accu after is: %s %%" % val_acc_2)

    # Main loop
    start_time = time.time()
    epoch_time = AverageMeter()
    small_filter_index = []
    large_filter_index = []

    for epoch in range(args.start_epoch, args.epochs):
        current_learning_rate = adjust_learning_rate(optimizer, epoch,
                                                     args.gammas,
                                                     args.schedule)

        need_hour, need_mins, need_secs = convert_secs2time(
            epoch_time.avg * (args.epochs - epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(
            need_hour, need_mins, need_secs)

        print_log(
            '\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f}]'.format(time_string(), epoch, args.epochs,
                                                                                   need_time, current_learning_rate) \
            + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False),
                                                               100 - recorder.max_accuracy(False)), log)

        # train for one epoch
        train_acc, train_los = train(train_loader, net, criterion, optimizer,
                                     epoch, log, m)

        # evaluate on validation set
        val_acc_1, val_los_1 = validate(test_loader, net, criterion, log)
        if epoch % args.epoch_prune == 0 or epoch == args.epochs - 1:  # prune every args.epoch_prune
            m.model = net
            m.if_zero()
            m.init_mask(args.rate_norm, args.rate_dist, args.dist_type)
            m.do_mask()
            m.do_similar_mask()
            m.if_zero()
            net = m.model
            if args.use_cuda:
                net = net.cuda()

        val_acc_2, val_los_2 = validate(test_loader, net, criterion, log)

        is_best = recorder.update(epoch, train_los, train_acc, val_los_2,
                                  val_acc_2)

        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': net,
                'recorder': recorder,
                'optimizer': optimizer.state_dict(),
            }, is_best, args.save_path, 'checkpoint.pth.tar')

        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()
        recorder.plot_curve(os.path.join(args.save_path, 'curve.png'))

    log.close()
コード例 #21
0
def main():
    # Init logger
    if not os.path.isdir(args.save_path):
        os.makedirs(args.save_path)
    log = open(
        os.path.join(args.save_path,
                     'log_seed_{}.txt'.format(args.manualSeed)), 'w')
    print_log('save path : {}'.format(args.save_path), log)
    state = {k: v for k, v in args._get_kwargs()}
    print_log(state, log)
    print_log("Random Seed: {}".format(args.manualSeed), log)
    print_log("python version : {}".format(sys.version.replace('\n', ' ')),
              log)
    print_log("torch  version : {}".format(torch.__version__), log)
    print_log("cudnn  version : {}".format(torch.backends.cudnn.version()),
              log)

    # Init dataset
    if not os.path.isdir(args.data_path):
        os.makedirs(args.data_path)

    if args.dataset == 'cifar10':
        mean = [x / 255 for x in [125.3, 123.0, 113.9]]
        std = [x / 255 for x in [63.0, 62.1, 66.7]]
    elif args.dataset == 'cifar100':
        mean = [x / 255 for x in [129.3, 124.1, 112.4]]
        std = [x / 255 for x in [68.2, 65.4, 70.4]]
    elif args.dataset == 'cub_200':
        mean = [0.4856077, 0.49941534, 0.43237692]
        std = [0.23222743, 0.2277201, 0.26586822]
    else:
        assert False, "Unknow dataset : {}".format(args.dataset)

    train_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32, padding=4),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])
    test_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize(mean, std)])
    if args.dataset == 'cub_200':
        train_transform = transforms.Compose([
            transforms.RandomSizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])
        test_transform = transforms.Compose([
            transforms.Scale(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])

    if args.dataset == 'cifar10':
        train_data = dset.CIFAR10(args.data_path,
                                  train=True,
                                  transform=train_transform,
                                  download=True)
        test_data = dset.CIFAR10(args.data_path,
                                 train=False,
                                 transform=test_transform,
                                 download=True)
        num_classes = 10
    elif args.dataset == 'cifar100':
        train_data = dset.CIFAR100(args.data_path,
                                   train=True,
                                   transform=train_transform,
                                   download=True)
        test_data = dset.CIFAR100(args.data_path,
                                  train=False,
                                  transform=test_transform,
                                  download=True)
        num_classes = 100
    elif args.dataset == 'svhn':
        train_data = dset.SVHN(args.data_path,
                               split='train',
                               transform=train_transform,
                               download=True)
        test_data = dset.SVHN(args.data_path,
                              split='test',
                              transform=test_transform,
                              download=True)
        num_classes = 10
    elif args.dataset == 'stl10':
        train_data = dset.STL10(args.data_path,
                                split='train',
                                transform=train_transform,
                                download=True)
        test_data = dset.STL10(args.data_path,
                               split='test',
                               transform=test_transform,
                               download=True)
        num_classes = 10

    elif args.dataset == 'cub_200':
        train_data = MyImageFolder(args.data_path + 'train',
                                   transform=train_transform,
                                   data_cached=True)
        test_data = MyImageFolder(args.data_path + 'val',
                                  transform=test_transform,
                                  data_cached=True)
        num_classes = 200

    elif args.dataset == 'imagenet':
        assert False, 'Do not finish imagenet code'
    else:
        assert False, 'Do not support dataset : {}'.format(args.dataset)

    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=args.workers,
                                              pin_memory=True)

    print_log("=> creating model '{}'".format(args.arch), log)
    # Init model, criterion, and optimizer
    net = models.__dict__[args.arch](num_classes)
    print_log("=> network :\n {}".format(net), log)

    # define loss function (criterion) and optimizer
    criterion = torch.nn.CrossEntropyLoss()

    if args.arch == 'spatial_transform_resnext50':
        loc_params = net.loc.parameters()
        loc_params_id = list(map(id, loc_params))
        base_params = filter(lambda p: id(p) not in loc_params_id,
                             net.parameters())
        optimizer = torch.optim.SGD(
            [
                {
                    'params': base_params
                },
                {
                    'params': net.loc.parameters(),
                    'lr': 1e-4 * state['learning_rate']
                },
                #     {'params': net.stn.parameters()}
            ],
            state['learning_rate'],
            momentum=state['momentum'],
            weight_decay=state['decay'],
            nesterov=True)
        #optimizer = torch.optim.SGD(net.parameters(), state['learning_rate'], momentum=state['momentum'],
        #            weight_decay=state['decay'], nesterov=True)

    else:
        optimizer = torch.optim.SGD(net.parameters(),
                                    state['learning_rate'],
                                    momentum=state['momentum'],
                                    weight_decay=state['decay'],
                                    nesterov=True)

    recorder = RecorderMeter(args.epochs)
    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print_log("=> loading checkpoint '{}'".format(args.resume), log)
            checkpoint = torch.load(args.resume)
            #      recorder = checkpoint['recorder']
            #      args.start_epoch = checkpoint['epoch']

            state_dict = checkpoint['state_dict']
            if args.arch == 'spatial_transform_resnext50':
                ckpt_weights = list(state_dict.values())[:-2]
                m_list = net.crop_descriptors
                loc = net.loc[0]
                m_dict_keys = list(loc.state_dict().keys())
                m_dict = dict(zip(m_dict_keys, ckpt_weights))
                loc.load_state_dict(m_dict)
                print('loaded loc net')
                for m in m_list:
                    m = m[0]
                    m_dict_keys = list(m.state_dict().keys())
                    m_dict = dict(zip(m_dict_keys, ckpt_weights))
                    m.load_state_dict(m_dict)
                    print('loaded one descriptor')
            else:

                model_dict = net.state_dict()
                from_ = list(state_dict.keys())
                to = list(model_dict.keys())

                for i, k in enumerate(from_):
                    if k not in ['module.fc.weight', 'module.fc.bias']:
                        model_dict[to[i]] = state_dict[k]

                net.load_state_dict(model_dict)
        # optimizer.load_state_dict(checkpoint['optimizer'])
            print_log(
                "=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint['epoch']), log)
        else:
            print_log("=> no checkpoint found at '{}'".format(args.resume),
                      log)
    else:
        print_log(
            "=> do not use any checkpoint for {} model".format(args.arch), log)
    loc_classifier_w = net.loc[1].fc_2.weight
    classifier_w = net.classifier.weight
    net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))
    if args.use_cuda:
        net.cuda()
        criterion.cuda()

    if args.evaluate:
        validate(test_loader, net, criterion, log)
        return

    # Main loop
    start_time = time.time()
    epoch_time = AverageMeter()
    for epoch in range(args.start_epoch, args.epochs):
        # current_learning_rate = adjust_learning_rate(optimizer, epoch, args.gammas, args.schedule)
        current_learning_rate = 0
        need_hour, need_mins, need_secs = convert_secs2time(
            epoch_time.avg * (args.epochs - epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(
            need_hour, need_mins, need_secs)

        print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f}]'.format(time_string(), epoch, args.epochs, need_time, current_learning_rate) \
                    + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False), 100-recorder.max_accuracy(False)), log)

        # train for one epoch
        train_acc, train_los = train(train_loader, net, criterion, optimizer,
                                     epoch, log, loc_classifier_w,
                                     classifier_w)

        # evaluate on validation set
        #val_acc,   val_los   = extract_features(test_loader, net, criterion, log)
        val_acc, val_los = validate(test_loader, net, criterion, log)
        is_best = recorder.update(epoch, train_los, train_acc, val_los,
                                  val_acc)

        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': net.state_dict(),
                'recorder': recorder,
                'optimizer': optimizer.state_dict(),
            }, is_best, args.save_path, 'checkpoint.pth.tar')

        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()
        recorder.plot_curve(os.path.join(args.save_path, 'curve.png'))

    log.close()
コード例 #22
0
def main():
    # Init logger6
    if not os.path.isdir(args.save_path):
        os.makedirs(args.save_path)
    log = open(
        os.path.join(args.save_path,
                     'log_seed_{}.txt'.format(args.manualSeed)), 'w')
    print_log('save path : {}'.format(args.save_path), log)
    state = {k: v for k, v in args._get_kwargs()}
    print_log(state, log)
    print_log("Random Seed: {}".format(args.manualSeed), log)
    print_log("python version : {}".format(sys.version.replace('\n', ' ')),
              log)
    print_log("torch  version : {}".format(torch.__version__), log)
    print_log("cudnn  version : {}".format(torch.backends.cudnn.version()),
              log)

    # Init the tensorboard path and writer
    tb_path = os.path.join(args.save_path, 'tb_log',
                           'run_' + str(args.manualSeed))
    # logger = Logger(tb_path)
    writer = SummaryWriter(tb_path)

    # Init dataset
    if not os.path.isdir(args.data_path):
        os.makedirs(args.data_path)

    if args.dataset == 'cifar10':
        mean = [x / 255 for x in [125.3, 123.0, 113.9]]
        std = [x / 255 for x in [63.0, 62.1, 66.7]]
    elif args.dataset == 'cifar100':
        mean = [x / 255 for x in [129.3, 124.1, 112.4]]
        std = [x / 255 for x in [68.2, 65.4, 70.4]]
    elif args.dataset == 'svhn':
        mean = [0.5, 0.5, 0.5]
        std = [0.5, 0.5, 0.5]
    elif args.dataset == 'mnist':
        mean = [0.5, 0.5, 0.5]
        std = [0.5, 0.5, 0.5]
    elif args.dataset == 'imagenet':
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
    else:
        assert False, "Unknow dataset : {}".format(args.dataset)

    if args.dataset == 'imagenet':
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])
        test_transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])  # here is actually the validation dataset
    else:
        train_transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, padding=4),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])
        test_transform = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize(mean, std)])

    if args.dataset == 'mnist':
        train_data = dset.MNIST(args.data_path,
                                train=True,
                                transform=train_transform,
                                download=True)
        test_data = dset.MNIST(args.data_path,
                               train=False,
                               transform=test_transform,
                               download=True)
        num_classes = 10
    elif args.dataset == 'cifar10':
        train_data = dset.CIFAR10(args.data_path,
                                  train=True,
                                  transform=train_transform,
                                  download=True)
        test_data = dset.CIFAR10(args.data_path,
                                 train=False,
                                 transform=test_transform,
                                 download=True)
        num_classes = 10
    elif args.dataset == 'cifar100':
        train_data = dset.CIFAR100(args.data_path,
                                   train=True,
                                   transform=train_transform,
                                   download=True)
        test_data = dset.CIFAR100(args.data_path,
                                  train=False,
                                  transform=test_transform,
                                  download=True)
        num_classes = 100
    elif args.dataset == 'svhn':
        train_data = dset.SVHN(args.data_path,
                               split='train',
                               transform=train_transform,
                               download=True)
        test_data = dset.SVHN(args.data_path,
                              split='test',
                              transform=test_transform,
                              download=True)
        num_classes = 10
    elif args.dataset == 'stl10':
        train_data = dset.STL10(args.data_path,
                                split='train',
                                transform=train_transform,
                                download=True)
        test_data = dset.STL10(args.data_path,
                               split='test',
                               transform=test_transform,
                               download=True)
        num_classes = 10
    elif args.dataset == 'imagenet':
        train_dir = os.path.join(args.data_path, 'train')
        test_dir = os.path.join(args.data_path, 'val')
        train_data = dset.ImageFolder(train_dir, transform=train_transform)
        test_data = dset.ImageFolder(test_dir, transform=test_transform)
        num_classes = 1000
    else:
        assert False, 'Do not support dataset : {}'.format(args.dataset)

    train_loader = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.attack_sample_size,
        shuffle=True,
        num_workers=args.workers,
        pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=args.test_batch_size,
                                              shuffle=True,
                                              num_workers=args.workers,
                                              pin_memory=True)

    print_log("=> creating model '{}'".format(args.arch), log)

    # Init model, criterion, and optimizer
    net = models.__dict__[args.arch](num_classes)
    print_log("=> network :\n {}".format(net), log)

    if args.use_cuda:
        if args.ngpu > 1:
            net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))

    # define loss function (criterion) and optimizer
    criterion = torch.nn.CrossEntropyLoss()

    # separate the parameters thus param groups can be updated by different optimizer
    all_param = [
        param for name, param in net.named_parameters()
        if not 'step_size' in name
    ]

    step_param = [
        param for name, param in net.named_parameters() if 'step_size' in name
    ]

    if args.optimizer == "SGD":
        print("using SGD as optimizer")
        optimizer = torch.optim.SGD(all_param,
                                    lr=state['learning_rate'],
                                    momentum=state['momentum'],
                                    weight_decay=state['decay'],
                                    nesterov=True)

    elif args.optimizer == "Adam":
        print("using Adam as optimizer")
        optimizer = torch.optim.Adam(filter(lambda param: param.requires_grad,
                                            net.parameters()),
                                     lr=state['learning_rate'],
                                     weight_decay=state['decay'])

    elif args.optimizer == "RMSprop":
        print("using RMSprop as optimizer")
        optimizer = torch.optim.RMSprop(filter(
            lambda param: param.requires_grad, net.parameters()),
                                        lr=state['learning_rate'],
                                        alpha=0.99,
                                        eps=1e-08,
                                        weight_decay=0,
                                        momentum=0)

    if args.use_cuda:
        net.cuda()
        criterion.cuda()

    recorder = RecorderMeter(args.epochs)  # count number of epoches

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print_log("=> loading checkpoint '{}'".format(args.resume), log)
            checkpoint = torch.load(args.resume)
            if not (args.fine_tune):
                args.start_epoch = checkpoint['epoch']
                recorder = checkpoint['recorder']
                optimizer.load_state_dict(checkpoint['optimizer'])

            state_tmp = net.state_dict()
            if 'state_dict' in checkpoint.keys():
                state_tmp.update(checkpoint['state_dict'])
            else:
                state_tmp.update(checkpoint)

            net.load_state_dict(state_tmp)

            print_log(
                "=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, args.start_epoch), log)
        else:
            print_log("=> no checkpoint found at '{}'".format(args.resume),
                      log)
    else:
        print_log(
            "=> do not use any checkpoint for {} model".format(args.arch), log)

    # update the step_size once the model is loaded. This is used for quantization.
    for m in net.modules():
        if isinstance(m, quan_Conv2d) or isinstance(m, quan_Linear):
            # simple step size update based on the pretrained model or weight init
            m.__reset_stepsize__()

    # block for quantizer optimization
    if args.optimize_step:
        optimizer_quan = torch.optim.SGD(step_param,
                                         lr=0.01,
                                         momentum=0.9,
                                         weight_decay=0,
                                         nesterov=True)

        for m in net.modules():
            if isinstance(m, quan_Conv2d) or isinstance(m, quan_Linear):
                for i in range(
                        300
                ):  # runs 200 iterations to reduce quantization error
                    optimizer_quan.zero_grad()
                    weight_quan = quantize(m.weight, m.step_size,
                                           m.half_lvls) * m.step_size
                    loss_quan = F.mse_loss(weight_quan,
                                           m.weight,
                                           reduction='mean')
                    loss_quan.backward()
                    optimizer_quan.step()

        for m in net.modules():
            if isinstance(m, quan_Conv2d):
                print(m.step_size.data.item(),
                      (m.step_size.detach() * m.half_lvls).item(),
                      m.weight.max().item())

    # block for weight reset
    if args.reset_weight:
        for m in net.modules():
            if isinstance(m, quan_Conv2d) or isinstance(m, quan_Linear):
                m.__reset_weight__()
                # print(m.weight)

    attacker = BFA(criterion, args.k_top)
    net_clean = copy.deepcopy(net)
    # weight_conversion(net)

    if args.enable_bfa:
        perform_attack(attacker, net, net_clean, train_loader, test_loader,
                       args.n_iter, log, writer)
        return

    if args.evaluate:
        validate(test_loader, net, criterion, log)
        return

    # Main loop
    start_time = time.time()
    epoch_time = AverageMeter()

    for epoch in range(args.start_epoch, args.epochs):
        current_learning_rate, current_momentum = adjust_learning_rate(
            optimizer, epoch, args.gammas, args.schedule)
        # Display simulation time
        need_hour, need_mins, need_secs = convert_secs2time(
            epoch_time.avg * (args.epochs - epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(
            need_hour, need_mins, need_secs)

        print_log(
            '\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [LR={:6.4f}][M={:1.2f}]'.format(time_string(), epoch, args.epochs,
                                                                                   need_time, current_learning_rate,
                                                                                   current_momentum) \
            + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False),
                                                               100 - recorder.max_accuracy(False)), log)

        # train for one epoch
        train_acc, train_los = train(train_loader, net, criterion, optimizer,
                                     epoch, log)

        # evaluate on validation set
        val_acc, _, val_los = validate(test_loader, net, criterion, log)
        recorder.update(epoch, train_los, train_acc, val_los, val_acc)
        is_best = val_acc >= recorder.max_accuracy(False)

        if args.model_only:
            checkpoint_state = {'state_dict': net.state_dict}
        else:
            checkpoint_state = {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': net.state_dict(),
                'recorder': recorder,
                'optimizer': optimizer.state_dict(),
            }

        save_checkpoint(checkpoint_state, is_best, args.save_path,
                        'checkpoint.pth.tar', log)

        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()
        recorder.plot_curve(os.path.join(args.save_path, 'curve.png'))

        # save addition accuracy log for plotting
        accuracy_logger(base_dir=args.save_path,
                        epoch=epoch,
                        train_accuracy=train_acc,
                        test_accuracy=val_acc)

        # ============ TensorBoard logging ============#

        ## Log the graidents distribution
        for name, param in net.named_parameters():
            name = name.replace('.', '/')
            writer.add_histogram(name + '/grad',
                                 param.grad.clone().cpu().data.numpy(),
                                 epoch + 1,
                                 bins='tensorflow')

        # ## Log the weight and bias distribution
        for name, module in net.named_modules():
            name = name.replace('.', '/')
            class_name = str(module.__class__).split('.')[-1].split("'")[0]

            if "Conv2d" in class_name or "Linear" in class_name:
                if module.weight is not None:
                    writer.add_histogram(
                        name + '/weight/',
                        module.weight.clone().cpu().data.numpy(),
                        epoch + 1,
                        bins='tensorflow')

        writer.add_scalar('loss/train_loss', train_los, epoch + 1)
        writer.add_scalar('loss/test_loss', val_los, epoch + 1)
        writer.add_scalar('accuracy/train_accuracy', train_acc, epoch + 1)
        writer.add_scalar('accuracy/test_accuracy', val_acc, epoch + 1)
    # ============ TensorBoard logging ============#

    log.close()
コード例 #23
0
ファイル: main_dc.py プロジェクト: elliothe/tern_repo
def main():
    # Init logger6
    if not os.path.isdir(args.save_path):
        os.makedirs(args.save_path)
    log = open(
        os.path.join(args.save_path,
                     'log_seed_{}.txt'.format(args.manualSeed)), 'w')
    print_log('save path : {}'.format(args.save_path), log)
    state = {k: v for k, v in args._get_kwargs()}
    print_log(state, log)
    print_log("Random Seed: {}".format(args.manualSeed), log)
    print_log("python version : {}".format(sys.version.replace('\n', ' ')),
              log)
    print_log("torch  version : {}".format(torch.__version__), log)
    print_log("cudnn  version : {}".format(torch.backends.cudnn.version()),
              log)

    # Init the tensorboard path and writer
    tb_path = os.path.join(args.save_path, 'tb_log')
    # logger = Logger(tb_path)
    writer = SummaryWriter(tb_path)

    # Init dataset
    if not os.path.isdir(args.data_path):
        os.makedirs(args.data_path)

    if args.dataset == 'cifar10':
        mean = [x / 255 for x in [125.3, 123.0, 113.9]]
        std = [x / 255 for x in [63.0, 62.1, 66.7]]
    elif args.dataset == 'cifar100':
        mean = [x / 255 for x in [129.3, 124.1, 112.4]]
        std = [x / 255 for x in [68.2, 65.4, 70.4]]
    elif args.dataset == 'svhn':
        mean = [0.5, 0.5, 0.5]
        std = [0.5, 0.5, 0.5]
    elif args.dataset == 'mnist':
        mean = [0.5, 0.5, 0.5]
        std = [0.5, 0.5, 0.5]
    elif args.dataset == 'imagenet':
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
    else:
        assert False, "Unknow dataset : {}".format(args.dataset)

    if args.dataset == 'imagenet':
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])
        test_transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])  # here is actually the validation dataset
    else:
        train_transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, padding=4),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])
        test_transform = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize(mean, std)])

    if args.dataset == 'mnist':
        train_data = dset.MNIST(args.data_path,
                                train=True,
                                transform=train_transform,
                                download=True)
        test_data = dset.MNIST(args.data_path,
                               train=False,
                               transform=test_transform,
                               download=True)
        num_classes = 10
    elif args.dataset == 'cifar10':
        train_data = dset.CIFAR10(args.data_path,
                                  train=True,
                                  transform=train_transform,
                                  download=True)
        test_data = dset.CIFAR10(args.data_path,
                                 train=False,
                                 transform=test_transform,
                                 download=True)
        num_classes = 10
    elif args.dataset == 'cifar100':
        train_data = dset.CIFAR100(args.data_path,
                                   train=True,
                                   transform=train_transform,
                                   download=True)
        test_data = dset.CIFAR100(args.data_path,
                                  train=False,
                                  transform=test_transform,
                                  download=True)
        num_classes = 100
    elif args.dataset == 'svhn':
        train_data = dset.SVHN(args.data_path,
                               split='train',
                               transform=train_transform,
                               download=True)
        test_data = dset.SVHN(args.data_path,
                              split='test',
                              transform=test_transform,
                              download=True)
        num_classes = 10
    elif args.dataset == 'stl10':
        train_data = dset.STL10(args.data_path,
                                split='train',
                                transform=train_transform,
                                download=True)
        test_data = dset.STL10(args.data_path,
                               split='test',
                               transform=test_transform,
                               download=True)
        num_classes = 10
    elif args.dataset == 'imagenet':
        train_dir = os.path.join(args.data_path, 'train')
        test_dir = os.path.join(args.data_path, 'val')
        train_data = dset.ImageFolder(train_dir, transform=train_transform)
        test_data = dset.ImageFolder(test_dir, transform=test_transform)
        num_classes = 1000
    else:
        assert False, 'Do not support dataset : {}'.format(args.dataset)

    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=args.workers,
                                              pin_memory=True)

    print_log("=> creating model '{}'".format(args.arch), log)

    # Init model, criterion, and optimizer
    net = models.__dict__[args.arch](num_classes)
    print_log("=> network :\n {}".format(net), log)

    if args.use_cuda:
        if args.ngpu > 1:
            net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))
        else:
            net = torch.nn.DataParallel(net, device_ids=[0])

    # define loss function (criterion) and optimizer
    criterion = torch.nn.CrossEntropyLoss()

    # params without threshold
    all_param = [
        param for name, param in net.named_parameters()
        if not 'delta_th' in name
    ]

    th_param = [
        param for name, param in net.named_parameters() if 'delta_th' in name
    ]

    if args.optimizer == "SGD":
        print("using SGD as optimizer")
        optimizer = torch.optim.SGD(all_param,
                                    lr=state['learning_rate'],
                                    momentum=state['momentum'],
                                    weight_decay=state['decay'],
                                    nesterov=True)
        optimizer_th = torch.optim.SGD(th_param,
                                       lr=state['learning_rate'],
                                       momentum=state['momentum'],
                                       weight_decay=state['decay'],
                                       nesterov=True)

    elif args.optimizer == "Adam":
        print("using Adam as optimizer")
        optimizer = torch.optim.Adam(all_param,
                                     lr=state['learning_rate'],
                                     weight_decay=state['decay'])

        optimizer_th = torch.optim.SGD(th_param,
                                       lr=state['learning_rate'],
                                       momentum=state['momentum'],
                                       weight_decay=0,
                                       nesterov=True)

    elif args.optimizer == "YF":
        print("using YellowFin as optimizer")
        optimizer = YFOptimizer(filter(lambda param: param.requires_grad,
                                       net.parameters()),
                                lr=state['learning_rate'],
                                mu=state['momentum'],
                                weight_decay=state['decay'])
    # optimizer = YFOptimizer( filter(lambda param: param.requires_grad, net.parameters()) )
    elif args.optimizer == "RMSprop":
        print("using RMSprop as optimizer")
        optimizer = torch.optim.RMSprop(filter(
            lambda param: param.requires_grad, net.parameters()),
                                        lr=state['learning_rate'],
                                        alpha=0.99,
                                        eps=1e-08,
                                        weight_decay=0,
                                        momentum=0)

    if args.use_cuda:
        net.cuda()
        criterion.cuda()

    recorder = RecorderMeter(args.epochs)  # count number of epoches

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print_log("=> loading checkpoint '{}'".format(args.resume), log)
            checkpoint = torch.load(args.resume)
            if not (args.fine_tune):
                args.start_epoch = checkpoint['epoch']
                recorder = checkpoint['recorder']
                optimizer.load_state_dict(checkpoint['optimizer'])

            state_tmp = net.state_dict()
            if 'state_dict' in checkpoint.keys():
                state_tmp.update(checkpoint['state_dict'])
            else:
                state_tmp.update(checkpoint)

            net.load_state_dict(state_tmp)

            print_log(
                "=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, args.start_epoch), log)
        else:
            print_log("=> no checkpoint found at '{}'".format(args.resume),
                      log)
    else:
        print_log(
            "=> do not use any checkpoint for {} model".format(args.arch), log)

    # Right after the pretrained model is loaded:
    '''
    when model is loaded with the pre-trained model, the original
    initialized threshold are not correct anymore, which might be clipped
    by the hard-tanh function.
    '''
    for name, module in net.named_modules():
        name = name.replace('.', '/')
        class_name = str(module.__class__).split('.')[-1].split("'")[0]
        if "quanConv2d" in class_name or "quanLinear" in class_name:
            module.delta_th.data = module.weight.abs().max(
            ) * module.init_factor.cuda()

    if args.evaluate:
        validate(test_loader, net, criterion, log)
        return

    # set the graident register hook to modify the gradient (gradient clipping)
    for name, param in net.named_parameters():
        if "delta_th" in name:
            # if "delta_th" in name and 'classifier' in name:
            # based on previous experiment, the clamp interval would better range between 0.001
            param.register_hook(lambda grad: grad.clamp(min=-0.001, max=0.001))

    # Main loop
    start_time = time.time()
    epoch_time = AverageMeter()

    for epoch in range(args.start_epoch, args.epochs):
        current_learning_rate, current_momentum = adjust_learning_rate(
            optimizer, epoch, args.gammas, args.schedule)
        current_learning_rate, current_momentum = adjust_learning_rate(
            optimizer_th, epoch, args.gammas, args.schedule)

        # Display simulation time
        need_hour, need_mins, need_secs = convert_secs2time(
            epoch_time.avg * (args.epochs - epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(
            need_hour, need_mins, need_secs)

        print_log(
            '\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [LR={:6.4f}][M={:1.2f}]'.format(time_string(), epoch, args.epochs,
                                                                                   need_time, current_learning_rate,
                                                                                   current_momentum) \
            + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False),
                                                               100 - recorder.max_accuracy(False)), log)

        # ============ TensorBoard logging ============#
        # we show the model param initialization to give a intuition when we do the fine tuning

        for name, param in net.named_parameters():
            name = name.replace('.', '/')
            if "delta_th" not in name:
                writer.add_histogram(name, param.cpu().detach().numpy(), epoch)

        for name, module in net.named_modules():
            name = name.replace('.', '/')
            class_name = str(module.__class__).split('.')[-1].split("'")[0]
            if "quanConv2d" in class_name or "quanLinear" in class_name:
                sparsity = Sparsity_check(module)
                writer.add_scalar(name + '/sparsity/', sparsity, epoch)
                # writer.add_histogram(name + '/ternweight/', tern_weight.detach().numpy(), epoch + 1)

        # ============ TensorBoard logging ============#

        # train for one epoch
        train_acc, train_los = train(train_loader, net, criterion, optimizer,
                                     optimizer_th, epoch, log)

        # evaluate on validation set
        val_acc, val_los = validate(test_loader, net, criterion, log)
        recorder.update(epoch, train_los, train_acc, val_los, val_acc)
        is_best = val_acc >= recorder.max_accuracy(False)

        if args.model_only:
            checkpoint_state = {'state_dict': net.state_dict}
        else:
            checkpoint_state = {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': net.state_dict(),
                'recorder': recorder,
                'optimizer': optimizer.state_dict(),
            }

        save_checkpoint(checkpoint_state, is_best, args.save_path,
                        'checkpoint.pth.tar', log)

        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()
        recorder.plot_curve(os.path.join(args.save_path, 'curve.png'))

        # save addition accuracy log for plotting
        accuracy_logger(base_dir=args.save_path,
                        epoch=epoch,
                        train_accuracy=train_acc,
                        test_accuracy=val_acc)

        # ============ TensorBoard logging ============#

        for name, param in net.named_parameters():
            name = name.replace('.', '/')
            writer.add_histogram(name + '/grad',
                                 param.grad.cpu().detach().numpy(), epoch + 1)

        # for name, module in net.named_modules():
        #     name = name.replace('.', '/')
        #     class_name = str(module.__class__).split('.')[-1].split("'")[0]
        #     if "quanConv2d" in class_name or "quanLinear" in class_name:
        #         sparsity = Sparsity_check(module)
        #         writer.add_scalar(name + '/sparsity/', sparsity, epoch + 1)
        #         # writer.add_histogram(name + '/ternweight/', tern_weight.detach().numpy(), epoch + 1)

        for name, module in net.named_modules():
            name = name.replace('.', '/')
            class_name = str(module.__class__).split('.')[-1].split("'")[0]
            if "quanConv2d" in class_name or "quanLinear" in class_name:
                if module.delta_th.data is not None:
                    if module.delta_th.dim(
                    ) == 0:  # zero-dimension tensor (scalar) not iterable
                        writer.add_scalar(name + '/delta/',
                                          module.delta_th.detach(), epoch + 1)
                    else:
                        for idx, delta in enumerate(module.delta_th.detach()):
                            writer.add_scalar(
                                name + '/delta/' + '{}'.format(idx), delta,
                                epoch + 1)

        writer.add_scalar('loss/train_loss', train_los, epoch + 1)
        writer.add_scalar('loss/test_loss', val_los, epoch + 1)
        writer.add_scalar('accuracy/train_accuracy', train_acc, epoch + 1)
        writer.add_scalar('accuracy/test_accuracy', val_acc, epoch + 1)
    # ============ TensorBoard logging ============#

    log.close()
def main():
  # Init logger
  
  if not os.path.isdir(args.save_path):
    os.makedirs(args.save_path)
  log = open(os.path.join(args.save_path, 'log_seed_{}.txt'.format(args.manualSeed)), 'w')
  print_log('save path : {}'.format(args.save_path), log)
  state = {k: v for k, v in args._get_kwargs()}
  print_log(state, log)
  print_log("Random Seed: {}".format(args.manualSeed), log)
  print_log("python version : {}".format(sys.version.replace('\n', ' ')), log)
  print_log("torch  version : {}".format(torch.__version__), log)
  print_log("cudnn  version : {}".format(torch.backends.cudnn.version()), log)

  # Init dataset
  
  if not os.path.exists(args.data_path):
    os.makedirs(args.data_path)

  if args.dataset == 'cifar10':
    mean = [x / 255 for x in [125.3, 123.0, 113.9]]
    std = [x / 255 for x in [63.0, 62.1, 66.7]]
  elif args.dataset == 'cifar100':
    mean = [x / 255 for x in [129.3, 124.1, 112.4]]
    std = [x / 255 for x in [68.2, 65.4, 70.4]]
  else:
    assert False, "Unknow dataset : {}".format(args.dataset)

  train_transform = transforms.Compose(
    [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(),
     transforms.Normalize(mean, std)])
  test_transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize(mean, std)])

  if args.dataset == 'cifar10':
    train_data = dset.CIFAR10(args.data_path, train=True, transform=train_transform, download=True)
    test_data = dset.CIFAR10(args.data_path, train=False, transform=test_transform, download=True)
    num_classes = 10
  elif args.dataset == 'cifar100':
    train_data = dset.CIFAR100(args.data_path, train=True, transform=train_transform, download=True)
    test_data = dset.CIFAR100(args.data_path, train=False, transform=test_transform, download=True)
    num_classes = 100
  elif args.dataset == 'svhn':
    train_data = dset.SVHN(args.data_path, split='train', transform=train_transform, download=True)
    test_data = dset.SVHN(args.data_path, split='test', transform=test_transform, download=True)
    num_classes = 10
  elif args.dataset == 'stl10':
    train_data = dset.STL10(args.data_path, split='train', transform=train_transform, download=True)
    test_data = dset.STL10(args.data_path, split='test', transform=test_transform, download=True)
    num_classes = 10
  elif args.dataset == 'imagenet':
    assert False, 'Do not finish imagenet code'
  else:
    assert False, 'Do not support dataset : {}'.format(args.dataset)

  train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True,
                         num_workers=args.workers, pin_memory=True)
  test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=False,
                        num_workers=args.workers, pin_memory=True)

  print_log("=> creating model '{}'".format(args.arch), log)
  # Init model, criterion, and optimizer
  net = models.__dict__[args.arch](num_classes)
  print_log("=> network :\n {}".format(net), log)

  net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))

  # define loss function (criterion) and optimizer
  criterion = torch.nn.CrossEntropyLoss()
  optimizer = torch.optim.SGD(net.parameters(), state['learning_rate'], momentum=state['momentum'],
                weight_decay=state['decay'], nesterov=False)

  if args.use_cuda:
    net.cuda()
    criterion.cuda()

  recorder = RecorderMeter(args.epochs)
  # optionally resume from a checkpoint
  if args.resume:
    if os.path.isfile(args.resume):
      print_log("=> loading checkpoint '{}'".format(args.resume), log)
      checkpoint = torch.load(args.resume)
      recorder = checkpoint['recorder']
      args.start_epoch = checkpoint['epoch']
      net.load_state_dict(checkpoint['state_dict'])
      optimizer.load_state_dict(checkpoint['optimizer'])
      print_log("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch']), log)
    else:
      raise ValueError("=> no checkpoint found at '{}'".format(args.resume))
  else:
    print_log("=> do not use any checkpoint for {} model".format(args.arch), log)

  if args.evaluate:
    validate(test_loader, net, criterion, log)
    return

  # Main loop
  start_time = time.time()
  epoch_time = AverageMeter()
  for epoch in range(args.start_epoch, args.epochs):
    current_learning_rate = adjust_learning_rate(optimizer, epoch, args.gammas, args.schedule)

    need_hour, need_mins, need_secs = convert_secs2time(epoch_time.avg * (args.epochs-epoch))
    need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(need_hour, need_mins, need_secs)

    print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f}]'.format(time_string(), epoch, args.epochs, need_time, current_learning_rate) \
                + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False), 100-recorder.max_accuracy(False)), log)

    # train for one epoch
    train_acc, train_los = train(train_loader, net, criterion, optimizer, epoch, log)

    # evaluate on validation set
    #val_acc,   val_los   = extract_features(test_loader, net, criterion, log)
    val_acc,   val_los   = validate(test_loader, net, criterion, log)
    is_best = recorder.update(epoch, train_los, train_acc, val_los, val_acc)

    save_checkpoint({
      'epoch': epoch + 1,
      'arch': args.arch,
      'state_dict': net.state_dict(),
      'recorder': recorder,
      'optimizer' : optimizer.state_dict(),
      'args'      : copy.deepcopy(args),
    }, is_best, args.save_path, 'hb16_10check.pth.tar')

    # measure elapsed time
    epoch_time.update(time.time() - start_time)
    start_time = time.time()
    recorder.plot_curve( os.path.join(args.save_path, 'hb16_10.png') )

  log.close()
def main():

    ### transfer data from source to current node#####
    print ("Copying the dataset to the current node's  dir...")

    tmp = args.temp_dir
    home = args.home_dir


    dataset=args.dataset
    data_source_dir = os.path.join(home,'data',dataset)
    if not os.path.exists(data_source_dir):
        os.makedirs(data_source_dir)
    data_target_dir = os.path.join(tmp,'data',dataset)
    copy_tree(data_source_dir, data_target_dir)

    ### set up the experiment directories########
    exp_name=experiment_name(arch=args.arch,
                    epochs=args.epochs,
                    dropout=args.dropout,
                    batch_size=args.batch_size,
                    lr=args.learning_rate,
                    momentum=args.momentum,
                    alpha = args.alpha,
                    decay= args.decay,
                    data_aug=args.data_aug,
                    dualcutout=args.dualcutout,
                    singlecutout = args.singlecutout,
                    cutsize = args.cutsize,
                    manualSeed=args.manualSeed,
                    job_id=args.job_id,
                    add_name=args.add_name)
    temp_model_dir = os.path.join(tmp,'experiments/DualCutout/'+dataset+'/model/'+ exp_name)
    temp_result_dir = os.path.join(tmp, 'experiments/DualCutout/'+dataset+'/results/'+ exp_name)
    model_dir = os.path.join(home, 'experiments/DualCutout/'+dataset+'/model/'+ exp_name)
    result_dir = os.path.join(home, 'experiments/DualCutout/'+dataset+'/results/'+ exp_name)


    if not os.path.exists(temp_model_dir):
        os.makedirs(temp_model_dir)

    if not os.path.exists(temp_result_dir):
        os.makedirs(temp_result_dir)

    copy_script_to_folder(os.path.abspath(__file__), temp_result_dir)

    result_png_path = os.path.join(temp_result_dir, 'results.png')


    global best_acc

    log = open(os.path.join(temp_result_dir, 'log.txt'.format(args.manualSeed)), 'w')
    print_log('save path : {}'.format(temp_result_dir), log)
    state = {k: v for k, v in args._get_kwargs()}
    print_log(state, log)
    print_log("Random Seed: {}".format(args.manualSeed), log)
    print_log("python version : {}".format(sys.version.replace('\n', ' ')), log)
    print_log("torch  version : {}".format(torch.__version__), log)
    print_log("cudnn  version : {}".format(torch.backends.cudnn.version()), log)


    train_loader, test_loader,num_classes=load_data(args.data_aug, args.batch_size,args.workers,args.dataset, data_target_dir)

    print_log("=> creating model '{}'".format(args.arch), log)
    # Init model, criterion, and optimizer

    net = models.__dict__[args.arch](num_classes,args.dropout)
    print_log("=> network :\n {}".format(net), log)

    #net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))

    # define loss function (criterion) and optimizer
    criterion = torch.nn.CrossEntropyLoss()

    optimizer = torch.optim.SGD(net.parameters(), state['learning_rate'], momentum=state['momentum'],
                weight_decay=state['decay'], nesterov=True)


    cutout = Cutout(1, args.cutsize)
    if args.use_cuda:
        net.cuda()
    criterion.cuda()

    recorder = RecorderMeter(args.epochs)
    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print_log("=> loading checkpoint '{}'".format(args.resume), log)
            checkpoint = torch.load(args.resume)
            recorder = checkpoint['recorder']
            args.start_epoch = checkpoint['epoch']
            net.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            best_acc = recorder.max_accuracy(False)
            print_log("=> loaded checkpoint '{}' accuracy={} (epoch {})" .format(args.resume, best_acc, checkpoint['epoch']), log)
        else:
            print_log("=> no checkpoint found at '{}'".format(args.resume), log)
    else:
        print_log("=> do not use any checkpoint for {} model".format(args.arch), log)

    if args.evaluate:
        validate(test_loader, net, criterion, log)
        return

    # Main loop
    start_time = time.time()
    epoch_time = AverageMeter()
    # Main loop
    train_loss = []
    train_acc=[]
    test_loss=[]
    test_acc=[]
    for epoch in range(args.start_epoch, args.epochs):
        current_learning_rate = adjust_learning_rate(optimizer, epoch, args.gammas, args.schedule)

        need_hour, need_mins, need_secs = convert_secs2time(epoch_time.avg * (args.epochs-epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(need_hour, need_mins, need_secs)

        print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f}]'.format(time_string(), epoch, args.epochs, need_time, current_learning_rate) \
                + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False), 100-recorder.max_accuracy(False)), log)

        # train for one epoch
        tr_acc, tr_los = train(train_loader, net, criterion, cutout, optimizer, epoch, log)

        # evaluate on validation set
        val_acc,   val_los   = validate(test_loader, net, criterion, log)
        train_loss.append(tr_los)
        train_acc.append(tr_acc)
        test_loss.append(val_los)
        test_acc.append(val_acc)
        dummy = recorder.update(epoch, tr_los, tr_acc, val_los, val_acc)

        is_best = False
        if val_acc > best_acc:
            is_best = True
            best_acc = val_acc

        save_checkpoint({
          'epoch': epoch + 1,
          'arch': args.arch,
          'state_dict': net.state_dict(),
          'recorder': recorder,
          'optimizer' : optimizer.state_dict(),
        }, is_best, temp_model_dir, 'checkpoint.pth.tar')

        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()
        recorder.plot_curve(result_png_path)

    train_log = OrderedDict()
    train_log['train_loss'] = train_loss
    train_log['train_acc']=train_acc
    train_log['test_loss']=test_loss
    train_log['test_acc']=test_acc

    pickle.dump(train_log, open( os.path.join(temp_result_dir,'log.pkl'), 'wb'))
    plotting(temp_result_dir)

    copy_tree(temp_model_dir, model_dir)
    copy_tree(temp_result_dir, result_dir)

    rmtree(temp_model_dir)
    rmtree(temp_result_dir)

    log.close()
コード例 #26
0
ファイル: main.py プロジェクト: nivertech/applied-dl-2018
def main():
    # Init logger
    if not os.path.isdir(args.save_path):
        os.makedirs(args.save_path)
    print('Dataset: {}'.format(args.dataset.upper()))

    if args.dataset == "seedlings" or args.dataset == "bone":
        classes, class_to_idx, num_to_class, df = GenericDataset.find_classes(
            args.data_path)
    if args.dataset == "ISIC2017":
        classes, class_to_idx, num_to_class, df = GenericDataset.find_classes_melanoma(
            args.data_path)

    df.head(3)

    args.num_classes = len(classes)
    # Init model, criterion, and optimizer
    # net = models.__dict__[args.arch](num_classes)
    # net= kmodels.simpleXX_generic(num_classes=args.num_classes, imgDim=args.imgDim)
    # net= kmodels.vggnetXX_generic(num_classes=args.num_classes,  imgDim=args.imgDim)
    # net= kmodels.vggnetXX_generic(num_classes=args.num_classes,  imgDim=args.imgDim)
    net = kmodels.dpn92(num_classes=args.num_classes)
    # print_log("=> network :\n {}".format(net), log)

    real_model_name = (type(net).__name__)
    print("=> Creating model '{}'".format(real_model_name))
    import datetime

    exp_name = datetime.datetime.now().strftime(real_model_name + '_' +
                                                args.dataset +
                                                '_%Y-%m-%d_%H-%M-%S')
    print('Training ' + real_model_name +
          ' on {} dataset:'.format(args.dataset.upper()))

    mPath = args.save_path + '/' + args.dataset + '/' + real_model_name + '/'
    args.save_path_model = mPath
    if not os.path.isdir(args.save_path_model):
        os.makedirs(args.save_path_model)

    log = open(os.path.join(mPath, 'seed_{}.txt'.format(args.manualSeed)), 'w')
    print_log('save path : {}'.format(args.save_path), log)
    state = {k: v for k, v in args._get_kwargs()}
    print_log(state, log)
    print("Random Seed: {}".format(args.manualSeed))
    print("python version : {}".format(sys.version.replace('\n', ' ')))
    print("torch  version : {}".format(torch.__version__))
    print("cudnn  version : {}".format(torch.backends.cudnn.version()))

    # Init dataset
    if not os.path.isdir(args.data_path):
        os.makedirs(args.data_path)
    normalize_img = torchvision.transforms.Normalize(
        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    train_trans = transforms.Compose([
        transforms.RandomSizedCrop(args.img_scale),
        PowerPIL(),
        transforms.ToTensor(),
        # normalize_img,
        RandomErasing()
    ])

    ## Normalization only for validation and test
    valid_trans = transforms.Compose([
        transforms.Scale(256),
        transforms.CenterCrop(args.img_scale),
        transforms.ToTensor(),
        # normalize_img
    ])

    test_trans = valid_trans

    train_data = df.sample(frac=args.validationRatio)
    valid_data = df[~df['file'].isin(train_data['file'])]

    train_set = GenericDataset(train_data,
                               args.data_path,
                               transform=train_trans)
    valid_set = GenericDataset(valid_data,
                               args.data_path,
                               transform=valid_trans)

    t_loader = DataLoader(train_set,
                          batch_size=args.batch_size,
                          shuffle=True,
                          num_workers=0)
    v_loader = DataLoader(valid_set,
                          batch_size=args.batch_size,
                          shuffle=True,
                          num_workers=0)
    # test_loader  = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=4)

    dataset_sizes = {
        'train': len(t_loader.dataset),
        'valid': len(v_loader.dataset)
    }
    print(dataset_sizes)
    # net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))
    criterion = torch.nn.CrossEntropyLoss()

    # optimizer = torch.optim.SGD(net.parameters(), state['learning_rate'], momentum=state['momentum'],
    #               weight_decay=state['decay'], nesterov=True)

    # optimizer = torch.optim.Adam(net.parameters(), lr=args.learning_rate)
    optimizer = torch.optim.SGD(net.parameters(),
                                state['learning_rate'],
                                momentum=state['momentum'],
                                weight_decay=state['decay'],
                                nesterov=True)
    # optimizer = torch.optim.Adam(net.parameters(), lr=state['learning_rate'])

    if args.use_cuda:
        net.cuda()
        criterion.cuda()

    recorder = RecorderMeter(args.epochs)
    # optionally resume from a checkpoint
    if args.evaluate:
        validate(v_loader, net, criterion, log)
        return
    if args.tensorboard:
        configure("./logs/runs/%s" % (exp_name))

    print('    Total params: %.2fM' %
          (sum(p.numel() for p in net.parameters()) / 1000000.0))

    # Main loop
    start_training_time = time.time()
    training_time = time.time()
    start_time = time.time()
    epoch_time = AverageMeter()
    for epoch in tqdm(range(args.start_epoch, args.epochs)):
        current_learning_rate = adjust_learning_rate(optimizer, epoch,
                                                     args.gammas,
                                                     args.schedule)
        need_hour, need_mins, need_secs = convert_secs2time(
            epoch_time.avg * (args.epochs - epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(
            need_hour, need_mins, need_secs)
        print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f}]'.format(time_string(), epoch, args.epochs, need_time, current_learning_rate) \
    # print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f}]'.format(time_string(), epoch, args.epochs, need_time, current_learning_rate) \
                    + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False), 100-recorder.max_accuracy(False)), log)

        tqdm.write(
            '\n==>>Epoch=[{:03d}/{:03d}]], {:s}, LR=[{}], Batch=[{}]'.format(
                epoch, args.epochs, time_string(), state['learning_rate'],
                args.batch_size) + ' [Model={}]'.format(
                    (type(net).__name__), ), log)

        # train for one epoch
        train_acc, train_los = train(t_loader, net, criterion, optimizer,
                                     epoch, log)
        val_acc, val_los = validate(v_loader, net, criterion, epoch, log)
        is_best = recorder.update(epoch, train_los, train_acc, val_los,
                                  val_acc)

        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()
        training_time = time.time() - start_training_time
        recorder.plot_curve(
            os.path.join(mPath, real_model_name + '_' + exp_name + '.png'),
            training_time, net, real_model_name, dataset_sizes,
            args.batch_size, args.learning_rate, args.dataset, args.manualSeed,
            args.num_classes)

        if float(val_acc) > float(95.0):
            print("*** EARLY STOP ***")
            df_pred = testSeedlingsModel(args.test_data_path, net,
                                         num_to_class, test_trans)
            model_save_path = os.path.join(
                mPath, real_model_name + '_' + str(val_acc) + '_' +
                str(val_los) + "_" + str(epoch))

            df_pred.to_csv(model_save_path + "_sub.csv",
                           columns=('file', 'species'),
                           index=None)
            torch.save(net.state_dict(), model_save_path + '_.pth')

            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    # 'arch': args.arch,
                    'state_dict': net.state_dict(),
                    'recorder': recorder,
                    'optimizer': optimizer.state_dict(),
                },
                is_best,
                mPath,
                str(val_acc) + '_' + str(val_los) + "_" + str(epoch) +
                '_checkpoint.pth.tar')

    log.close()
コード例 #27
0
ファイル: train.py プロジェクト: ZLKong/Pruning
start_time = time.time()
epoch_time = AverageMeter()

netG.eval()
for epoch in range(0, args.epochs):
    current_learning_rate = args.learning_rate
    need_hour, need_mins, need_secs = convert_secs2time(epoch_time.avg *
                                                        (args.epochs - epoch))
    need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(need_hour, need_mins,
                                                      need_secs)

    print('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f}]'.
          format(time_string(), epoch, args.epochs, need_time,
                 current_learning_rate) +
          ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(
              recorder.max_accuracy(False), 100 -
              recorder.max_accuracy(False)))

    train_acc, train_los = train_with_kf(train_loader_kf,
                                         net,
                                         criterion,
                                         optimizer,
                                         epoch,
                                         log,
                                         kfclass=netG)

    for ikf in range(len(kfconv_list)):
        kfscale_list[ikf].append(kfconv_list[ikf].kfscale.data.clone().cpu())

    epoch_time.update(time.time() - start_time)
    start_time = time.time()
コード例 #28
0
ファイル: main.py プロジェクト: AJSVB/GPBT
class TrainCIFAR():
    def __init__(self, config):
        self.args = vars(copy.deepcopy(args))
        for key, value in config.items():
            self.args[key] = value
        # Init logger
        if not os.path.isdir(args.save_path):
            os.makedirs(args.save_path)
        # used for file names, etc

        self.time_stamp = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')

        log = open(
            os.path.join(
                args.save_path,
                'log_seed_{0}_{1}.txt'.format(args.manualSeed,
                                              self.time_stamp)), 'w')

        print_log('save path : {}'.format(args.save_path), log)
        state = {k: v for k, v in self.args.items()}
        print_log(state, log)
        print_log("Random Seed: {}".format(args.manualSeed), log)
        print_log("python version : {}".format(sys.version.replace('\n', ' ')),
                  log)
        print_log("torch  version : {}".format(torch.__version__), log)
        print_log("cudnn  version : {}".format(torch.backends.cudnn.version()),
                  log)
        # Init dataset
        if not os.path.isdir(args.data_path):
            os.makedirs(args.data_path)

        if args.dataset == 'cifar10':
            mean = [x / 255 for x in [125.3, 123.0, 113.9]]
            std = [x / 255 for x in [63.0, 62.1, 66.7]]
        elif args.dataset == 'cifar100':
            mean = [x / 255 for x in [129.3, 124.1, 112.4]]
            std = [x / 255 for x in [68.2, 65.4, 70.4]]
        else:
            assert False, "Unknow dataset : {}".format(args.dataset)

#   writer = SummaryWriter()

#   # Data transforms
# mean = [0.5071, 0.4867, 0.4408]
# std = [0.2675, 0.2565, 0.2761]

        train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])
        #[transforms.CenterCrop(32), transforms.ToTensor(),
        # transforms.Normalize(mean, std)])
        #)
        test_transform = transforms.Compose([
            transforms.CenterCrop(32),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])

        if args.dataset == 'cifar10':
            train_data = dset.CIFAR10(args.data_path,
                                      train=True,
                                      transform=train_transform,
                                      download=True)
            test_data = dset.CIFAR10(args.data_path,
                                     train=False,
                                     transform=test_transform,
                                     download=True)
            num_classes = 10
        elif args.dataset == 'cifar100':
            train_data = dset.CIFAR100(args.data_path,
                                       train=True,
                                       transform=train_transform,
                                       download=True)
            test_data = dset.CIFAR100(args.data_path,
                                      train=False,
                                      transform=test_transform,
                                      download=True)
            num_classes = 100
        elif args.dataset == 'imagenet':
            assert False, 'Did not finish imagenet code'
        else:
            assert False, 'Does not support dataset : {}'.format(args.dataset)

        from sklearn.model_selection import train_test_split
        self.train_loader = torch.utils.data.DataLoader(
            train_data,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.workers,
            pin_memory=True)
        # test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=False,
        #                       num_workers=args.workers, pin_memory=True)

        nb_test = int((.5) * len(test_data))
        test_dataset, val_dataset = torch.utils.data.dataset.random_split(
            test_data, [nb_test, nb_test])
        self.test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.workers,
            pin_memory=True)
        self.val_loader = torch.utils.data.DataLoader(
            val_dataset,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.workers,
            pin_memory=True)
        print_log("=> creating model '{}'".format(args.arch), log)
        # Init model, criterion, and optimizer
        self.net = models.__dict__[args.arch](
            num_classes, drp=self.args['drp']
        )  #,eps=self.args['eps_arch'],momentum=self.args['momentum_arch'])
        #torch.save(net, 'net.pth')
        #init_net = torch.load('net.pth')
        #net.load_my_state_dict(init_net.state_dict())
        #  print_log("=> network :\n {}".format(self.net),log)

        # self.net = torch.nn.DataParallel(self.net, device_ids=list(range(args.ngpu)))

        # define loss function (criterion) and optimizer
        self.criterion = torch.nn.CrossEntropyLoss()
        if args.use_cuda:
            self.net.cuda()
            self.criterion.cuda()
        #optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=0.005, nesterov=False)
        self.optimizer = torch.optim.Adadelta(
            self.net.parameters(),
            lr=self.args['lr'],
            rho=self.args['momentum'],
            eps=self.args['eps'],  # momentum=state['momentum'],
            weight_decay=self.args['weight_decay'])

        print_log("=> Seed '{}'".format(args.manualSeed), log)
        print_log(
            "=> dataset mean and std '{} - {}'".format(str(mean), str(std)),
            log)

        states_settings = {'optimizer': self.optimizer.state_dict()}

        #  print_log("=> optimizer '{}'".format(states_settings),log)
        # 50k,95k,153k,195k,220k

        self.recorder = RecorderMeter(args.epochs)
        # optionally resume from a checkpoint
        if args.resume:
            if os.path.isfile(args.resume):
                print_log("=> loading checkpoint '{}'".format(args.resume),
                          log)
                checkpoint = torch.load(args.resume)
                self.recorder = checkpoint['recorder']
                args.start_epoch = checkpoint['epoch']
                self.net.load_state_dict(checkpoint['state_dict'])
                self.optimizer.load_state_dict(checkpoint['optimizer'])
                print_log(
                    "=> loaded checkpoint '{}' (epoch {})".format(
                        args.resume, checkpoint['epoch']), log)
            else:
                print_log("=> no checkpoint found at '{}'".format(args.resume),
                          log)
        else:
            print_log(
                "=> did not use any checkpoint for {} model".format(args.arch),
                log)

        if args.evaluate:
            print("if we go in there we are wasting time (args.evaluate=true)")
            validate(self.test_loader, self.net, self.criterion, log)
            return
        self.i = 0
        log.close()

    def adapt(self, config):
        # print(self.net)
        self_copy = copy.deepcopy(self)
        for key, value in config.items():
            self_copy.args[key] = value

        self_copy.net.cpu()

        self_copy.net.adapt(
            self_copy.args['drp']
        )  #,self_copy.args['eps_arch'],self_copy.args['momentum_arch'])
        if args.use_cuda:
            self_copy.net.cuda()

        self_copy.optimizer = torch.optim.Adadelta(
            self_copy.net.parameters(),
            lr=self_copy.args['lr'],
            rho=self_copy.args['momentum'],
            eps=self_copy.args['eps'],  # momentum=state['momentum'],
            weight_decay=self_copy.args['weight_decay'])

        #for param_group in self_copy.optimizer.param_groups:
        #    param_group['lr'] = self_copy.args['lr']
        #    param_group['rho'] = self_copy.args['momentum']
        #    param_group['eps'] = self_copy.args['eps']
        #    param_group['weight_decay'] = self_copy.args['weight_decay']

        return self_copy

    def train1(self):
        log = open(
            os.path.join(
                args.save_path,
                'log_seed_{0}_{1}.txt'.format(args.manualSeed,
                                              self.time_stamp)), 'a')

        # train for one epoch
        train_acc, train_los = train(self.train_loader, self.net,
                                     self.criterion, self.optimizer, self.i,
                                     log)
        self.i += 1
        log.close()
        return train_acc, train_los

    def val1(self):
        log = open(
            os.path.join(
                args.save_path,
                'log_seed_{0}_{1}.txt'.format(args.manualSeed,
                                              self.time_stamp)), 'a')

        val_acc, val_los = validate(self.val_loader, self.net, self.criterion,
                                    log)
        log.close()
        return val_acc, val_los

    def test1(self):
        log = open(
            os.path.join(
                args.save_path,
                'log_seed_{0}_{1}.txt'.format(args.manualSeed,
                                              self.time_stamp)), 'a')

        val_acc, val_los = validate(self.test_loader, self.net, self.criterion,
                                    log)
        log.close()
        return val_acc, val_los

    def step(self):
        log = open(
            os.path.join(
                args.save_path,
                'log_seed_{0}_{1}.txt'.format(args.manualSeed,
                                              self.time_stamp)), 'a')

        start_time = time.time()
        epoch_time = AverageMeter()
        #current_learning_rate = adjust_learning_rate(optimizer, epoch, args.gammas, args.schedule)
        #current_learning_rate = float(self.scheduler.get_last_lr()[-1])
        #print('lr:',current_learning_rate)

        #self.scheduler.step()

        #adjust_learning_rate(optimizer, epoch)

        need_hour, need_mins, need_secs = convert_secs2time(
            epoch_time.avg * (args.epochs - self.i))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(
            need_hour, need_mins, need_secs)

        print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:.6f}]'.format(time_string(), self.i, args.epochs, need_time, self.args['lr']) \
                    + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(self.recorder.max_accuracy(False), 100-self.recorder.max_accuracy(False)),log)

        train_acc, train_los = self.train1()
        val_acc, val_los = self.val1()
        is_best = self.recorder.update(self.i - 1, train_los, train_acc,
                                       val_los, val_acc)

        #  save_checkpoint({
        #    'epoch': self.i,
        #    'arch': args.arch,
        #    'state_dict': self.net.state_dict(),
        #    'recorder': self.recorder,
        #    'optimizer' : self.optimizer.state_dict(),
        #  }, is_best, args.save_path, 'checkpoint_{0}.pth.tar'.format(self.time_stamp), self.time_stamp)

        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()
        #self.recorder.plot_curve( os.path.join(args.save_path, 'training_plot_{0}_{1}.png'.format(args.manualSeed, self.time_stamp)) )
        log.close()
        return train_acc, train_los, val_acc, val_los
コード例 #29
0
def main():
    # Init logger
    if not os.path.isdir(args.save_path):
        os.makedirs(args.save_path)
    log = open(os.path.join(args.save_path, 'log_seed_{}.txt'.format(args.manualSeed)), 'w')
    print_log('save path : {}'.format(args.save_path), log)
    state = {k: v for k, v in args._get_kwargs()}
    print_log(state, log)
    print_log("Random Seed: {}".format(args.manualSeed), log)
    print_log("CS Pruning Rate: {}".format(args.prune_rate_cs), log)
    print_log("GM Pruning Rate: {}".format(args.prune_rate_gm), log)
    print_log("Layer Begin: {}".format(args.layer_begin), log)
    print_log("Layer End: {}".format(args.layer_end), log)
    print_log("Layer Inter: {}".format(args.layer_inter), log)
    print_log("Epoch prune: {}".format(args.epoch_prune), log)
    print_log("use pretrain: {}".format(args.use_pretrain), log)
    print_log("Pretrain path: {}".format(args.pretrain_path), log)

    # Init dataset
    if not os.path.isdir(args.data_path):
        os.makedirs(args.data_path)

    n_mels = 32
    if args.input == 'mel40':
        n_mels = 40

    data_aug_transform = transforms.Compose([
        speech_transforms.ChangeAmplitude(),
        speech_transforms.ChangeSpeedAndPitchAudio(),
        speech_transforms.FixAudioLength(),
        speech_transforms.ToSTFT(),
        speech_transforms.StretchAudioOnSTFT(),
        speech_transforms.TimeshiftAudioOnSTFT(),
        speech_transforms.FixSTFTDimension()])

    parser.add_argument("--background_noise", type=str, default='datasets/gsc/train/_background_noise_',
                        help='path of background noise')

    backgroundNoisePname = os.path.join(args.data_path, 'train\_background_noise_')
    bg_dataset = gsc_utils.BackgroundNoiseDataset(backgroundNoisePname, data_aug_transform)
    add_bg_noise = speech_transforms.AddBackgroundNoiseOnSTFT(bg_dataset)

    train_feature_transform = transforms.Compose([
        speech_transforms.ToMelSpectrogramFromSTFT(n_mels=n_mels),
        speech_transforms.DeleteSTFT(),
        speech_transforms.ToTensor('mel_spectrogram', 'input')])

    train_dataset = gsc_utils.SpeechCommandsDataset(
        os.path.join(args.data_path, 'train'),
        transforms.Compose([speech_transforms.LoadAudio(),
                            data_aug_transform,
                            add_bg_noise,
                            train_feature_transform]))

    valid_feature_transform = transforms.Compose([
        speech_transforms.ToMelSpectrogram(n_mels=n_mels),
        speech_transforms.ToTensor('mel_spectrogram', 'input')])
    valid_dataset = gsc_utils.SpeechCommandsDataset(
        os.path.join(args.data_path, 'valid'),
        transforms.Compose([
            speech_transforms.LoadAudio(),
            speech_transforms.FixAudioLength(),
            valid_feature_transform]))
    test_dataset = gsc_utils.SpeechCommandsDataset(
        os.path.join(args.data_path, 'test'),
        transforms.Compose([
            speech_transforms.LoadAudio(),
            speech_transforms.FixAudioLength(),
            valid_feature_transform]),
        silence_percentage=0)


    weights = train_dataset.make_weights_for_balanced_classes()
    sampler = WeightedRandomSampler(weights, len(weights))
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, sampler=sampler,
                                               num_workers=args.workers, pin_memory=True)
    train_prune_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_prune_size,
                                                     num_workers=args.workers, pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False,
                                              num_workers=args.workers, pin_memory=True)

    num_classes = len(gsc_utils.CLASSES)



    print_log("=> creating model '{}'".format(args.arch), log)

    # Init model, criterion, and optimizer
    net = models.__dict__[args.arch](num_classes=num_classes, in_channels=1, fctMultLinLyr=64)
    print_log("=> network :\n {}".format(net), log)

    net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))

    # define loss function (criterion) and optimizer
    criterion = torch.nn.CrossEntropyLoss()

    optimizer = torch.optim.SGD(net.parameters(), state['learning_rate'], momentum=state['momentum'],
                                weight_decay=state['decay'], nesterov=True)

    if args.use_cuda:
        net.cuda()
        criterion.cuda()

    if args.use_pretrain:
        pretrain = torch.load(args.pretrain_path)
        if args.use_state_dict:
            net.load_state_dict(pretrain['state_dict'])
        else:
            net = pretrain['state_dict']

    recorder = RecorderMeter(args.epochs)

    mdlIdx2ConvIdx = [] # module index to conv filter index
    for index1, layr in enumerate(net.modules()):
        if isinstance(layr, torch.nn.Conv2d):
            mdlIdx2ConvIdx.append(index1)

    prmIdx2ConvIdx = [] # parameter index to conv filter index
    for index2, item in enumerate(net.parameters()):
        if len(item.size()) == 4:
            prmIdx2ConvIdx.append(index2)

    # set index of last layer depending on the known architecture
    if args.arch == 'resnet20':
        args.layer_end = 54
    elif args.arch == 'resnet56':
        args.layer_end = 162
    elif args.arch == 'resnet110':
        args.layer_end = 324
    else:
        pass # unkonwn architecture, use input value

    # asymptotic schedule
    total_pruning_rate = args.prune_rate_gm + args.prune_rate_cs
    compress_rates_total, scalling_factors, compress_rates_cs, compress_rates_fpgm, e2 =\
        cmpAsymptoticSchedule(theta3=total_pruning_rate, e3=args.epochs-1, tau=args.tau, theta_cs_final=args.prune_rate_cs, scaling_attn=args.scaling_attenuation) # tau=8.
    keep_rate_cs = 1. - compress_rates_cs

    if args.use_zero_scaling:
        scalling_factors = np.zeros(scalling_factors.shape)

    m = Mask(net, train_prune_loader, mdlIdx2ConvIdx, prmIdx2ConvIdx, scalling_factors, keep_rate_cs, compress_rates_fpgm, args.max_iter_cs)
    m.set_curr_epoch(0)
    m.set_epoch_cs(args.epoch_apply_cs)
    m.init_selected_filts()
    m.init_length()

    val_acc_1, val_los_1 = validate(test_loader, net, criterion, log)

    m.model = net
    m.init_mask(keep_rate_cs[0], compress_rates_fpgm[0], scalling_factors[0])
    #    m.if_zero()
    m.do_mask()
    m.do_similar_mask()
    net = m.model
    #    m.if_zero()
    if args.use_cuda:
        net = net.cuda()
    val_acc_2, val_los_2 = validate(test_loader, net, criterion, log)

    # Main loop
    start_time = time.time()
    epoch_time = AverageMeter()

    for epoch in range(args.start_epoch, args.epochs):

        current_learning_rate = adjust_learning_rate(optimizer, epoch, args.gammas, args.schedule)

        need_hour, need_mins, need_secs = convert_secs2time(epoch_time.avg * (args.epochs - epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(need_hour, need_mins, need_secs)

        print_log(
            '\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f}]'.format(time_string(), epoch, args.epochs,
                                                                                   need_time, current_learning_rate) \
            + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False),
                                                               100 - recorder.max_accuracy(False)), log)

        # train for one epoch
        train_acc, train_los = train(train_loader, net, criterion, optimizer, epoch, log, m)
        
        # evaluate on validation set
        if epoch % args.epoch_prune == 0 or epoch == args.epochs - 1:
            m.model = net
            m.set_curr_epoch(epoch)
            # m.if_zero()
            m.init_mask(keep_rate_cs[epoch], compress_rates_fpgm[epoch], scalling_factors[epoch])
            m.do_mask()
            m.do_similar_mask()
            # m.if_zero()
            net = m.model
            if args.use_cuda:
                net = net.cuda()
            if epoch == args.epochs - 1:
                m.if_zero()

        val_acc_2, val_los_2 = validate(test_loader, net, criterion, log)

        is_best = recorder.update(epoch, train_los, train_acc, val_los_2, val_acc_2)

        save_checkpoint({
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': net,
            'recorder': recorder,
            'optimizer': optimizer.state_dict(),
        }, is_best, args.save_path, 'checkpoint.pth.tar')

        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()
        recorder.plot_curve(os.path.join(args.save_path, 'curve.png'))

    log.close()