Exemplo n.º 1
0
def main():
    # Init logger
    if not os.path.isdir(args.save_path):
        os.makedirs(args.save_path)

    # Get timestamp
    time_stamp = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')

    # Random seed used
    print("Random Seed: {0}".format(args.manualSeed))
    # Python Version used
    print("python version : {}".format(sys.version.replace('\n', ' ')), log)
    # Torch Version used
    print("torch  version : {}".format(torch.__version__), log)
    # Cudnn Version used
    print("cudnn  version : {0}".format(torch.backends.cudnn.version()))

    # Path for the dataset. If not present, it is downloaded
    if not os.path.isdir(args.data_path):
        os.makedirs(args.data_path)

    # Preparing dataset
    if args.dataset == 'cifar10':
        mean = [x / 255 for x in [125.3, 123.0, 113.9]]
        std = [x / 255 for x in [63.0, 62.1, 66.7]]
    elif args.dataset == 'cifar100':
        mean = [x / 255 for x in [129.3, 124.1, 112.4]]
        std = [x / 255 for x in [68.2, 65.4, 70.4]]
    else:
        assert False, "Unknow dataset : {}".format(args.dataset)

    # Additional dataset transforms like padding, crop and flipping of images
    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])
    test_transform = transforms.Compose([
        transforms.CenterCrop(32),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])

    # Loading dataset from the path to data
    if args.dataset == 'cifar10':
        train_data = dset.CIFAR10(args.data_path,
                                  train=True,
                                  transform=train_transform,
                                  download=True)
        test_data = dset.CIFAR10(args.data_path,
                                 train=False,
                                 transform=test_transform,
                                 download=True)
        num_classes = 10
    elif args.dataset == 'cifar100':
        train_data = dset.CIFAR100(args.data_path,
                                   train=True,
                                   transform=train_transform,
                                   target_transform=None,
                                   download=True)
        test_data = dset.CIFAR100(args.data_path,
                                  train=False,
                                  transform=test_transform,
                                  target_transform=None,
                                  download=True)
        num_classes = 100
    else:
        assert False, 'Does not support dataset : {}'.format(args.dataset)

    # Splitting the training dataset into train and val sets
    num_train = len(train_data)
    indices = list(range(num_train))
    split = int(10000)

    np.random.seed(args.manualSeed)
    np.random.shuffle(indices)

    train_idx, valid_idx = indices[split:], indices[:split]
    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)

    # Loading the data into loaders
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               sampler=train_sampler,
                                               num_workers=args.workers,
                                               pin_memory=True)

    val_loader = torch.utils.data.DataLoader(train_data,
                                             batch_size=args.batch_size,
                                             sampler=valid_sampler,
                                             num_workers=args.workers,
                                             pin_memory=True)

    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=args.workers,
                                              pin_memory=True)

    # Print the architecture of the model being used
    print("=> creating model '{}'".format(args.arch), log)

    # Inutializing the model. Initializes using weights.
    net = simplenet(classes=100)

    # Using GPUs for the Neural Network
    if args.ngpu > 0:
        net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))

    # Define loss function (criterion)
    criterion = torch.nn.CrossEntropyLoss()
    # Define the optimizer used in the laerning algorithm (optimizer)
    optimizer = torch.optim.Adadelta(pars,
                                     lr=0.1,
                                     rho=0.9,
                                     eps=1e-3,
                                     weight_decay=0)

    # The list of weights that are to be learnt while training

    # This is the list of all the parameters. Used when all the layers are to be trained
    # (Comment for training only the last layer and uncomment the following command)
    pars = list(net.parameters())

    # This is the list of parameters in the last layer of the network. Used when all the layers are to be trained
    # (Uncomment for training only the last layer and the rest of the layers of the network frozen)
    # pars = list(net.module.classifier.parameters())

    states_settings = {'optimizer': optimizer.state_dict()}

    # Epochs after which when the learning rate is to be changed
    milestones = [50, 70, 100]
    # Defines how the leraning rate is changed at the above mentioned epochs.
    # 'gamma' is the factor by which the learning rate is reduced
    scheduler = lr_scheduler.MultiStepLR(optimizer, milestones, gamma=0.1)

    # Put the network and the other data on the CUDA device
    if args.use_cuda:
        net.cuda()
        criterion.cuda()
        print('__Number CUDA Devices:', torch.cuda.device_count())

    # A structure to record different results while training and evaluating
    recorder = RecorderMeter(args.epochs)

    # If this argument is given, the network loads weights from the checkpoint file given in the arguments
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)

            # If no GPU is being used, uncomment the following line
            # checkpoint = torch.load(args.resume,map_location='cpu')

            # When loading from a checkpoint with only features saved, uncomment the following line else comment it
            # net.module.features.load_state_dict(checkpoint['state_dict'])

            # When loading from a checkpoint with all the parameters saved, uncomment the following line else comment it.
            net.load_state_dict(checkpoint['state_dict'])

            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

    else:
        print("=> did not use any checkpoint for {} model".format(args.arch))

    # Opening a log, None when not using it
    log = None

    # Evaluating a model on training, testing and validation sets
    if args.evaluate:
        print("Train data : ")
        acc, loss, tloss = validate(train_loader, net, criterion, log)
        print(loss, acc, tloss)
        print("Test data : ")
        acc, loss, tloss = validate(test_loader, net, criterion, log)
        print(loss, acc, tloss)
        print("Validation data : ")
        acc, loss, tloss = validate(val_loader, net, criterion, log)
        print(loss, acc, tloss)
        return

    # Main loop

    # Start timer
    start_time = time.time()
    # Structure to record time for each epoch
    epoch_time = AverageMeter()

    # Starts training from epoch 0
    args.start_epoch = 0

    best_loss = 1000000
    best_acc = 100000
    log = None

    # Loop for training for each epoch
    for epoch in range(args.start_epoch, args.epochs):

        # Get learning rate for the epoch
        current_learning_rate = adjust_learning_rate(optimizer, epoch,
                                                     args.gammas,
                                                     args.schedule)
        current_learning_rate = float(scheduler.get_lr()[-1])

        scheduler.step()

        # Calcluate the time for the remaing number of epochs
        need_hour, need_mins, need_secs = convert_secs2time(
            epoch_time.avg * (args.epochs - epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(
            need_hour, need_mins, need_secs)

        # Train the network on the training data and return accuracy, training loss and the t-loss on the training dataset
        train_acc, train_los, tloss_train = train(train_loader, net, criterion,
                                                  optimizer, epoch, log)

        # Print after each epoch the remaining time, epoch number, max accuracy recorded on validation set so far and other details
        print('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:.6f}]'.
              format(time_string(), epoch, args.epochs, need_time,
                     current_learning_rate) +
              ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(
                  recorder.max_accuracy(False), 100 -
                  recorder.max_accuracy(False)))

        # Evaluate on the validation data and update the recorded values
        val_acc, val_los, tloss_val = validate(val_loader, net, criterion, log)
        is_best = recorder.update(epoch, train_los, train_acc, val_los,
                                  val_acc)

        # Save checkpoint after every 30 epochs
        if epoch % 30 == 29:
            save_checkpoint(
                {
                    'epoch': epoch,
                    'arch': args.arch,
                    # Save the whole network with the following line uncommented. If only features are to be saved, comment it.
                    # 'state_dict': net.state_dict(),
                    # Save only the features layers of the network with the following line uncommented. If thw whole network is to be saved, comment it.
                    'state_dict': net.module.features.state_dict(),
                    'recorder': recorder,
                    'optimizer': optimizer.state_dict(),
                    # Name of the ckeckpoint file to be saved to
                },
                False,
                args.save_path,
                'crossEntropy_full_features.ckpt'.format(epoch),
                time_stamp)

        # measure elapsed time for one epoch
        epoch_time.update(time.time() - start_time)
        start_time = time.time()
        recorder.plot_curve(
            os.path.join(
                args.save_path,
                'training_plot_crossEntropy_{0}.png'.format(args.manualSeed)))

    # writer.close()
    # End loop

    # Evaluate and print the results on training, testing and validation data

    test_acc, test_los, tl = validate(train_loader, net, criterion, log)
    print("Train accuracy : ")
    print(test_acc)
    print(test_los)
    print(tl)

    test_acc, test_los, tl = validate(val_loader, net, criterion, log)
    print("Val accuracy : ")
    print(test_acc)
    print(test_los)
    print(tl)

    test_acc, test_los, tl = validate(test_loader, net, criterion, log)
    print("Test accuracy : ")
    print(test_acc)
    print(test_los)
    print(tl)
Exemplo n.º 2
0
def main():
    # Init logger
    if not os.path.isdir(args.save_path):
        os.makedirs(args.save_path)
    log = open(
        os.path.join(args.save_path,
                     'log_seed_{}.txt'.format(args.manualSeed)), 'w')
    print_log('save path : {}'.format(args.save_path), log)
    state = {k: v for k, v in args._get_kwargs()}
    print_log(state, log)
    print_log("Random Seed: {}".format(args.manualSeed), log)
    print_log("python version : {}".format(sys.version.replace('\n', ' ')),
              log)
    print_log("torch  version : {}".format(torch.__version__), log)
    print_log("cudnn  version : {}".format(torch.backends.cudnn.version()),
              log)
    print_log("Norm Pruning Rate: {}".format(args.rate_norm), log)
    print_log("Distance Pruning Rate: {}".format(args.rate_dist), log)
    print_log("Layer Begin: {}".format(args.layer_begin), log)
    print_log("Layer End: {}".format(args.layer_end), log)
    print_log("Layer Inter: {}".format(args.layer_inter), log)
    print_log("Epoch prune: {}".format(args.epoch_prune), log)
    print_log("use pretrain: {}".format(args.use_pretrain), log)
    print_log("Pretrain path: {}".format(args.pretrain_path), log)
    print_log("Dist type: {}".format(args.dist_type), log)

    # Init dataset
    if not os.path.isdir(args.data_path):
        os.makedirs(args.data_path)

    if args.dataset == 'cifar10':
        mean = [x / 255 for x in [125.3, 123.0, 113.9]]
        std = [x / 255 for x in [63.0, 62.1, 66.7]]
    elif args.dataset == 'cifar100':
        mean = [x / 255 for x in [129.3, 124.1, 112.4]]
        std = [x / 255 for x in [68.2, 65.4, 70.4]]
    else:
        assert False, "Unknow dataset : {}".format(args.dataset)

    train_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32, padding=4),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])
    test_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize(mean, std)])

    if args.dataset == 'cifar10':
        train_data = dset.CIFAR10(args.data_path,
                                  train=True,
                                  transform=train_transform,
                                  download=True)
        test_data = dset.CIFAR10(args.data_path,
                                 train=False,
                                 transform=test_transform,
                                 download=True)
        num_classes = 10
    elif args.dataset == 'cifar100':
        train_data = dset.CIFAR100(args.data_path,
                                   train=True,
                                   transform=train_transform,
                                   download=True)
        test_data = dset.CIFAR100(args.data_path,
                                  train=False,
                                  transform=test_transform,
                                  download=True)
        num_classes = 100
    elif args.dataset == 'svhn':
        train_data = dset.SVHN(args.data_path,
                               split='train',
                               transform=train_transform,
                               download=True)
        test_data = dset.SVHN(args.data_path,
                              split='test',
                              transform=test_transform,
                              download=True)
        num_classes = 10
    elif args.dataset == 'stl10':
        train_data = dset.STL10(args.data_path,
                                split='train',
                                transform=train_transform,
                                download=True)
        test_data = dset.STL10(args.data_path,
                               split='test',
                               transform=test_transform,
                               download=True)
        num_classes = 10
    elif args.dataset == 'imagenet':
        assert False, 'Do not finish imagenet code'
    else:
        assert False, 'Do not support dataset : {}'.format(args.dataset)

    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=args.workers,
                                              pin_memory=True)

    print_log("=> creating model '{}'".format(args.arch), log)
    # Init model, criterion, and optimizer
    net = models.__dict__[args.arch](num_classes)
    print_log("=> network :\n {}".format(net), log)

    net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))

    # define loss function (criterion) and optimizer
    criterion = torch.nn.CrossEntropyLoss()

    optimizer = torch.optim.SGD(net.parameters(),
                                state['learning_rate'],
                                momentum=state['momentum'],
                                weight_decay=state['decay'],
                                nesterov=True)

    if args.use_cuda:
        net.cuda()
        criterion.cuda()

    if args.use_pretrain:
        if os.path.isfile(args.pretrain_path):
            print_log(
                "=> loading pretrain model '{}'".format(args.pretrain_path),
                log)
        else:
            dir = '/data/yahe/cifar10_base/'
            # dir = '/data/uts521/yang/progress/cifar10_base/'
            whole_path = dir + 'cifar10_' + args.arch + '_base'
            args.pretrain_path = whole_path + '/checkpoint.pth.tar'
            print_log("Pretrain path: {}".format(args.pretrain_path), log)
        pretrain = torch.load(args.pretrain_path)
        if args.use_state_dict:
            net.load_state_dict(pretrain['state_dict'])
        else:
            net = pretrain['state_dict']

    recorder = RecorderMeter(args.epochs)
    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print_log("=> loading checkpoint '{}'".format(args.resume), log)
            checkpoint = torch.load(args.resume)
            recorder = checkpoint['recorder']
            args.start_epoch = checkpoint['epoch']
            if args.use_state_dict:
                net.load_state_dict(checkpoint['state_dict'])
            else:
                net = checkpoint['state_dict']

            optimizer.load_state_dict(checkpoint['optimizer'])
            print_log(
                "=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint['epoch']), log)
        else:
            print_log("=> no checkpoint found at '{}'".format(args.resume),
                      log)
    else:
        print_log(
            "=> do not use any checkpoint for {} model".format(args.arch), log)

    if args.evaluate:
        time1 = time.time()
        validate(test_loader, net, criterion, log)
        time2 = time.time()
        print('function took %0.3f ms' % ((time2 - time1) * 1000.0))
        return

    m = Mask(net, args)
    m.init_length()
    print("-" * 10 + "one epoch begin" + "-" * 10)
    print("remaining ratio of pruning : Norm is %f" % args.rate_norm)
    print("reducing ratio of pruning : Distance is %f" % args.rate_dist)
    print("total remaining ratio is %f" % (args.rate_norm - args.rate_dist))

    val_acc_1, val_los_1 = validate(test_loader, net, criterion, log)

    print(" accu before is: %.3f %%" % val_acc_1)

    m.model = net

    m.init_mask(args.rate_norm, args.rate_dist, args.dist_type)
    #    m.if_zero()
    m.do_mask()
    m.do_similar_mask()
    net = m.model
    #    m.if_zero()
    if args.use_cuda:
        net = net.cuda()
    val_acc_2, val_los_2 = validate(test_loader, net, criterion, log)
    print(" accu after is: %s %%" % val_acc_2)

    # Main loop
    start_time = time.time()
    epoch_time = AverageMeter()
    small_filter_index = []
    large_filter_index = []

    for epoch in range(args.start_epoch, args.epochs):
        current_learning_rate = adjust_learning_rate(optimizer, epoch,
                                                     args.gammas,
                                                     args.schedule)

        need_hour, need_mins, need_secs = convert_secs2time(
            epoch_time.avg * (args.epochs - epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(
            need_hour, need_mins, need_secs)

        print_log(
            '\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f}]'.format(time_string(), epoch, args.epochs,
                                                                                   need_time, current_learning_rate) \
            + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False),
                                                               100 - recorder.max_accuracy(False)), log)

        # train for one epoch
        train_acc, train_los = train(train_loader, net, criterion, optimizer,
                                     epoch, log, m)

        # evaluate on validation set
        val_acc_1, val_los_1 = validate(test_loader, net, criterion, log)
        if epoch % args.epoch_prune == 0 or epoch == args.epochs - 1:  # prune every args.epoch_prune
            m.model = net
            m.if_zero()
            m.init_mask(args.rate_norm, args.rate_dist, args.dist_type)
            m.do_mask()
            m.do_similar_mask()
            m.if_zero()
            net = m.model
            if args.use_cuda:
                net = net.cuda()

        val_acc_2, val_los_2 = validate(test_loader, net, criterion, log)

        is_best = recorder.update(epoch, train_los, train_acc, val_los_2,
                                  val_acc_2)

        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': net,
                'recorder': recorder,
                'optimizer': optimizer.state_dict(),
            }, is_best, args.save_path, 'checkpoint.pth.tar')

        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()
        recorder.plot_curve(os.path.join(args.save_path, 'curve.png'))

    log.close()
Exemplo n.º 3
0
def main():
    global best_acc

    if not os.path.isdir(args.save_path): os.makedirs(args.save_path)
    log = open(
        os.path.join(args.save_path,
                     'log_seed_{}.txt'.format(args.manualSeed)), 'w')
    print_log('save path : {}'.format(args.save_path), log)
    state = {k: v for k, v in args._get_kwargs()}
    print_log(state, log)
    print_log("Random Seed: {}".format(args.manualSeed), log)
    print_log("Python version : {}".format(sys.version.replace('\n', ' ')),
              log)
    print_log("PyTorch  version : {}".format(torch.__version__), log)
    print_log("CuDNN  version : {}".format(torch.backends.cudnn.version()),
              log)

    if not os.path.isdir(args.data_path): os.makedirs(args.data_path)

    num_classes, train_loader, test_loader = load_dataset()
    net = load_model(num_classes, log)

    criterion = torch.nn.CrossEntropyLoss().cuda()

    params = group_weight_decay(net, state['decay'], ['coefficients'])
    optimizer = torch.optim.SGD(params,
                                state['learning_rate'],
                                momentum=state['momentum'],
                                nesterov=(state['momentum'] > 0.0))

    recorder = RecorderMeter(args.epochs)
    if args.resume:
        if args.resume == 'auto':
            args.resume = os.path.join(args.save_path, 'checkpoint.pth.tar')
        if os.path.isfile(args.resume):
            print_log("=> loading checkpoint '{}'".format(args.resume), log)
            checkpoint = torch.load(args.resume)
            recorder = checkpoint['recorder']
            recorder.refresh(args.epochs)
            args.start_epoch = checkpoint['epoch']
            net.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            best_acc = recorder.max_accuracy(False)
            print_log(
                "=> loaded checkpoint '{}' accuracy={} (epoch {})".format(
                    args.resume, best_acc, checkpoint['epoch']), log)
        else:
            print_log("=> no checkpoint found at '{}'".format(args.resume),
                      log)
    else:
        print_log(
            "=> do not use any checkpoint for {} model".format(args.arch), log)

    if args.evaluate:
        validate(test_loader, net, criterion, log)
        return

    start_time = time.time()
    epoch_time = AverageMeter()
    train_los = -1

    for epoch in range(args.start_epoch, args.epochs):
        current_learning_rate = adjust_learning_rate(optimizer, epoch,
                                                     args.gammas,
                                                     args.schedule, train_los)

        need_hour, need_mins, need_secs = convert_secs2time(
            epoch_time.avg * (args.epochs - epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(
            need_hour, need_mins, need_secs)

        print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f}]'.format(time_string(), epoch, args.epochs, need_time, current_learning_rate) \
                    + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False), 100-recorder.max_accuracy(False)), log)

        train_acc, train_los = train(train_loader, net, criterion, optimizer,
                                     epoch, log)

        val_acc, val_los = validate(test_loader, net, criterion, log)
        recorder.update(epoch, train_los, train_acc, val_los, val_acc)

        is_best = False
        if val_acc > best_acc:
            is_best = True
            best_acc = val_acc

        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': net.state_dict(),
                'recorder': recorder,
                'optimizer': optimizer.state_dict(),
            }, is_best, args.save_path, 'checkpoint.pth.tar')

        epoch_time.update(time.time() - start_time)
        start_time = time.time()
        recorder.plot_curve(result_png_path)
    log.close()
Exemplo n.º 4
0
def main_worker(gpu, ngpus_per_node, args):
    global best_acc
    args.gpu = gpu
    assert args.gpu is not None
    print("Use GPU: {} for training".format(args.gpu))

    log = open(os.path.join(args.save_path, 'log{}{}.txt'.format('_seed'+
                   str(args.manualSeed), '_eval' if args.evaluate else '')), 'w')
    log = (log, args.gpu)

    net = models.__dict__[args.arch](args, args.depth, args.wide, args.num_classes)
    print_log("Python version : {}".format(sys.version.replace('\n', ' ')), log)
    print_log("PyTorch  version : {}".format(torch.__version__), log)
    print_log("CuDNN  version : {}".format(torch.backends.cudnn.version()), log)
    print_log("Number of parameters: {}".format(sum([p.numel() for p in net.parameters()])), log)
    print_log(str(args), log)

    if args.distributed:
        if args.multiprocessing_distributed:
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url+":"+args.dist_port,
                                world_size=args.world_size, rank=args.rank)
        torch.cuda.set_device(args.gpu)
        net.cuda(args.gpu)
        args.batch_size = int(args.batch_size / ngpus_per_node)
        net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[args.gpu])
    else:
        torch.cuda.set_device(args.gpu)
        net = net.cuda(args.gpu)

    criterion = torch.nn.CrossEntropyLoss().cuda(args.gpu)

    mus, vars = [], []
    for name, param in net.named_parameters():
        if 'log_sigma' in name: vars.append(param)
        else: assert(param.requires_grad); mus.append(param)
    optimizer = torch.optim.SGD([{'params': mus, 'weight_decay': args.decay}],
                args.batch_size * args.world_size / 128 * args.learning_rate,
                momentum=args.momentum, nesterov=(args.momentum > 0.0))
    if args.bayes:
        assert(len(mus) == len(vars))
        var_optimizer = VarSGD([{'params': vars, 'weight_decay': args.decay}],
                    args.batch_size * args.world_size / 128 * args.log_sigma_lr,
                    momentum=args.momentum, nesterov=(args.momentum > 0.0),
                    num_data = args.num_data)
    else:
        assert(len(vars) == 0)
        var_optimizer = NoneOptimizer()


    recorder = RecorderMeter(args.epochs)
    if args.resume:
        if args.resume == 'auto': args.resume = os.path.join(args.save_path, 'checkpoint.pth.tar')
        if os.path.isfile(args.resume):
            print_log("=> loading checkpoint '{}'".format(args.resume), log)
            checkpoint = torch.load(args.resume, map_location='cuda:{}'.format(args.gpu))
            recorder = checkpoint['recorder']
            recorder.refresh(args.epochs)
            args.start_epoch = checkpoint['epoch']
            net.load_state_dict(checkpoint['state_dict'] if args.distributed else {k.replace('module.', ''): v for k,v in checkpoint['state_dict'].items()})
            optimizer.load_state_dict(checkpoint['optimizer'])
            var_optimizer.load_state_dict(checkpoint['var_optimizer'])
            best_acc = recorder.max_accuracy(False)
            print_log("=> loaded checkpoint '{}' accuracy={} (epoch {})" .format(args.resume, best_acc, checkpoint['epoch']), log)
        else:
            print_log("=> no checkpoint found at '{}'".format(args.resume), log)
    else:
        print_log("=> do not use any checkpoint for {} model".format(args.arch), log)

    cudnn.benchmark = True

    train_loader, test_loader = load_dataset(args)

    if args.evaluate:
        validate(test_loader, net, criterion, args, log)
        return

    start_time = time.time()
    epoch_time = AverageMeter()
    train_los = -1

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_loader.sampler.set_epoch(epoch)
        cur_lr, cur_slr = adjust_learning_rate(optimizer, var_optimizer, epoch, args)

        need_hour, need_mins, need_secs = convert_secs2time(epoch_time.avg * (args.epochs-epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(need_hour, need_mins, need_secs)

        print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f} {:6.4f}]'.format(
                                    time_string(), epoch, args.epochs, need_time, cur_lr, cur_slr) \
                    + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False), 100-recorder.max_accuracy(False)), log)

        train_acc, train_los = train(train_loader, net, criterion, optimizer, var_optimizer, epoch, args, log)
        val_acc, val_los   = validate(test_loader, net, criterion, args, log)
        recorder.update(epoch, train_los, train_acc, val_los, val_acc)

        is_best = False
        if val_acc > best_acc:
            is_best = True
            best_acc = val_acc

        if args.gpu == 0:
            save_checkpoint({
              'epoch': epoch + 1,
              'arch': args.arch,
              'state_dict': net.state_dict(),
              'recorder': recorder,
              'optimizer' : optimizer.state_dict(),
              'var_optimizer' : var_optimizer.state_dict(),
            }, is_best, args.save_path, 'checkpoint.pth.tar')

        epoch_time.update(time.time() - start_time)
        start_time = time.time()
        recorder.plot_curve(os.path.join(args.save_path, 'log.png'))

    log[0].close()
Exemplo n.º 5
0
def main():
    # Init logger
    if not os.path.isdir(args.save_path):
        os.makedirs(args.save_path)
    print('Dataset: {}'.format(args.dataset.upper()))

    if args.dataset == "seedlings" or args.dataset == "bone":
        classes, class_to_idx, num_to_class, df = GenericDataset.find_classes(
            args.data_path)
    if args.dataset == "ISIC2017":
        classes, class_to_idx, num_to_class, df = GenericDataset.find_classes_melanoma(
            args.data_path)

    df.head(3)

    args.num_classes = len(classes)
    # Init model, criterion, and optimizer
    # net = models.__dict__[args.arch](num_classes)
    # net= kmodels.simpleXX_generic(num_classes=args.num_classes, imgDim=args.imgDim)
    # net= kmodels.vggnetXX_generic(num_classes=args.num_classes,  imgDim=args.imgDim)
    # net= kmodels.vggnetXX_generic(num_classes=args.num_classes,  imgDim=args.imgDim)
    net = kmodels.dpn92(num_classes=args.num_classes)
    # print_log("=> network :\n {}".format(net), log)

    real_model_name = (type(net).__name__)
    print("=> Creating model '{}'".format(real_model_name))
    import datetime

    exp_name = datetime.datetime.now().strftime(real_model_name + '_' +
                                                args.dataset +
                                                '_%Y-%m-%d_%H-%M-%S')
    print('Training ' + real_model_name +
          ' on {} dataset:'.format(args.dataset.upper()))

    mPath = args.save_path + '/' + args.dataset + '/' + real_model_name + '/'
    args.save_path_model = mPath
    if not os.path.isdir(args.save_path_model):
        os.makedirs(args.save_path_model)

    log = open(os.path.join(mPath, 'seed_{}.txt'.format(args.manualSeed)), 'w')
    print_log('save path : {}'.format(args.save_path), log)
    state = {k: v for k, v in args._get_kwargs()}
    print_log(state, log)
    print("Random Seed: {}".format(args.manualSeed))
    print("python version : {}".format(sys.version.replace('\n', ' ')))
    print("torch  version : {}".format(torch.__version__))
    print("cudnn  version : {}".format(torch.backends.cudnn.version()))

    # Init dataset
    if not os.path.isdir(args.data_path):
        os.makedirs(args.data_path)
    normalize_img = torchvision.transforms.Normalize(
        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    train_trans = transforms.Compose([
        transforms.RandomSizedCrop(args.img_scale),
        PowerPIL(),
        transforms.ToTensor(),
        # normalize_img,
        RandomErasing()
    ])

    ## Normalization only for validation and test
    valid_trans = transforms.Compose([
        transforms.Scale(256),
        transforms.CenterCrop(args.img_scale),
        transforms.ToTensor(),
        # normalize_img
    ])

    test_trans = valid_trans

    train_data = df.sample(frac=args.validationRatio)
    valid_data = df[~df['file'].isin(train_data['file'])]

    train_set = GenericDataset(train_data,
                               args.data_path,
                               transform=train_trans)
    valid_set = GenericDataset(valid_data,
                               args.data_path,
                               transform=valid_trans)

    t_loader = DataLoader(train_set,
                          batch_size=args.batch_size,
                          shuffle=True,
                          num_workers=0)
    v_loader = DataLoader(valid_set,
                          batch_size=args.batch_size,
                          shuffle=True,
                          num_workers=0)
    # test_loader  = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=4)

    dataset_sizes = {
        'train': len(t_loader.dataset),
        'valid': len(v_loader.dataset)
    }
    print(dataset_sizes)
    # net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))
    criterion = torch.nn.CrossEntropyLoss()

    # optimizer = torch.optim.SGD(net.parameters(), state['learning_rate'], momentum=state['momentum'],
    #               weight_decay=state['decay'], nesterov=True)

    # optimizer = torch.optim.Adam(net.parameters(), lr=args.learning_rate)
    optimizer = torch.optim.SGD(net.parameters(),
                                state['learning_rate'],
                                momentum=state['momentum'],
                                weight_decay=state['decay'],
                                nesterov=True)
    # optimizer = torch.optim.Adam(net.parameters(), lr=state['learning_rate'])

    if args.use_cuda:
        net.cuda()
        criterion.cuda()

    recorder = RecorderMeter(args.epochs)
    # optionally resume from a checkpoint
    if args.evaluate:
        validate(v_loader, net, criterion, log)
        return
    if args.tensorboard:
        configure("./logs/runs/%s" % (exp_name))

    print('    Total params: %.2fM' %
          (sum(p.numel() for p in net.parameters()) / 1000000.0))

    # Main loop
    start_training_time = time.time()
    training_time = time.time()
    start_time = time.time()
    epoch_time = AverageMeter()
    for epoch in tqdm(range(args.start_epoch, args.epochs)):
        current_learning_rate = adjust_learning_rate(optimizer, epoch,
                                                     args.gammas,
                                                     args.schedule)
        need_hour, need_mins, need_secs = convert_secs2time(
            epoch_time.avg * (args.epochs - epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(
            need_hour, need_mins, need_secs)
        print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f}]'.format(time_string(), epoch, args.epochs, need_time, current_learning_rate) \
    # print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f}]'.format(time_string(), epoch, args.epochs, need_time, current_learning_rate) \
                    + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False), 100-recorder.max_accuracy(False)), log)

        tqdm.write(
            '\n==>>Epoch=[{:03d}/{:03d}]], {:s}, LR=[{}], Batch=[{}]'.format(
                epoch, args.epochs, time_string(), state['learning_rate'],
                args.batch_size) + ' [Model={}]'.format(
                    (type(net).__name__), ), log)

        # train for one epoch
        train_acc, train_los = train(t_loader, net, criterion, optimizer,
                                     epoch, log)
        val_acc, val_los = validate(v_loader, net, criterion, epoch, log)
        is_best = recorder.update(epoch, train_los, train_acc, val_los,
                                  val_acc)

        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()
        training_time = time.time() - start_training_time
        recorder.plot_curve(
            os.path.join(mPath, real_model_name + '_' + exp_name + '.png'),
            training_time, net, real_model_name, dataset_sizes,
            args.batch_size, args.learning_rate, args.dataset, args.manualSeed,
            args.num_classes)

        if float(val_acc) > float(95.0):
            print("*** EARLY STOP ***")
            df_pred = testSeedlingsModel(args.test_data_path, net,
                                         num_to_class, test_trans)
            model_save_path = os.path.join(
                mPath, real_model_name + '_' + str(val_acc) + '_' +
                str(val_los) + "_" + str(epoch))

            df_pred.to_csv(model_save_path + "_sub.csv",
                           columns=('file', 'species'),
                           index=None)
            torch.save(net.state_dict(), model_save_path + '_.pth')

            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    # 'arch': args.arch,
                    'state_dict': net.state_dict(),
                    'recorder': recorder,
                    'optimizer': optimizer.state_dict(),
                },
                is_best,
                mPath,
                str(val_acc) + '_' + str(val_los) + "_" + str(epoch) +
                '_checkpoint.pth.tar')

    log.close()
Exemplo n.º 6
0
def main():
    # Init logger
    if not os.path.isdir(args.save_path):
        os.makedirs(args.save_path)
    log = open(os.path.join(args.save_path, 'log_seed_{}.txt'.format(args.manualSeed)), 'w')
    print_log('save path : {}'.format(args.save_path), log)
    state = {k: v for k, v in args._get_kwargs()}
    print_log(state, log)
    print_log("Random Seed: {}".format(args.manualSeed), log)
    print_log("CS Pruning Rate: {}".format(args.prune_rate_cs), log)
    print_log("GM Pruning Rate: {}".format(args.prune_rate_gm), log)
    print_log("Layer Begin: {}".format(args.layer_begin), log)
    print_log("Layer End: {}".format(args.layer_end), log)
    print_log("Layer Inter: {}".format(args.layer_inter), log)
    print_log("Epoch prune: {}".format(args.epoch_prune), log)
    print_log("use pretrain: {}".format(args.use_pretrain), log)
    print_log("Pretrain path: {}".format(args.pretrain_path), log)

    # Init dataset
    if not os.path.isdir(args.data_path):
        os.makedirs(args.data_path)

    if args.dataset == 'cifar10':
        mean = [x / 255 for x in [125.3, 123.0, 113.9]]
        std = [x / 255 for x in [63.0, 62.1, 66.7]]
    else:
        assert False, "Unknow dataset : {}".format(args.dataset)

    train_transform = transforms.Compose(
        [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(),
         transforms.Normalize(mean, std)])
    test_transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize(mean, std)])

    if args.dataset == 'cifar10':
        train_data = dset.CIFAR10(args.data_path, train=True, transform=train_transform, download=True)
        test_data = dset.CIFAR10(args.data_path, train=False, transform=test_transform, download=True)
        num_classes = 10
    else:
        assert False, 'Do not support dataset : {}'.format(args.dataset)

    train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True,
                                               num_workers=args.workers, pin_memory=True)
    train_prune_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_prune_size, shuffle=False,
                                               num_workers=args.workers, pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=False,
                                              num_workers=args.workers, pin_memory=True)

    print_log("=> creating model '{}'".format(args.arch), log)

    # Init model, criterion, and optimizer
    net = models.__dict__[args.arch](num_classes)
    print_log("=> network :\n {}".format(net), log)

    net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))

    # define loss function (criterion) and optimizer
    criterion = torch.nn.CrossEntropyLoss()

    optimizer = torch.optim.SGD(net.parameters(), state['learning_rate'], momentum=state['momentum'],
                                weight_decay=state['decay'], nesterov=True)

    if args.use_cuda:
        net.cuda()
        criterion.cuda()

    if args.use_pretrain:
        pretrain = torch.load(args.pretrain_path)
        if args.use_state_dict:
            net.load_state_dict(pretrain['state_dict'])
        else:
            net = pretrain['state_dict']

    recorder = RecorderMeter(args.epochs)

    mdlIdx2ConvIdx = [] # module index to conv filter index
    for index1, layr in enumerate(net.modules()):
        if isinstance(layr, torch.nn.Conv2d):
            mdlIdx2ConvIdx.append(index1)

    prmIdx2ConvIdx = [] # parameter index to conv filter index
    for index2, item in enumerate(net.parameters()):
        if len(item.size()) == 4:
            prmIdx2ConvIdx.append(index2)

    # set index of last layer depending on the known architecture
    if args.arch == 'resnet20':
        args.layer_end = 54
    elif args.arch == 'resnet56':
        args.layer_end = 162
    elif args.arch == 'resnet110':
        args.layer_end = 324
    else:
        pass # unkonwn architecture, use input value

    # asymptotic schedule
    total_pruning_rate = args.prune_rate_gm + args.prune_rate_cs
    compress_rates_total, scalling_factors, compress_rates_cs, compress_rates_fpgm, e2 =\
        cmpAsymptoticSchedule(theta3=total_pruning_rate, e3=args.epochs-1, tau=args.tau, theta_cs_final=args.prune_rate_cs, scaling_attn=args.scaling_attenuation) # tau=8.
    keep_rate_cs = 1. - compress_rates_cs

    if args.use_zero_scaling:
        scalling_factors = np.zeros(scalling_factors.shape)

    m = Mask(net, train_prune_loader, mdlIdx2ConvIdx, prmIdx2ConvIdx, scalling_factors, keep_rate_cs, compress_rates_fpgm, args.max_iter_cs)
    m.set_curr_epoch(0)
    m.set_epoch_cs(args.epoch_apply_cs)
    m.init_selected_filts()
    m.init_length()

    val_acc_1, val_los_1 = validate(test_loader, net, criterion, log)

    m.model = net
    m.init_mask(keep_rate_cs[0], compress_rates_fpgm[0], scalling_factors[0])
    #    m.if_zero()
    m.do_mask()
    m.do_similar_mask()
    net = m.model
    #    m.if_zero()
    if args.use_cuda:
        net = net.cuda()
    val_acc_2, val_los_2 = validate(test_loader, net, criterion, log)

    # Main loop
    start_time = time.time()
    epoch_time = AverageMeter()

    for epoch in range(args.start_epoch, args.epochs):

        current_learning_rate = adjust_learning_rate(optimizer, epoch, args.gammas, args.schedule)

        need_hour, need_mins, need_secs = convert_secs2time(epoch_time.avg * (args.epochs - epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(need_hour, need_mins, need_secs)

        print_log(
            '\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f}]'.format(time_string(), epoch, args.epochs,
                                                                                   need_time, current_learning_rate) \
            + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False),
                                                               100 - recorder.max_accuracy(False)), log)

        # train for one epoch
        train_acc, train_los = train(train_loader, net, criterion, optimizer, epoch, log, m)
        
        # evaluate on validation set
        if epoch % args.epoch_prune == 0 or epoch == args.epochs - 1:
            m.model = net
            m.set_curr_epoch(epoch)
            # m.if_zero()
            m.init_mask(keep_rate_cs[epoch], compress_rates_fpgm[epoch], scalling_factors[epoch])
            m.do_mask()
            m.do_similar_mask()
            # m.if_zero()
            net = m.model
            if args.use_cuda:
                net = net.cuda()
            if epoch == args.epochs - 1:
                m.if_zero()

        val_acc_2, val_los_2 = validate(test_loader, net, criterion, log)

        is_best = recorder.update(epoch, train_los, train_acc, val_los_2, val_acc_2)

        save_checkpoint({
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': net,
            'recorder': recorder,
            'optimizer': optimizer.state_dict(),
        }, is_best, args.save_path, 'checkpoint.pth.tar')

        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()
        recorder.plot_curve(os.path.join(args.save_path, 'curve.png'))

    log.close()
def main():
    # Init logger
    if not os.path.isdir(args.save_path):
        os.makedirs(args.save_path)
    log = open(
        os.path.join(args.save_path,
                     'log_seed_{}.txt'.format(args.manualSeed)), 'w')
    print_log('save path : {}'.format(args.save_path), log)
    state = {k: v for k, v in args._get_kwargs()}
    print_log(state, log)
    print_log("Random Seed: {}".format(args.manualSeed), log)
    print_log("python version : {}".format(sys.version.replace('\n', ' ')),
              log)
    print_log("torch  version : {}".format(torch.__version__), log)
    print_log("cudnn  version : {}".format(torch.backends.cudnn.version()),
              log)
    print_log("Norm Pruning Rate: {}".format(args.rate_norm), log)
    print_log("Distance Pruning Rate: {}".format(args.rate_dist), log)
    print_log("Layer Begin: {}".format(args.layer_begin), log)
    print_log("Layer End: {}".format(args.layer_end), log)
    print_log("Layer Inter: {}".format(args.layer_inter), log)
    print_log("Epoch prune: {}".format(args.epoch_prune), log)
    print_log("use pretrain: {}".format(args.use_pretrain), log)
    print_log("Pretrain path: {}".format(args.pretrain_path), log)
    print_log("Dist type: {}".format(args.dist_type), log)

    # Init dataset
    if not os.path.isdir(args.data_path):
        os.makedirs(args.data_path)

    if args.dataset == 'cifar10':
        mean = [x / 255 for x in [125.3, 123.0, 113.9]]
        std = [x / 255 for x in [63.0, 62.1, 66.7]]
    elif args.dataset == 'cifar100':
        mean = [x / 255 for x in [129.3, 124.1, 112.4]]
        std = [x / 255 for x in [68.2, 65.4, 70.4]]
    else:
        assert False, "Unknow dataset : {}".format(args.dataset)

    train_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32, padding=4),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])
    test_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize(mean, std)])

    if args.dataset == 'cifar10':
        train_data = dset.CIFAR10(args.data_path,
                                  train=True,
                                  transform=train_transform,
                                  download=True)
        test_data = dset.CIFAR10(args.data_path,
                                 train=False,
                                 transform=test_transform,
                                 download=True)
        num_classes = 10
    elif args.dataset == 'cifar100':
        train_data = dset.CIFAR100(args.data_path,
                                   train=True,
                                   transform=train_transform,
                                   download=True)
        test_data = dset.CIFAR100(args.data_path,
                                  train=False,
                                  transform=test_transform,
                                  download=True)
        num_classes = 100

    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=args.workers,
                                              pin_memory=True)

    # Init model, criterion, and optimizer
    net = mobilenetv2.MobileNetV2()
    model_loaded = torch.load('ckpt.pth')
    mobile = model_loaded['net'].copy()
    print(type(mobile))
    net.load_state_dict(
        {k.replace('module.', ''): v
         for k, v in mobile.items()})

    print_log("=> network :\n {}".format(net), log)

    net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))

    # define loss function (criterion) and optimizer
    criterion = torch.nn.CrossEntropyLoss()

    optimizer = torch.optim.SGD(net.parameters(),
                                state['lr'],
                                momentum=state['momentum'],
                                weight_decay=state['decay'],
                                nesterov=True)

    if args.use_cuda:
        net.cuda()
        criterion.cuda()

    recorder = RecorderMeter(args.epochs)

    if args.evaluate:
        time1 = time.time()
        validate(test_loader, net, criterion, log)
        time2 = time.time()
        print('function took %0.3f ms' % ((time2 - time1) * 1000.0))
        return

    m = Mask(net)
    m.init_length()
    print("-" * 10 + "one epoch begin" + "-" * 10)
    print("remaining ratio of pruning : Norm is %f" % args.rate_norm)
    print("reducing ratio of pruning : Distance is %f" % args.rate_dist)
    print("total remaining ratio is %f" % (args.rate_norm - args.rate_dist))

    validation_accurate_1, validation_loss_1 = validate(
        test_loader, net, criterion, log)

    print(" accu before is: %.3f %%" % validation_accurate_1)

    m.model = net

    m.init_mask(args.rate_norm, args.rate_dist, args.dist_type)
    #    m.if_zero()
    m.do_mask()
    m.do_similar_mask()
    net = m.model
    #    m.if_zero()
    if args.use_cuda:
        net = net.cuda()
    validation_accurate_2, validation_loss_2 = validate(
        test_loader, net, criterion, log)
    print(" accu after is: %s %%" % validation_accurate_2)

    # Main loop
    start_time = time.time()
    epoch_time = AverageMeter()
    small_filter_idx = []
    large_filter_idx = []

    for epoch in range(args.start_epoch, args.epochs):
        current_lr = adjust_lr(optimizer, epoch, args.gammas, args.schedule)

        required_hour, required_mins, required_secs = convert_secs2time(
            epoch_time.avg * (args.epochs - epoch))
        required_time = '[required: {:02d}:{:02d}:{:02d}]'.format(
            required_hour, required_mins, required_secs)

        print_log(
            '\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [lr={:6.4f}]'.format(time_string(), epoch, args.epochs,
                                                                                   required_time, current_lr) \
            + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False),
                                                               100 - recorder.max_accuracy(False)), log)

        # train for one epoch
        train_accurate, train_loss = train(train_loader, net, criterion,
                                           optimizer, epoch, log, m)

        # evaluate on validation set
        #         validation_accurate_1, validation_loss_1 = validate(test_loader, net, criterion, log)
        if epoch % args.epoch_prune == 0 or epoch == args.epochs - 1:
            m.model = net
            m.if_zero()
            m.init_mask(args.rate_norm, args.rate_dist, args.dist_type)
            m.do_mask()
            m.do_similar_mask()
            m.if_zero()
            net = m.model
            if args.use_cuda:
                net = net.cuda()

        validation_accurate_2, validation_loss_2 = validate(
            test_loader, net, criterion, log)

        best = recorder.update(epoch, train_loss, train_accurate,
                               validation_loss_2, validation_accurate_2)

        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': net,
                'recorder': recorder,
                'optimizer': optimizer.state_dict(),
            }, best, args.save_path, 'checkpoint.pth.tar')

        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()
        recorder.plot_curve(os.path.join(args.save_path, 'curve.png'))

    start = time.time()
    validation_accurate_2, validation_loss_2 = validate(
        test_loader, net, criterion, log)
    print("Acc=%.4f\n" % (validation_accurate_2))
    print(f"Run time: {(time.time() - start):.3f} s")
    torch.save(
        net,
        'geometric_median-%s%.2f.pth' % (args.prune_method, args.rate_dist))
    log.close()
Exemplo n.º 8
0
def main_worker(gpu, ngpus_per_node, args):
    global best_acc
    args.gpu = gpu
    assert args.gpu is not None
    print("Use GPU: {} for training".format(args.gpu))

    log = open(
        os.path.join(
            args.save_path,
            'log_seed{}{}.txt'.format(args.manualSeed,
                                      '_eval' if args.evaluate else '')), 'w')
    log = (log, args.gpu)

    net = models.__dict__[args.arch](pretrained=True)
    disable_dropout(net)
    net = to_bayesian(net, args.psi_init_range)
    net.apply(unfreeze)

    print_log("Python version : {}".format(sys.version.replace('\n', ' ')),
              log)
    print_log("PyTorch  version : {}".format(torch.__version__), log)
    print_log("CuDNN  version : {}".format(torch.backends.cudnn.version()),
              log)
    print_log(
        "Number of parameters: {}".format(
            sum([p.numel() for p in net.parameters()])), log)
    print_log(str(args), log)

    if args.distributed:
        if args.multiprocessing_distributed:
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url + ":" +
                                args.dist_port,
                                world_size=args.world_size,
                                rank=args.rank)
        torch.cuda.set_device(args.gpu)
        net.cuda(args.gpu)
        args.batch_size = int(args.batch_size / ngpus_per_node)
        net = torch.nn.parallel.DistributedDataParallel(net,
                                                        device_ids=[args.gpu])
    else:
        torch.cuda.set_device(args.gpu)
        net = net.cuda(args.gpu)

    criterion = torch.nn.CrossEntropyLoss().cuda(args.gpu)

    mus, psis = [], []
    for name, param in net.named_parameters():
        if 'psi' in name: psis.append(param)
        else: mus.append(param)
    mu_optimizer = SGD(mus,
                       args.learning_rate,
                       args.momentum,
                       weight_decay=args.decay,
                       nesterov=(args.momentum > 0.0))

    psi_optimizer = PsiSGD(psis,
                           args.learning_rate,
                           args.momentum,
                           weight_decay=args.decay,
                           nesterov=(args.momentum > 0.0))

    recorder = RecorderMeter(args.epochs)
    if args.resume:
        if args.resume == 'auto':
            args.resume = os.path.join(args.save_path, 'checkpoint.pth.tar')
        if os.path.isfile(args.resume):
            print_log("=> loading checkpoint '{}'".format(args.resume), log)
            checkpoint = torch.load(args.resume,
                                    map_location='cuda:{}'.format(args.gpu))
            recorder = checkpoint['recorder']
            recorder.refresh(args.epochs)
            args.start_epoch = checkpoint['epoch']
            net.load_state_dict(
                checkpoint['state_dict'] if args.distributed else {
                    k.replace('module.', ''): v
                    for k, v in checkpoint['state_dict'].items()
                })
            mu_optimizer.load_state_dict(checkpoint['mu_optimizer'])
            psi_optimizer.load_state_dict(checkpoint['psi_optimizer'])
            best_acc = recorder.max_accuracy(False)
            print_log(
                "=> loaded checkpoint '{}' accuracy={} (epoch {})".format(
                    args.resume, best_acc, checkpoint['epoch']), log)
        else:
            print_log("=> no checkpoint found at '{}'".format(args.resume),
                      log)
    else:
        print_log("=> do not use any checkpoint for the model", log)

    cudnn.benchmark = True

    train_loader, ood_train_loader, test_loader, adv_loader, \
        fake_loader, adv_loader2 = load_dataset_ft(args)
    psi_optimizer.num_data = len(train_loader.dataset)

    if args.evaluate:
        evaluate(test_loader, adv_loader, fake_loader, adv_loader2, net,
                 criterion, args, log, 20, 100)
        return

    start_time = time.time()
    epoch_time = AverageMeter()
    train_los = -1

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_loader.sampler.set_epoch(epoch)
            ood_train_loader.sampler.set_epoch(epoch)
        cur_lr, cur_slr = adjust_learning_rate(mu_optimizer, psi_optimizer,
                                               epoch, args)

        need_hour, need_mins, need_secs = convert_secs2time(
            epoch_time.avg * (args.epochs - epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(
            need_hour, need_mins, need_secs)

        print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f} {:6.4f}]'.format(
                                    time_string(), epoch, args.epochs, need_time, cur_lr, cur_slr) \
                    + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False), 100-recorder.max_accuracy(False)), log)

        train_acc, train_los = train(train_loader, ood_train_loader, net,
                                     criterion, mu_optimizer, psi_optimizer,
                                     epoch, args, log)
        val_acc, val_los = 0, 0
        recorder.update(epoch, train_los, train_acc, val_acc, val_los)

        is_best = False
        if val_acc > best_acc:
            is_best = True
            best_acc = val_acc

        if args.gpu == 0:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': net.state_dict(),
                    'recorder': recorder,
                    'mu_optimizer': mu_optimizer.state_dict(),
                    'psi_optimizer': psi_optimizer.state_dict(),
                }, False, args.save_path, 'checkpoint.pth.tar')

        epoch_time.update(time.time() - start_time)
        start_time = time.time()
        recorder.plot_curve(os.path.join(args.save_path, 'log.png'))

    evaluate(test_loader, adv_loader, fake_loader, adv_loader2, net, criterion,
             args, log, 20, 100)

    log[0].close()
Exemplo n.º 9
0
def main():
  if not os.path.isdir(args.save_path): os.makedirs(args.save_path)
  log = open(os.path.join(args.save_path, 'log_seed_{}.txt'.format(args.manualSeed)), 'w')
  print_log('save path : {}'.format(args.save_path), log)
  state = {k: v for k, v in args._get_kwargs()}
  print_log(state, log)
  print_log("Random Seed: {}".format(args.manualSeed), log)
  print_log("python version : {}".format(sys.version.replace('\n', ' ')), log)
  print_log("torch  version : {}".format(torch.__version__), log)
  print_log("cudnn  version : {}".format(torch.backends.cudnn.version()), log)

  # Init dataset
  if not os.path.isdir(args.data_path):
    os.makedirs(args.data_path)

  if args.dataset == 'cifar10':
    mean = [x / 255 for x in [125.3, 123.0, 113.9]]
    std = [x / 255 for x in [63.0, 62.1, 66.7]]
  elif args.dataset == 'cifar100':
    mean = [x / 255 for x in [129.3, 124.1, 112.4]]
    std = [x / 255 for x in [68.2, 65.4, 70.4]]
  else:
    assert False, "Unknow dataset : {}".format(args.dataset)

  train_transform = transforms.Compose(
    [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(),
     transforms.Normalize(mean, std)])
  test_transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize(mean, std)])

  if args.dataset == 'cifar10':
    train_data = dset.CIFAR10(args.data_path, train=True, transform=train_transform, download=True)
    test_data = dset.CIFAR10(args.data_path, train=False, transform=test_transform, download=True)
    num_classes = 10
  elif args.dataset == 'cifar100':
    train_data = dset.CIFAR100(args.data_path, train=True, transform=train_transform, download=True)
    test_data = dset.CIFAR100(args.data_path, train=False, transform=test_transform, download=True)
    num_classes = 100
  elif args.dataset == 'svhn':
    train_data = dset.SVHN(args.data_path, split='train', transform=train_transform, download=True)
    test_data = dset.SVHN(args.data_path, split='test', transform=test_transform, download=True)
    num_classes = 10
  elif args.dataset == 'stl10':
    train_data = dset.STL10(args.data_path, split='train', transform=train_transform, download=True)
    test_data = dset.STL10(args.data_path, split='test', transform=test_transform, download=True)
    num_classes = 10
  elif args.dataset == 'imagenet':
    assert False, 'Do not finish imagenet code'
  else:
    assert False, 'Do not support dataset : {}'.format(args.dataset)

  train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True,
                         num_workers=args.workers, pin_memory=True)
  test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=False,
                        num_workers=args.workers, pin_memory=True)

  # Init model, criterion, and optimizer
  #net = models.__dict__[args.arch](num_classes).cuda()
  net = SENet34()

  # define loss function (criterion) and optimizer
  criterion = F.nll_loss
  optimizer = torch.optim.SGD(net.parameters(), state['learning_rate'], momentum=state['momentum'],
                weight_decay=state['decay'], nesterov=True)

  if args.use_cuda: net.cuda()

  recorder = RecorderMeter(args.epochs)
  # optionally resume from a checkpoint
  if args.resume:
    if os.path.isfile(args.resume):
      print_log("=> loading checkpoint '{}'".format(args.resume), log)
      checkpoint = torch.load(args.resume)
      recorder = checkpoint['recorder']
      args.start_epoch = checkpoint['epoch']
      net.load_state_dict(checkpoint['state_dict'])
      optimizer.load_state_dict(checkpoint['optimizer'])
      print_log("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch']), log)
    else:
      print_log("=> no checkpoint found at '{}'".format(args.resume), log)
  else:
    print_log("=> do not use any checkpoint for model", log)

  if args.evaluate:
    validate(test_loader, net, criterion, log)
    return

  # Main loop
  start_time = time.time()
  epoch_time = AverageMeter()
  for epoch in range(args.start_epoch, args.epochs):
    current_learning_rate = adjust_learning_rate(optimizer, epoch, args.gammas, args.schedule)

    need_hour, need_mins, need_secs = convert_secs2time(epoch_time.avg * (args.epochs-epoch))
    need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(need_hour, need_mins, need_secs)

    print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f}]'.format(time_string(), epoch, args.epochs, need_time, current_learning_rate) \
                + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False), 100-recorder.max_accuracy(False)), log)

    # train for one epoch
    train_acc, train_los = train(train_loader, net, criterion, optimizer, epoch, log)

    # evaluate on validation set
    val_acc,   val_los   = validate(test_loader, net, criterion, log)
    is_best = recorder.update(epoch, train_los, train_acc, val_los, val_acc)

    save_checkpoint({
      'epoch': epoch + 1,
      'state_dict': net.state_dict(),
      'recorder': recorder,
      'optimizer' : optimizer.state_dict(),
    }, is_best, args.save_path, 'checkpoint.pth.tar')

    # measure elapsed time
    epoch_time.update(time.time() - start_time)
    start_time = time.time()
    recorder.plot_curve( os.path.join(args.save_path, 'curve.png') )

  log.close()
Exemplo n.º 10
0
def main():
  # Init logger
  if not os.path.isdir(args.save_path):
    os.makedirs(args.save_path)

  # used for file names, etc 
  time_stamp = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')   
  log = open(os.path.join(args.save_path, '{0}_{1}_{2}.txt'.format(args.arch, args.dataset, time_stamp)), 'w')
  print_log('save path : {}'.format(args.save_path), log)
  state = {k: v for k, v in args._get_kwargs()}
  print_log(state, log)
  print_log("Random Seed: {}".format(args.manualSeed), log)
  print_log("python version : {}".format(sys.version.replace('\n', ' ')), log)
  print_log("torch  version : {}".format(torch.__version__), log)
  print_log("cudnn  version : {}".format(torch.backends.cudnn.version()), log)

  # Init dataset
  if not os.path.isdir(args.data_path):
    os.makedirs(args.data_path)


  if args.dataset == 'cifar10':
    mean = [x / 255 for x in [125.3, 123.0, 113.9]]
    std = [x / 255 for x in [63.0, 62.1, 66.7]]
  elif args.dataset == 'cifar100':
    mean = [x / 255 for x in [129.3, 124.1, 112.4]]
    std = [x / 255 for x in [68.2, 65.4, 70.4]]
  else:
    assert False, "Unknow dataset : {}".format(args.dataset)



  writer = SummaryWriter()


  train_transform = transforms.Compose(
    [ transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(), transforms.ToTensor(),
     transforms.Normalize(mean, std)])

  test_transform = transforms.Compose(
    [transforms.CenterCrop(32), transforms.ToTensor(), transforms.Normalize(mean, std)])


  if args.dataset == 'cifar10':
    train_data = dset.CIFAR10(args.data_path, train=True, transform=train_transform, download=True)
    test_data = dset.CIFAR10(args.data_path, train=False, transform=test_transform, download=True)
    num_classes = 10
  elif args.dataset == 'cifar100':
    train_data = dset.CIFAR100(args.data_path, train=True, transform=train_transform, download=True)
    test_data = dset.CIFAR100(args.data_path, train=False, transform=test_transform, download=True)
    num_classes = 100
  elif args.dataset == 'svhn':
    train_data = dset.SVHN(args.data_path, split='train', transform=train_transform, download=True)
    train_data_ext = dset.SVHN(args.data_path, split='extra', transform=train_transform, download=True)
    test_data = dset.SVHN(args.data_path, split='test', transform=test_transform, download=True)
    num_classes = 10
  elif args.dataset == 'stl10':
    train_data = dset.STL10(args.data_path, split='train', transform=train_transform, download=True)
    test_data = dset.STL10(args.data_path, split='test', transform=test_transform, download=True)
    num_classes = 10
  elif args.dataset == 'imagenet':
    assert False, 'Do not finish imagenet code'
  else:
    assert False, 'Do not support dataset : {}'.format(args.dataset)

  train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True,
                         num_workers=args.workers, pin_memory=True)
  test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=False,
                        num_workers=args.workers, pin_memory=True)

  print_log("=> creating model '{}'".format(args.arch), log)
  # Init model, criterion, and optimizer
  net = models.__dict__[args.arch](num_classes)
  print_log("=> network :\n {}".format(net), log)


  net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))

  # define loss function (criterion) and optimizer
  criterion = torch.nn.CrossEntropyLoss()

  #optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=0.005, nesterov=False)
  optimizer = torch.optim.Adadelta(net.parameters(), lr=0.1, rho=0.9, eps=1e-3, weight_decay=0.001)


  print_log("=> Seed '{}'".format(args.manualSeed), log)
  print_log("=> dataset mean and std '{} - {}'".format(str(mean), str(std)), log)
  
  states_settings = {
                     'optimizer': optimizer.state_dict()
                    }


  print_log("=> optimizer '{}'".format(states_settings), log)
  # 50k, 95k, 153k,195k, 220k 
  milestones = [100, 190, 306, 390, 440, 500]
  scheduler = lr_scheduler.MultiStepLR(optimizer, milestones, gamma=0.1)

  if args.use_cuda:
    net.cuda()
    criterion.cuda()

  recorder = RecorderMeter(args.epochs)
  # optionally resume from a checkpoint
  if args.resume:
    if os.path.isfile(args.resume):
      print_log("=> loading checkpoint '{}'".format(args.resume), log)
      checkpoint = torch.load(args.resume)
      recorder = checkpoint['recorder']
      args.start_epoch = checkpoint['epoch']
      net.load_state_dict(checkpoint['state_dict'])
      scheduler.load_state_dict(checkpoint['scheduler'])
      optimizer.load_state_dict(checkpoint['optimizer'])
      print_log("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch']), log)
    else:
      print_log("=> no checkpoint found at '{}'".format(args.resume), log)
  else:
    print_log("=> did not use any checkpoint for {} model".format(args.arch), log)

  if args.evaluate:
    validate(test_loader, net, criterion, log)
    return

  # Main loop
  start_time = time.time()
  epoch_time = AverageMeter()

  for epoch in range(args.start_epoch, args.epochs):
    #current_learning_rate = adjust_learning_rate(optimizer, epoch, args.gammas, args.schedule)
    current_learning_rate = float(scheduler.get_lr()[-1])

    scheduler.step()

    #adjust_learning_rate(optimizer, epoch)

    need_hour, need_mins, need_secs = convert_secs2time(epoch_time.avg * (args.epochs-epoch))
    need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(need_hour, need_mins, need_secs)

    print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:.6f}]'.format(time_string(), epoch, args.epochs, need_time, current_learning_rate) \
                + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False), 100-recorder.max_accuracy(False)), log)

    # train for one epoch
    train_acc, train_los = train(train_loader, net, criterion, optimizer, epoch, log)

    # evaluate on validation set
    #val_acc,   val_los   = extract_features(test_loader, net, criterion, log)
    val_acc,   val_los   = validate(test_loader, net, criterion, log)
    is_best = recorder.update(epoch, train_los, train_acc, val_los, val_acc)


    writer.add_scalar('training/loss', train_los, epoch)
    writer.add_scalar('training/acc', train_acc, epoch)

    writer.add_scalar('validation/loss', val_los, epoch)
    writer.add_scalar('validation/acc', val_acc, epoch)

   

    if epoch == 180:
        save_checkpoint({
          'epoch': epoch ,
          'arch': args.arch,
          'state_dict': net.state_dict(),
          'recorder': recorder,
          'optimizer' : optimizer.state_dict(),
          'scheduler' : scheduler.state_dict(),
        }, False, args.save_path, 'chkpt_{0}_{1}_{2}_{3}.pth.tar'.format(args.arch, args.dataset, epoch, time_stamp))

    save_checkpoint({
      'epoch': epoch + 1,
      'arch': args.arch,
      'state_dict': net.state_dict(),
      'recorder': recorder,
      'optimizer' : optimizer.state_dict(),
      'scheduler' : scheduler.state_dict(),
    }, is_best, args.save_path, 'chkpt_{0}_{1}_{2}.pth.tar'.format(args.arch, args.dataset, time_stamp))


    epoch_time.update(time.time() - start_time)
    start_time = time.time()
    recorder.plot_curve( os.path.join(args.save_path, 'plot_{0}_{1}_{2}.png'.format(args.arch, args.dataset, time_stamp)) )

  writer.close()
  log.close()
                                  val_acc, args.epochs)

        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': net.state_dict(),
                'recorder': recorder,
                'optimizer': optimizer.state_dict(),
                'args': copy.deepcopy(args),
            }, is_best, args.save_path, 'checkpoint4.pth')

        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()
        recorder.plot_curve(os.path.join(args.save_path, 'curve.png'))
    log.close()


#Define pruning method
def prune(m, alpha):
    global CUDA
    weight_mask = torch.tensor([0.0])
    num_weights = 0  #Number of weights in layer
    num_pruned = 0  #Number of weights pruned
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
        num_weights = torch.numel(m.weight.data)
        # use a byteTensor to represent the mask and convert it to a floatTensor for multiplication
        if CUDA:
            weight_mask = torch.ge(
                m.weight.data.abs(), (alpha * m.weight.data.std()).type(
Exemplo n.º 12
0
def main():
  # Init logger
  if not os.path.isdir(args.save_path):
    os.makedirs(args.save_path)
  log = open(os.path.join(args.save_path, 'log_seed_{}.txt'.format(args.manualSeed)), 'w')
  print_log('save path : {}'.format(args.save_path), log)
  state = {k: v for k, v in args._get_kwargs()}
  print_log(state, log)
  print_log("Random Seed: {}".format(args.manualSeed), log)
  print_log("python version : {}".format(sys.version.replace('\n', ' ')), log)
  print_log("torch  version : {}".format(torch.__version__), log)
  print_log("cudnn  version : {}".format(torch.backends.cudnn.version()), log)

  # Data loading code
  # Any other preprocessings? http://pytorch.org/audio/transforms.html
  sample_length = 10000
  scale = transforms.Scale()
  padtrim = transforms.PadTrim(sample_length)
  downmix = transforms.DownmixMono()
  transforms_audio = transforms.Compose([
    scale, padtrim, downmix
  ])

  if not os.path.isdir(args.data_path):
    os.makedirs(args.data_path)
  train_dir = os.path.join(args.data_path, 'train')
  val_dir = os.path.join(args.data_path, 'val')

  #Choose dataset to use
  if args.dataset == 'arctic':
    # TODO No ImageFolder equivalent for audio. Need to create a Dataset manually
    train_dataset = Arctic(train_dir, transform=transforms_audio, download=True)
    val_dataset = Arctic(val_dir, transform=transforms_audio, download=True)
    num_classes = 4
  elif args.dataset == 'vctk':
    train_dataset = dset.VCTK(train_dir, transform=transforms_audio, download=True)
    val_dataset = dset.VCTK(val_dir, transform=transforms_audio, download=True)
    num_classes = 10
  elif args.dataset == 'yesno':
    train_dataset = dset.YESNO(train_dir, transform=transforms_audio, download=True)
    val_dataset = dset.YESNO(val_dir, transform=transforms_audio, download=True)
    num_classes = 2
  else:
    assert False, 'Dataset is incorrect'

  train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=args.batch_size,
    shuffle=True,
    num_workers=args.workers,
    # pin_memory=True, # What is this?
    # sampler=None     # What is this?
  )
  val_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=args.batch_size, shuffle=False,
    num_workers=args.workers, pin_memory=True)


  #Feed in respective model file to pass into model (alexnet.py)
  print_log("=> creating model '{}'".format(args.arch), log)
  # Init model, criterion, and optimizer
  # net = models.__dict__[args.arch](num_classes)
  net = AlexNet(num_classes)
  #
  print_log("=> network :\n {}".format(net), log)

  # net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))

  # define loss function (criterion) and optimizer
  criterion = torch.nn.CrossEntropyLoss()

  # Define stochastic gradient descent as optimizer (run backprop on random small batch)
  optimizer = torch.optim.SGD(net.parameters(), state['learning_rate'], momentum=state['momentum'],
                weight_decay=state['decay'], nesterov=True)

  #Sets use for GPU if available
  if args.use_cuda:
    net.cuda()
    criterion.cuda()

  recorder = RecorderMeter(args.epochs)
  # optionally resume from a checkpoint
  # Need same python vresion that the resume was in 
  if args.resume:
    if os.path.isfile(args.resume):
      print_log("=> loading checkpoint '{}'".format(args.resume), log)
      if args.ngpu == 0:
        checkpoint = torch.load(args.resume, map_location=lambda storage, loc: storage)
      else:
        checkpoint = torch.load(args.resume)

      recorder = checkpoint['recorder']
      args.start_epoch = checkpoint['epoch']
      net.load_state_dict(checkpoint['state_dict'])
      optimizer.load_state_dict(checkpoint['optimizer'])
      print_log("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch']), log)
    else:
      print_log("=> no checkpoint found at '{}'".format(args.resume), log)
  else:
    print_log("=> do not use any checkpoint for {} model".format(args.arch), log)

  if args.evaluate:
    validate(val_loader, net, criterion, 0, log, val_dataset)
    return

  # Main loop
  start_time = time.time()
  epoch_time = AverageMeter()

  # Training occurs here
  for epoch in range(args.start_epoch, args.epochs):
    current_learning_rate = adjust_learning_rate(optimizer, epoch, args.gammas, args.schedule)

    need_hour, need_mins, need_secs = convert_secs2time(epoch_time.avg * (args.epochs-epoch))
    need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(need_hour, need_mins, need_secs)

    print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f}]'.format(time_string(), epoch, args.epochs, need_time, current_learning_rate) \
                + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False), 100-recorder.max_accuracy(False)), log)

    print("One epoch")
    # train for one epoch
    # Call to train (note that our previous net is passed into the model argument)
    train_acc, train_los = train(train_loader, net, criterion, optimizer, epoch, log, train_dataset)

    # evaluate on validation set
    #val_acc,   val_los   = extract_features(test_loader, net, criterion, log)
    val_acc,   val_los   = validate(val_loader, net, criterion, epoch, log, val_dataset)
    is_best = recorder.update(epoch, train_los, train_acc, val_los, val_acc)

    save_checkpoint({
      'epoch': epoch + 1,
      'arch': args.arch,
      'state_dict': net.state_dict(),
      'recorder': recorder,
      'optimizer' : optimizer.state_dict(),
    }, is_best, args.save_path, 'checkpoint.pth.tar')

    # measure elapsed time
    epoch_time.update(time.time() - start_time)
    start_time = time.time()
    recorder.plot_curve( os.path.join(args.save_path, 'curve.png') )

  log.close()
Exemplo n.º 13
0
def main():
    # Init logger
    if not os.path.isdir(args.save_path):
        os.makedirs(args.save_path)
    log = open(os.path.join(args.save_path, 'log_seed_{}.txt'.format(args.manualSeed)), 'w')
    print_log('save path : {}'.format(args.save_path), log)
    state = {k: v for k, v in args._get_kwargs()}
    print_log(state, log)
    print_log("Random Seed: {}".format(args.manualSeed), log)
    print_log("CS Pruning Rate: {}".format(args.prune_rate_cs), log)
    print_log("GM Pruning Rate: {}".format(args.prune_rate_gm), log)
    print_log("Layer Begin: {}".format(args.layer_begin), log)
    print_log("Layer End: {}".format(args.layer_end), log)
    print_log("Layer Inter: {}".format(args.layer_inter), log)
    print_log("Epoch prune: {}".format(args.epoch_prune), log)
    print_log("use pretrain: {}".format(args.use_pretrain), log)
    print_log("Pretrain path: {}".format(args.pretrain_path), log)

    # Init dataset
    if not os.path.isdir(args.data_path):
        os.makedirs(args.data_path)

    print_log("=> creating model '{}'".format(args.arch), log)

    num_classes = 1000 # number of imagenet classes

    # Init model, criterion, and optimizer
    net = models.__dict__[args.arch](num_classes)
    print_log("=> network :\n {}".format(net), log)

    net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))

    # define loss function (criterion) and optimizer
    criterion = torch.nn.CrossEntropyLoss()

    optimizer = torch.optim.SGD(net.parameters(), state['learning_rate'], momentum=state['momentum'],
                                weight_decay=state['decay'], nesterov=True)

    if args.use_cuda:
        net.cuda()
        criterion.cuda()

    if args.use_pretrain:
        pretrain = torch.load(args.pretrain_path)
        if args.use_state_dict:
            net.load_state_dict(pretrain['state_dict'])
        else:
            net = pretrain['state_dict']

    recorder = RecorderMeter(args.epochs)

    traindir = os.path.join(args.data_path, 'train')
    valdir = os.path.join(args.data_path, 'val')

    # create loaders for testing and pruning
    img_size = 32
    batch_range = list(range(1, 11))
    random.shuffle(batch_range)
    imnetdata = imgnet_utils.load_databatch(traindir, batch_range[0], img_size=img_size)
    X_train_prune = imnetdata['X_train']
    Y_train_prune = imnetdata['Y_train']
    mean_image = imnetdata['mean']
    num_batches_pruner = 0
    for ib in batch_range[1:1 + num_batches_pruner]:
        print('train batch for prune loader: {}'.format(ib))
        imnetdata = imgnet_utils.load_databatch(traindir, ib, img_size=img_size)
        X_train_prune = np.concatenate((X_train_prune, imnetdata['X_train']), axis=0)
        Y_train_prune = np.concatenate((Y_train_prune, imnetdata['Y_train']), axis=0)

    del imnetdata

    train_data_prune_loader = torch.utils.data.TensorDataset(
        torch.cat([torch.FloatTensor(X_train_prune), torch.FloatTensor(X_train_prune[:,:,:,::-1].copy())], dim=0),
        torch.cat([torch.LongTensor(Y_train_prune), torch.LongTensor(Y_train_prune)], dim=0))
    train_prune_loader = torch.utils.data.DataLoader(train_data_prune_loader, batch_size=args.batch_prune_size,
                                                     shuffle=True, num_workers=args.workers, pin_memory=True)

    del X_train_prune, Y_train_prune

    # create test loader
    imnetdata = imgnet_utils.load_validation_data(valdir, mean_image=mean_image, img_size=img_size)
    X_test = imnetdata['X_test']
    Y_test = imnetdata['Y_test']

    del imnetdata

    test_data = torch.utils.data.TensorDataset(torch.FloatTensor(X_test), torch.LongTensor(Y_test))
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=False,
                                              num_workers=args.workers, pin_memory=True)

    del X_test, Y_test

    mdlIdx2ConvIdx = [] # module index to conv filter index
    for index1, layr in enumerate(net.modules()):
        if isinstance(layr, torch.nn.Conv2d):
            mdlIdx2ConvIdx.append(index1)

    prmIdx2ConvIdx = [] # parameter index to conv filter index
    for index2, item in enumerate(net.parameters()):
        if len(item.size()) == 4:
            prmIdx2ConvIdx.append(index2)

    # set index of last layer depending on the known architecture
    if args.arch == 'resnet20':
        args.layer_end = 54
    elif args.arch == 'resnet56':
        args.layer_end = 162
    elif args.arch == 'resnet110':
        args.layer_end = 324
    else:
        pass # unkonwn architecture, use input value

    # asymptotic schedule
    total_pruning_rate = args.prune_rate_gm + args.prune_rate_cs
    compress_rates_total, scalling_factors, compress_rates_cs, compress_rates_fpgm, e2 =\
        cmpAsymptoticSchedule(theta3=total_pruning_rate, e3=args.epochs-1, tau=args.tau, theta_cs_final = args.prune_rate_cs, scaling_attn = args.scaling_attenuation) # tau=8.
    keep_rate_cs = 1. - compress_rates_cs

    if args.use_zero_scaling:
        scalling_factors = np.zeros(scalling_factors.shape)

    m = Mask(net, train_prune_loader, mdlIdx2ConvIdx, prmIdx2ConvIdx, scalling_factors, keep_rate_cs, compress_rates_fpgm, args.max_iter_cs)
    m.set_curr_epoch(0)
    m.set_epoch_cs(args.epoch_apply_cs)
    m.init_selected_filts()
    m.init_length()

    val_acc_1, val_los_1 = validate(test_loader, net, criterion, log)

    m.model = net
    m.init_mask(keep_rate_cs[0], compress_rates_fpgm[0], scalling_factors[0])
    #    m.if_zero()
    m.do_mask()
    m.do_similar_mask()
    net = m.model
    #    m.if_zero()
    if args.use_cuda:
        net = net.cuda()
    val_acc_2, val_los_2 = validate(test_loader, net, criterion, log)

    # Main loop
    start_time = time.time()
    epoch_time = AverageMeter()

    for epoch in range(args.start_epoch, args.epochs):

        current_learning_rate = adjust_learning_rate(optimizer, epoch, args.gammas, args.schedule)

        need_hour, need_mins, need_secs = convert_secs2time(epoch_time.avg * (args.epochs - epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(need_hour, need_mins, need_secs)

        print_log(
            '\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f}]'.format(time_string(), epoch, args.epochs,
                                                                                   need_time, current_learning_rate) \
            + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False),
                                                               100 - recorder.max_accuracy(False)), log)

        # train for one epoch
        train_acc_i = torch.zeros([11, ], dtype=torch.float32)
        train_los_i = torch.zeros([11, ], dtype=torch.float32)
        idx = 0
        for ib in batch_range:
            print('train batch: {}'.format(ib))

            imnetdata = imgnet_utils.load_databatch(traindir, ib, img_size=img_size)
            X_train = imnetdata['X_train']
            Y_train = imnetdata['Y_train']

            del imnetdata

            train_data = torch.utils.data.TensorDataset(
                torch.cat([torch.FloatTensor(X_train), torch.FloatTensor(X_train[:, :, :, ::-1].copy())], dim=0),
                torch.cat([torch.LongTensor(Y_train), torch.LongTensor(Y_train)], dim=0))
            train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True,
                                                       num_workers=args.workers, pin_memory=True)

            del X_train, Y_train

            train_acc_i[idx], train_los_i[idx] = train(train_loader, net, criterion, optimizer, epoch, log, m)

            idx += 1
        train_acc = torch.mean(train_acc_i)
        train_los = torch.mean(train_los_i)

        # evaluate on validation set
        val_acc_1,   val_los_1   = validate(test_loader, net, criterion, log)
        print('Before: val_acc_1: {}, val_los_1: {}'.format(val_acc_1, val_los_1))

        #train_acc, train_los = train(train_loader, net, criterion, optimizer, epoch, log, m)
        
        # evaluate on validation set
        if epoch % args.epoch_prune == 0 or epoch == args.epochs - 1:
            m.model = net
            m.set_curr_epoch(epoch)
            # m.if_zero()
            m.init_mask(keep_rate_cs[epoch], compress_rates_fpgm[epoch], scalling_factors[epoch])
            m.do_mask()
            m.do_similar_mask()
            # m.if_zero()
            net = m.model
            if args.use_cuda:
                net = net.cuda()
            if epoch == args.epochs - 1:
                m.if_zero()

        val_acc_2, val_los_2 = validate(test_loader, net, criterion, log)
        print('After: val_acc_2: {}, val_los_2: {}'.format(val_acc_2, val_los_1))

        is_best = recorder.update(epoch, train_los, train_acc, val_los_2, val_acc_2)

        save_checkpoint({
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': net,
            'recorder': recorder,
            'optimizer': optimizer.state_dict(),
        }, is_best, args.save_path, 'checkpoint.pth.tar')

        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()
        recorder.plot_curve(os.path.join(args.save_path, 'curve.png'))

    log.close()
Exemplo n.º 14
0
def main():
    # Init logger
    if not os.path.isdir(args.save_path):
        os.makedirs(args.save_path)

    # used for file names, etc
    time_stamp = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    log = open(
        os.path.join(
            args.save_path,
            'log_seed_{0}_{1}.txt'.format(args.manualSeed, time_stamp)), 'w')
    print_log('save path : {}'.format(args.save_path), log)
    state = {k: v for k, v in args._get_kwargs()}
    print_log(state, log)
    print_log("Random Seed: {}".format(args.manualSeed), log)
    print_log("python version : {}".format(sys.version.replace('\n', ' ')),
              log)
    print_log("torch  version : {}".format(torch.__version__), log)
    print_log("cudnn  version : {}".format(torch.backends.cudnn.version()),
              log)

    # Init dataset
    if not os.path.isdir(args.data_path):
        os.makedirs(args.data_path)

    if args.dataset == 'cifar10':
        mean = [x / 255 for x in [125.3, 123.0, 113.9]]
        std = [x / 255 for x in [63.0, 62.1, 66.7]]
    elif args.dataset == 'cifar100':
        mean = [x / 255 for x in [129.3, 124.1, 112.4]]
        std = [x / 255 for x in [68.2, 65.4, 70.4]]
    else:
        assert False, "Unknow dataset : {}".format(args.dataset)

    writer = SummaryWriter()

    #   # Data transforms
    # mean = [0.5071, 0.4867, 0.4408]
    # std = [0.2675, 0.2565, 0.2761]

    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])
    #[transforms.CenterCrop(32), transforms.ToTensor(),
    # transforms.Normalize(mean, std)])
    #)
    test_transform = transforms.Compose([
        transforms.CenterCrop(32),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])

    if args.dataset == 'cifar10':
        train_data = dset.CIFAR10(args.data_path,
                                  train=True,
                                  transform=train_transform,
                                  download=True)
        test_data = dset.CIFAR10(args.data_path,
                                 train=False,
                                 transform=test_transform,
                                 download=True)
        num_classes = 10
    elif args.dataset == 'cifar100':
        train_data = dset.CIFAR100(args.data_path,
                                   train=True,
                                   transform=train_transform,
                                   download=True)
        test_data = dset.CIFAR100(args.data_path,
                                  train=False,
                                  transform=test_transform,
                                  download=True)
        num_classes = 100
    elif args.dataset == 'imagenet':
        assert False, 'Did not finish imagenet code'
    else:
        assert False, 'Does not support dataset : {}'.format(args.dataset)

    #step_sizes = 2500
    step_sizes = args.alinit
    indices = [l for l in range(0, 50000)]

    annot_indices = [
    ]  # indices which are added to the training pool, list as we store it for all steps
    unannot_indices = [
        indices
    ]  # indices which have not been added to the training pool

    selections = random.sample(range(0, len(unannot_indices[-1])), step_sizes)
    temp = list(np.asarray(unannot_indices[-1])[selections])
    annot_indices.append(temp)

    unannot_indices.append(
        list(set(unannot_indices[-1]) - set(annot_indices[-1])))

    labelled_dset = torch.utils.data.Subset(train_data, annot_indices[-1])
    unlabelled_dset = torch.utils.data.Subset(train_data, unannot_indices[-1])

    #train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True,
    #                       num_workers=args.workers, pin_memory=True)
    labelled_loader = torch.utils.data.DataLoader(labelled_dset,
                                                  batch_size=args.batch_size,
                                                  shuffle=True,
                                                  num_workers=args.workers,
                                                  pin_memory=True)

    #unlabelled_loader = torch.utils.data.DataLoader(unlabelled_dset, batch_size=args.batch_size, shuffle=True,
    #num_workers=args.workers, pin_memory=True)

    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=args.workers,
                                              pin_memory=True)

    print_log("=> creating model '{}'".format(args.arch), log)
    # Init model, criterion, and optimizer
    net = models.__dict__[args.arch](num_classes)
    #torch.save(net, 'net.pth')
    #init_net = torch.load('net.pth')
    #net.load_my_state_dict(init_net.state_dict())
    print_log("=> network :\n {}".format(net), log)

    net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))

    # define loss function (criterion) and optimizer
    criterion = torch.nn.CrossEntropyLoss()

    #optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=0.005, nesterov=False)
    optimizer = torch.optim.Adadelta(
        net.parameters(),
        lr=0.1,
        rho=0.9,
        eps=1e-3,  # momentum=state['momentum'],
        weight_decay=0.001)

    print_log("=> Seed '{}'".format(args.manualSeed), log)
    print_log("=> dataset mean and std '{} - {}'".format(str(mean), str(std)),
              log)

    states_settings = {'optimizer': optimizer.state_dict()}

    print_log("=> optimizer '{}'".format(states_settings), log)
    # 50k,95k,153k,195k,220k
    milestones = [100, 190, 306, 390, 440, 540]
    scheduler = lr_scheduler.MultiStepLR(optimizer, milestones, gamma=0.1)

    if args.use_cuda:
        net.cuda()
        criterion.cuda()

    recorder = RecorderMeter(args.epochs)
    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print_log("=> loading checkpoint '{}'".format(args.resume), log)
            checkpoint = torch.load(args.resume)
            recorder = checkpoint['recorder']
            args.start_epoch = checkpoint['epoch']
            net.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print_log(
                "=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint['epoch']), log)
        else:
            print_log("=> no checkpoint found at '{}'".format(args.resume),
                      log)
    else:
        print_log(
            "=> did not use any checkpoint for {} model".format(args.arch),
            log)

    if args.evaluate:
        validate(test_loader, net, criterion, log)
        return

    # Main loop
    start_time = time.time()
    epoch_time = AverageMeter()

    al_steps = int(50000 / args.alinit)

    curr_al_step = 0
    dump_data = []

    for (al_step, epoch) in [(a, b) for a in range(al_steps)
                             for b in range(args.start_epoch, args.epochs)]:
        print(" Current AL_step and epoch " + str((al_step, epoch)))
        if (al_step != curr_al_step):

            #These return scores of datapoints in unlabelled dataset according to their indices
            #indices of the data points(w.r.t to the original indexing from 1 to 50000) in the
            #unlabelled dataset
            curr_al_step = al_step
            #Resetting the learning rate scheduler
            scheduler = lr_scheduler.MultiStepLR(optimizer,
                                                 milestones,
                                                 gamma=0.1)

            scores_unlabelled = score(unlabelled_dset, net, criterion)
            indices_sorted = np.argsort(scores_unlabelled)

            #Greedy Sampling
            temp_selections = indices_sorted[-1 * args.alinit:]
            selections = np.asarray(list(
                unlabelled_dset.indices))[temp_selections].tolist()

            annot_indices.append(selections)

            unannot_indices.append(
                set(unannot_indices[-1]) - set(annot_indices[-1]))

            labelled_dset = torch.utils.data.Subset(train_data,
                                                    annot_indices[-1])
            labelled_loader = torch.utils.data.DataLoader(
                labelled_dset,
                batch_size=args.batch_size,
                shuffle=True,
                num_workers=args.workers,
                pin_memory=True)
            unlabelled_dset = torch.utils.data.Subset(train_data,
                                                      unannot_indices[-1])

            indices_data = [annot_indices, unannot_indices]
            filehandler = open("indices.pickle", "wb")
            pickle.dump(indices_data, filehandler)
            filehandler.close()

        #current_learning_rate = adjust_learning_rate(optimizer, epoch, args.gammas, args.schedule)
        current_learning_rate = float(scheduler.get_lr()[-1])
        #print('lr:',current_learning_rate)

        scheduler.step()

        #adjust_learning_rate(optimizer, epoch)

        need_hour, need_mins, need_secs = convert_secs2time(
            epoch_time.avg * (args.epochs - epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(
            need_hour, need_mins, need_secs)

        print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:.6f}]'.format(time_string(), epoch, args.epochs, need_time, current_learning_rate) \
                    + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False), 100-recorder.max_accuracy(False)), log)

        # train for one epoch
        #train_acc, train_los = train(train_loader, net, criterion, optimizer, epoch, log)
        train_acc, train_los = train(labelled_loader, net, criterion,
                                     optimizer, epoch, log)

        # evaluate on validation set
        #val_acc,   val_los   = extract_features(test_loader, net, criterion, log)
        val_acc, val_los = validate(test_loader, net, criterion, log)
        is_best = recorder.update(epoch, train_los, train_acc, val_los,
                                  val_acc)

        dump_data.append(([al_step, epoch], [train_acc,
                                             train_los], [val_acc, val_los]))
        if (epoch % 50 == 0):
            filehandler = open("accuracy.pickle", "wb")
            pickle.dump(dump_data, filehandler)
            filehandler.close()

        if epoch == 180:
            save_checkpoint(
                {
                    'epoch': epoch,
                    'arch': args.arch,
                    'state_dict': net.state_dict(),
                    'recorder': recorder,
                    'optimizer': optimizer.state_dict(),
                }, False, args.save_path,
                'checkpoint_{0}_{1}.pth.tar'.format(epoch,
                                                    time_stamp), time_stamp)

        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': net.state_dict(),
                'recorder': recorder,
                'optimizer': optimizer.state_dict(),
            }, is_best, args.save_path,
            'checkpoint_{0}.pth.tar'.format(time_stamp), time_stamp)

        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()
        recorder.plot_curve(
            os.path.join(
                args.save_path,
                'training_plot_{0}_{1}.png'.format(args.manualSeed,
                                                   time_stamp)))

    writer.close()
    log.close()
def main():
    # Init logger
    if not os.path.isdir(args.save_path):
        os.makedirs(args.save_path)
    log = open(
        os.path.join(args.save_path,
                     'log_seed_{}.txt'.format(args.manualSeed)), 'w')
    print_log('save path : {}'.format(args.save_path), log)
    state = {k: v for k, v in args._get_kwargs()}
    print_log(state, log)
    print_log("Random Seed: {}".format(args.manualSeed), log)
    print_log("python version : {}".format(sys.version.replace('\n', ' ')),
              log)
    print_log("torch  version : {}".format(torch.__version__), log)
    print_log("cudnn  version : {}".format(torch.backends.cudnn.version()),
              log)

    # Init dataset
    if not os.path.isdir(args.data_path):
        os.makedirs(args.data_path)

    if args.dataset == 'cifar10':
        mean = [x / 255 for x in [125.3, 123.0, 113.9]]
        std = [x / 255 for x in [63.0, 62.1, 66.7]]
    elif args.dataset == 'cifar100':
        mean = [x / 255 for x in [129.3, 124.1, 112.4]]
        std = [x / 255 for x in [68.2, 65.4, 70.4]]
    elif args.dataset == 'cub_200':
        mean = [0.4856077, 0.49941534, 0.43237692]
        std = [0.23222743, 0.2277201, 0.26586822]
    else:
        assert False, "Unknow dataset : {}".format(args.dataset)

    train_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32, padding=4),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])
    test_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize(mean, std)])
    if args.dataset == 'cub_200':
        train_transform = transforms.Compose([
            transforms.RandomSizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])
        test_transform = transforms.Compose([
            transforms.Scale(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])

    if args.dataset == 'cifar10':
        train_data = dset.CIFAR10(args.data_path,
                                  train=True,
                                  transform=train_transform,
                                  download=True)
        test_data = dset.CIFAR10(args.data_path,
                                 train=False,
                                 transform=test_transform,
                                 download=True)
        num_classes = 10
    elif args.dataset == 'cifar100':
        train_data = dset.CIFAR100(args.data_path,
                                   train=True,
                                   transform=train_transform,
                                   download=True)
        test_data = dset.CIFAR100(args.data_path,
                                  train=False,
                                  transform=test_transform,
                                  download=True)
        num_classes = 100
    elif args.dataset == 'svhn':
        train_data = dset.SVHN(args.data_path,
                               split='train',
                               transform=train_transform,
                               download=True)
        test_data = dset.SVHN(args.data_path,
                              split='test',
                              transform=test_transform,
                              download=True)
        num_classes = 10
    elif args.dataset == 'stl10':
        train_data = dset.STL10(args.data_path,
                                split='train',
                                transform=train_transform,
                                download=True)
        test_data = dset.STL10(args.data_path,
                               split='test',
                               transform=test_transform,
                               download=True)
        num_classes = 10

    elif args.dataset == 'cub_200':
        train_data = MyImageFolder(args.data_path + 'train',
                                   transform=train_transform,
                                   data_cached=True)
        test_data = MyImageFolder(args.data_path + 'val',
                                  transform=test_transform,
                                  data_cached=True)
        num_classes = 200

    elif args.dataset == 'imagenet':
        assert False, 'Do not finish imagenet code'
    else:
        assert False, 'Do not support dataset : {}'.format(args.dataset)

    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=args.workers,
                                              pin_memory=True)

    print_log("=> creating model '{}'".format(args.arch), log)
    # Init model, criterion, and optimizer
    net = models.__dict__[args.arch](num_classes)
    print_log("=> network :\n {}".format(net), log)

    # define loss function (criterion) and optimizer
    criterion = torch.nn.CrossEntropyLoss()

    if args.arch == 'spatial_transform_resnext50':
        loc_params = net.loc.parameters()
        loc_params_id = list(map(id, loc_params))
        base_params = filter(lambda p: id(p) not in loc_params_id,
                             net.parameters())
        optimizer = torch.optim.SGD(
            [
                {
                    'params': base_params
                },
                {
                    'params': net.loc.parameters(),
                    'lr': 1e-4 * state['learning_rate']
                },
                #     {'params': net.stn.parameters()}
            ],
            state['learning_rate'],
            momentum=state['momentum'],
            weight_decay=state['decay'],
            nesterov=True)
        #optimizer = torch.optim.SGD(net.parameters(), state['learning_rate'], momentum=state['momentum'],
        #            weight_decay=state['decay'], nesterov=True)

    else:
        optimizer = torch.optim.SGD(net.parameters(),
                                    state['learning_rate'],
                                    momentum=state['momentum'],
                                    weight_decay=state['decay'],
                                    nesterov=True)

    recorder = RecorderMeter(args.epochs)
    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print_log("=> loading checkpoint '{}'".format(args.resume), log)
            checkpoint = torch.load(args.resume)
            #      recorder = checkpoint['recorder']
            #      args.start_epoch = checkpoint['epoch']

            state_dict = checkpoint['state_dict']
            if args.arch == 'spatial_transform_resnext50':
                ckpt_weights = list(state_dict.values())[:-2]
                m_list = net.crop_descriptors
                loc = net.loc[0]
                m_dict_keys = list(loc.state_dict().keys())
                m_dict = dict(zip(m_dict_keys, ckpt_weights))
                loc.load_state_dict(m_dict)
                print('loaded loc net')
                for m in m_list:
                    m = m[0]
                    m_dict_keys = list(m.state_dict().keys())
                    m_dict = dict(zip(m_dict_keys, ckpt_weights))
                    m.load_state_dict(m_dict)
                    print('loaded one descriptor')
            else:

                model_dict = net.state_dict()
                from_ = list(state_dict.keys())
                to = list(model_dict.keys())

                for i, k in enumerate(from_):
                    if k not in ['module.fc.weight', 'module.fc.bias']:
                        model_dict[to[i]] = state_dict[k]

                net.load_state_dict(model_dict)
        # optimizer.load_state_dict(checkpoint['optimizer'])
            print_log(
                "=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint['epoch']), log)
        else:
            print_log("=> no checkpoint found at '{}'".format(args.resume),
                      log)
    else:
        print_log(
            "=> do not use any checkpoint for {} model".format(args.arch), log)
    loc_classifier_w = net.loc[1].fc_2.weight
    classifier_w = net.classifier.weight
    net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))
    if args.use_cuda:
        net.cuda()
        criterion.cuda()

    if args.evaluate:
        validate(test_loader, net, criterion, log)
        return

    # Main loop
    start_time = time.time()
    epoch_time = AverageMeter()
    for epoch in range(args.start_epoch, args.epochs):
        # current_learning_rate = adjust_learning_rate(optimizer, epoch, args.gammas, args.schedule)
        current_learning_rate = 0
        need_hour, need_mins, need_secs = convert_secs2time(
            epoch_time.avg * (args.epochs - epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(
            need_hour, need_mins, need_secs)

        print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f}]'.format(time_string(), epoch, args.epochs, need_time, current_learning_rate) \
                    + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False), 100-recorder.max_accuracy(False)), log)

        # train for one epoch
        train_acc, train_los = train(train_loader, net, criterion, optimizer,
                                     epoch, log, loc_classifier_w,
                                     classifier_w)

        # evaluate on validation set
        #val_acc,   val_los   = extract_features(test_loader, net, criterion, log)
        val_acc, val_los = validate(test_loader, net, criterion, log)
        is_best = recorder.update(epoch, train_los, train_acc, val_los,
                                  val_acc)

        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': net.state_dict(),
                'recorder': recorder,
                'optimizer': optimizer.state_dict(),
            }, is_best, args.save_path, 'checkpoint.pth.tar')

        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()
        recorder.plot_curve(os.path.join(args.save_path, 'curve.png'))

    log.close()
def main():
    # Init logger6
    if not os.path.isdir(args.save_path):
        os.makedirs(args.save_path)
    log = open(
        os.path.join(args.save_path,
                     'log_seed_{}.txt'.format(args.manualSeed)), 'w')
    print_log('save path : {}'.format(args.save_path), log)
    state = {k: v for k, v in args._get_kwargs()}
    print_log(state, log)
    print_log("Random Seed: {}".format(args.manualSeed), log)
    print_log("python version : {}".format(sys.version.replace('\n', ' ')),
              log)
    print_log("torch  version : {}".format(torch.__version__), log)
    print_log("cudnn  version : {}".format(torch.backends.cudnn.version()),
              log)

    # Init the tensorboard path and writer
    tb_path = os.path.join(args.save_path, 'tb_log',
                           'run_' + str(args.manualSeed))
    # logger = Logger(tb_path)
    writer = SummaryWriter(tb_path)

    # Init dataset
    if not os.path.isdir(args.data_path):
        os.makedirs(args.data_path)

    if args.dataset == 'cifar10':
        mean = [x / 255 for x in [125.3, 123.0, 113.9]]
        std = [x / 255 for x in [63.0, 62.1, 66.7]]
    elif args.dataset == 'cifar100':
        mean = [x / 255 for x in [129.3, 124.1, 112.4]]
        std = [x / 255 for x in [68.2, 65.4, 70.4]]
    elif args.dataset == 'svhn':
        mean = [0.5, 0.5, 0.5]
        std = [0.5, 0.5, 0.5]
    elif args.dataset == 'mnist':
        mean = [0.5, 0.5, 0.5]
        std = [0.5, 0.5, 0.5]
    elif args.dataset == 'imagenet':
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
    else:
        assert False, "Unknow dataset : {}".format(args.dataset)

    if args.dataset == 'imagenet':
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])
        test_transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])  # here is actually the validation dataset
    else:
        train_transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, padding=4),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])
        test_transform = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize(mean, std)])

    if args.dataset == 'mnist':
        train_data = dset.MNIST(args.data_path,
                                train=True,
                                transform=train_transform,
                                download=True)
        test_data = dset.MNIST(args.data_path,
                               train=False,
                               transform=test_transform,
                               download=True)
        num_classes = 10
    elif args.dataset == 'cifar10':
        train_data = dset.CIFAR10(args.data_path,
                                  train=True,
                                  transform=train_transform,
                                  download=True)
        test_data = dset.CIFAR10(args.data_path,
                                 train=False,
                                 transform=test_transform,
                                 download=True)
        num_classes = 10
    elif args.dataset == 'cifar100':
        train_data = dset.CIFAR100(args.data_path,
                                   train=True,
                                   transform=train_transform,
                                   download=True)
        test_data = dset.CIFAR100(args.data_path,
                                  train=False,
                                  transform=test_transform,
                                  download=True)
        num_classes = 100
    elif args.dataset == 'svhn':
        train_data = dset.SVHN(args.data_path,
                               split='train',
                               transform=train_transform,
                               download=True)
        test_data = dset.SVHN(args.data_path,
                              split='test',
                              transform=test_transform,
                              download=True)
        num_classes = 10
    elif args.dataset == 'stl10':
        train_data = dset.STL10(args.data_path,
                                split='train',
                                transform=train_transform,
                                download=True)
        test_data = dset.STL10(args.data_path,
                               split='test',
                               transform=test_transform,
                               download=True)
        num_classes = 10
    elif args.dataset == 'imagenet':
        train_dir = os.path.join(args.data_path, 'train')
        test_dir = os.path.join(args.data_path, 'val')
        train_data = dset.ImageFolder(train_dir, transform=train_transform)
        test_data = dset.ImageFolder(test_dir, transform=test_transform)
        num_classes = 1000
    else:
        assert False, 'Do not support dataset : {}'.format(args.dataset)

    train_loader = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.attack_sample_size,
        shuffle=True,
        num_workers=args.workers,
        pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=args.test_batch_size,
                                              shuffle=True,
                                              num_workers=args.workers,
                                              pin_memory=True)

    print_log("=> creating model '{}'".format(args.arch), log)

    # Init model, criterion, and optimizer
    net = models.__dict__[args.arch](num_classes)
    print_log("=> network :\n {}".format(net), log)

    if args.use_cuda:
        if args.ngpu > 1:
            net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))

    # define loss function (criterion) and optimizer
    criterion = torch.nn.CrossEntropyLoss()

    # separate the parameters thus param groups can be updated by different optimizer
    all_param = [
        param for name, param in net.named_parameters()
        if not 'step_size' in name
    ]

    step_param = [
        param for name, param in net.named_parameters() if 'step_size' in name
    ]

    if args.optimizer == "SGD":
        print("using SGD as optimizer")
        optimizer = torch.optim.SGD(all_param,
                                    lr=state['learning_rate'],
                                    momentum=state['momentum'],
                                    weight_decay=state['decay'],
                                    nesterov=True)

    elif args.optimizer == "Adam":
        print("using Adam as optimizer")
        optimizer = torch.optim.Adam(filter(lambda param: param.requires_grad,
                                            net.parameters()),
                                     lr=state['learning_rate'],
                                     weight_decay=state['decay'])

    elif args.optimizer == "RMSprop":
        print("using RMSprop as optimizer")
        optimizer = torch.optim.RMSprop(filter(
            lambda param: param.requires_grad, net.parameters()),
                                        lr=state['learning_rate'],
                                        alpha=0.99,
                                        eps=1e-08,
                                        weight_decay=0,
                                        momentum=0)

    if args.use_cuda:
        net.cuda()
        criterion.cuda()

    recorder = RecorderMeter(args.epochs)  # count number of epoches

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print_log("=> loading checkpoint '{}'".format(args.resume), log)
            checkpoint = torch.load(args.resume)
            if not (args.fine_tune):
                args.start_epoch = checkpoint['epoch']
                recorder = checkpoint['recorder']
                optimizer.load_state_dict(checkpoint['optimizer'])

            state_tmp = net.state_dict()
            if 'state_dict' in checkpoint.keys():
                state_tmp.update(checkpoint['state_dict'])
            else:
                state_tmp.update(checkpoint)

            net.load_state_dict(state_tmp)

            print_log(
                "=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, args.start_epoch), log)
        else:
            print_log("=> no checkpoint found at '{}'".format(args.resume),
                      log)
    else:
        print_log(
            "=> do not use any checkpoint for {} model".format(args.arch), log)

    # update the step_size once the model is loaded. This is used for quantization.
    for m in net.modules():
        if isinstance(m, quan_Conv2d) or isinstance(m, quan_Linear):
            # simple step size update based on the pretrained model or weight init
            m.__reset_stepsize__()

    # block for quantizer optimization
    if args.optimize_step:
        optimizer_quan = torch.optim.SGD(step_param,
                                         lr=0.01,
                                         momentum=0.9,
                                         weight_decay=0,
                                         nesterov=True)

        for m in net.modules():
            if isinstance(m, quan_Conv2d) or isinstance(m, quan_Linear):
                for i in range(
                        300
                ):  # runs 200 iterations to reduce quantization error
                    optimizer_quan.zero_grad()
                    weight_quan = quantize(m.weight, m.step_size,
                                           m.half_lvls) * m.step_size
                    loss_quan = F.mse_loss(weight_quan,
                                           m.weight,
                                           reduction='mean')
                    loss_quan.backward()
                    optimizer_quan.step()

        for m in net.modules():
            if isinstance(m, quan_Conv2d):
                print(m.step_size.data.item(),
                      (m.step_size.detach() * m.half_lvls).item(),
                      m.weight.max().item())

    # block for weight reset
    if args.reset_weight:
        for m in net.modules():
            if isinstance(m, quan_Conv2d) or isinstance(m, quan_Linear):
                m.__reset_weight__()
                # print(m.weight)

    attacker = BFA(criterion, args.k_top)
    net_clean = copy.deepcopy(net)
    # weight_conversion(net)

    if args.enable_bfa:
        perform_attack(attacker, net, net_clean, train_loader, test_loader,
                       args.n_iter, log, writer)
        return

    if args.evaluate:
        validate(test_loader, net, criterion, log)
        return

    # Main loop
    start_time = time.time()
    epoch_time = AverageMeter()

    for epoch in range(args.start_epoch, args.epochs):
        current_learning_rate, current_momentum = adjust_learning_rate(
            optimizer, epoch, args.gammas, args.schedule)
        # Display simulation time
        need_hour, need_mins, need_secs = convert_secs2time(
            epoch_time.avg * (args.epochs - epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(
            need_hour, need_mins, need_secs)

        print_log(
            '\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [LR={:6.4f}][M={:1.2f}]'.format(time_string(), epoch, args.epochs,
                                                                                   need_time, current_learning_rate,
                                                                                   current_momentum) \
            + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False),
                                                               100 - recorder.max_accuracy(False)), log)

        # train for one epoch
        train_acc, train_los = train(train_loader, net, criterion, optimizer,
                                     epoch, log)

        # evaluate on validation set
        val_acc, _, val_los = validate(test_loader, net, criterion, log)
        recorder.update(epoch, train_los, train_acc, val_los, val_acc)
        is_best = val_acc >= recorder.max_accuracy(False)

        if args.model_only:
            checkpoint_state = {'state_dict': net.state_dict}
        else:
            checkpoint_state = {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': net.state_dict(),
                'recorder': recorder,
                'optimizer': optimizer.state_dict(),
            }

        save_checkpoint(checkpoint_state, is_best, args.save_path,
                        'checkpoint.pth.tar', log)

        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()
        recorder.plot_curve(os.path.join(args.save_path, 'curve.png'))

        # save addition accuracy log for plotting
        accuracy_logger(base_dir=args.save_path,
                        epoch=epoch,
                        train_accuracy=train_acc,
                        test_accuracy=val_acc)

        # ============ TensorBoard logging ============#

        ## Log the graidents distribution
        for name, param in net.named_parameters():
            name = name.replace('.', '/')
            writer.add_histogram(name + '/grad',
                                 param.grad.clone().cpu().data.numpy(),
                                 epoch + 1,
                                 bins='tensorflow')

        # ## Log the weight and bias distribution
        for name, module in net.named_modules():
            name = name.replace('.', '/')
            class_name = str(module.__class__).split('.')[-1].split("'")[0]

            if "Conv2d" in class_name or "Linear" in class_name:
                if module.weight is not None:
                    writer.add_histogram(
                        name + '/weight/',
                        module.weight.clone().cpu().data.numpy(),
                        epoch + 1,
                        bins='tensorflow')

        writer.add_scalar('loss/train_loss', train_los, epoch + 1)
        writer.add_scalar('loss/test_loss', val_los, epoch + 1)
        writer.add_scalar('accuracy/train_accuracy', train_acc, epoch + 1)
        writer.add_scalar('accuracy/test_accuracy', val_acc, epoch + 1)
    # ============ TensorBoard logging ============#

    log.close()
Exemplo n.º 17
0
def main():
    # Init logger6
    if not os.path.isdir(args.save_path):
        os.makedirs(args.save_path)
    log = open(
        os.path.join(args.save_path,
                     'log_seed_{}.txt'.format(args.manualSeed)), 'w')
    print_log('save path : {}'.format(args.save_path), log)
    state = {k: v for k, v in args._get_kwargs()}
    print_log(state, log)
    print_log("Random Seed: {}".format(args.manualSeed), log)
    print_log("python version : {}".format(sys.version.replace('\n', ' ')),
              log)
    print_log("torch  version : {}".format(torch.__version__), log)
    print_log("cudnn  version : {}".format(torch.backends.cudnn.version()),
              log)

    # Init the tensorboard path and writer
    tb_path = os.path.join(args.save_path, 'tb_log')
    # logger = Logger(tb_path)
    writer = SummaryWriter(tb_path)

    # Init dataset
    if not os.path.isdir(args.data_path):
        os.makedirs(args.data_path)

    if args.dataset == 'cifar10':
        mean = [x / 255 for x in [125.3, 123.0, 113.9]]
        std = [x / 255 for x in [63.0, 62.1, 66.7]]
    elif args.dataset == 'cifar100':
        mean = [x / 255 for x in [129.3, 124.1, 112.4]]
        std = [x / 255 for x in [68.2, 65.4, 70.4]]
    elif args.dataset == 'svhn':
        mean = [0.5, 0.5, 0.5]
        std = [0.5, 0.5, 0.5]
    elif args.dataset == 'mnist':
        mean = [0.5, 0.5, 0.5]
        std = [0.5, 0.5, 0.5]
    elif args.dataset == 'imagenet':
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
    else:
        assert False, "Unknow dataset : {}".format(args.dataset)

    if args.dataset == 'imagenet':
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])
        test_transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])  # here is actually the validation dataset
    else:
        train_transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, padding=4),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])
        test_transform = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize(mean, std)])

    if args.dataset == 'mnist':
        train_data = dset.MNIST(args.data_path,
                                train=True,
                                transform=train_transform,
                                download=True)
        test_data = dset.MNIST(args.data_path,
                               train=False,
                               transform=test_transform,
                               download=True)
        num_classes = 10
    elif args.dataset == 'cifar10':
        train_data = dset.CIFAR10(args.data_path,
                                  train=True,
                                  transform=train_transform,
                                  download=True)
        test_data = dset.CIFAR10(args.data_path,
                                 train=False,
                                 transform=test_transform,
                                 download=True)
        num_classes = 10
    elif args.dataset == 'cifar100':
        train_data = dset.CIFAR100(args.data_path,
                                   train=True,
                                   transform=train_transform,
                                   download=True)
        test_data = dset.CIFAR100(args.data_path,
                                  train=False,
                                  transform=test_transform,
                                  download=True)
        num_classes = 100
    elif args.dataset == 'svhn':
        train_data = dset.SVHN(args.data_path,
                               split='train',
                               transform=train_transform,
                               download=True)
        test_data = dset.SVHN(args.data_path,
                              split='test',
                              transform=test_transform,
                              download=True)
        num_classes = 10
    elif args.dataset == 'stl10':
        train_data = dset.STL10(args.data_path,
                                split='train',
                                transform=train_transform,
                                download=True)
        test_data = dset.STL10(args.data_path,
                               split='test',
                               transform=test_transform,
                               download=True)
        num_classes = 10
    elif args.dataset == 'imagenet':
        train_dir = os.path.join(args.data_path, 'train')
        test_dir = os.path.join(args.data_path, 'val')
        train_data = dset.ImageFolder(train_dir, transform=train_transform)
        test_data = dset.ImageFolder(test_dir, transform=test_transform)
        num_classes = 1000
    else:
        assert False, 'Do not support dataset : {}'.format(args.dataset)

    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=args.workers,
                                              pin_memory=True)

    print_log("=> creating model '{}'".format(args.arch), log)

    # Init model, criterion, and optimizer
    net = models.__dict__[args.arch](num_classes)
    print_log("=> network :\n {}".format(net), log)

    if args.use_cuda:
        if args.ngpu > 1:
            net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))
        else:
            net = torch.nn.DataParallel(net, device_ids=[0])

    # define loss function (criterion) and optimizer
    criterion = torch.nn.CrossEntropyLoss()

    # params without threshold
    all_param = [
        param for name, param in net.named_parameters()
        if not 'delta_th' in name
    ]

    th_param = [
        param for name, param in net.named_parameters() if 'delta_th' in name
    ]

    if args.optimizer == "SGD":
        print("using SGD as optimizer")
        optimizer = torch.optim.SGD(all_param,
                                    lr=state['learning_rate'],
                                    momentum=state['momentum'],
                                    weight_decay=state['decay'],
                                    nesterov=True)
        optimizer_th = torch.optim.SGD(th_param,
                                       lr=state['learning_rate'],
                                       momentum=state['momentum'],
                                       weight_decay=state['decay'],
                                       nesterov=True)

    elif args.optimizer == "Adam":
        print("using Adam as optimizer")
        optimizer = torch.optim.Adam(all_param,
                                     lr=state['learning_rate'],
                                     weight_decay=state['decay'])

        optimizer_th = torch.optim.SGD(th_param,
                                       lr=state['learning_rate'],
                                       momentum=state['momentum'],
                                       weight_decay=0,
                                       nesterov=True)

    elif args.optimizer == "YF":
        print("using YellowFin as optimizer")
        optimizer = YFOptimizer(filter(lambda param: param.requires_grad,
                                       net.parameters()),
                                lr=state['learning_rate'],
                                mu=state['momentum'],
                                weight_decay=state['decay'])
    # optimizer = YFOptimizer( filter(lambda param: param.requires_grad, net.parameters()) )
    elif args.optimizer == "RMSprop":
        print("using RMSprop as optimizer")
        optimizer = torch.optim.RMSprop(filter(
            lambda param: param.requires_grad, net.parameters()),
                                        lr=state['learning_rate'],
                                        alpha=0.99,
                                        eps=1e-08,
                                        weight_decay=0,
                                        momentum=0)

    if args.use_cuda:
        net.cuda()
        criterion.cuda()

    recorder = RecorderMeter(args.epochs)  # count number of epoches

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print_log("=> loading checkpoint '{}'".format(args.resume), log)
            checkpoint = torch.load(args.resume)
            if not (args.fine_tune):
                args.start_epoch = checkpoint['epoch']
                recorder = checkpoint['recorder']
                optimizer.load_state_dict(checkpoint['optimizer'])

            state_tmp = net.state_dict()
            if 'state_dict' in checkpoint.keys():
                state_tmp.update(checkpoint['state_dict'])
            else:
                state_tmp.update(checkpoint)

            net.load_state_dict(state_tmp)

            print_log(
                "=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, args.start_epoch), log)
        else:
            print_log("=> no checkpoint found at '{}'".format(args.resume),
                      log)
    else:
        print_log(
            "=> do not use any checkpoint for {} model".format(args.arch), log)

    # Right after the pretrained model is loaded:
    '''
    when model is loaded with the pre-trained model, the original
    initialized threshold are not correct anymore, which might be clipped
    by the hard-tanh function.
    '''
    for name, module in net.named_modules():
        name = name.replace('.', '/')
        class_name = str(module.__class__).split('.')[-1].split("'")[0]
        if "quanConv2d" in class_name or "quanLinear" in class_name:
            module.delta_th.data = module.weight.abs().max(
            ) * module.init_factor.cuda()

    if args.evaluate:
        validate(test_loader, net, criterion, log)
        return

    # set the graident register hook to modify the gradient (gradient clipping)
    for name, param in net.named_parameters():
        if "delta_th" in name:
            # if "delta_th" in name and 'classifier' in name:
            # based on previous experiment, the clamp interval would better range between 0.001
            param.register_hook(lambda grad: grad.clamp(min=-0.001, max=0.001))

    # Main loop
    start_time = time.time()
    epoch_time = AverageMeter()

    for epoch in range(args.start_epoch, args.epochs):
        current_learning_rate, current_momentum = adjust_learning_rate(
            optimizer, epoch, args.gammas, args.schedule)
        current_learning_rate, current_momentum = adjust_learning_rate(
            optimizer_th, epoch, args.gammas, args.schedule)

        # Display simulation time
        need_hour, need_mins, need_secs = convert_secs2time(
            epoch_time.avg * (args.epochs - epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(
            need_hour, need_mins, need_secs)

        print_log(
            '\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [LR={:6.4f}][M={:1.2f}]'.format(time_string(), epoch, args.epochs,
                                                                                   need_time, current_learning_rate,
                                                                                   current_momentum) \
            + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False),
                                                               100 - recorder.max_accuracy(False)), log)

        # ============ TensorBoard logging ============#
        # we show the model param initialization to give a intuition when we do the fine tuning

        for name, param in net.named_parameters():
            name = name.replace('.', '/')
            if "delta_th" not in name:
                writer.add_histogram(name, param.cpu().detach().numpy(), epoch)

        for name, module in net.named_modules():
            name = name.replace('.', '/')
            class_name = str(module.__class__).split('.')[-1].split("'")[0]
            if "quanConv2d" in class_name or "quanLinear" in class_name:
                sparsity = Sparsity_check(module)
                writer.add_scalar(name + '/sparsity/', sparsity, epoch)
                # writer.add_histogram(name + '/ternweight/', tern_weight.detach().numpy(), epoch + 1)

        # ============ TensorBoard logging ============#

        # train for one epoch
        train_acc, train_los = train(train_loader, net, criterion, optimizer,
                                     optimizer_th, epoch, log)

        # evaluate on validation set
        val_acc, val_los = validate(test_loader, net, criterion, log)
        recorder.update(epoch, train_los, train_acc, val_los, val_acc)
        is_best = val_acc >= recorder.max_accuracy(False)

        if args.model_only:
            checkpoint_state = {'state_dict': net.state_dict}
        else:
            checkpoint_state = {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': net.state_dict(),
                'recorder': recorder,
                'optimizer': optimizer.state_dict(),
            }

        save_checkpoint(checkpoint_state, is_best, args.save_path,
                        'checkpoint.pth.tar', log)

        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()
        recorder.plot_curve(os.path.join(args.save_path, 'curve.png'))

        # save addition accuracy log for plotting
        accuracy_logger(base_dir=args.save_path,
                        epoch=epoch,
                        train_accuracy=train_acc,
                        test_accuracy=val_acc)

        # ============ TensorBoard logging ============#

        for name, param in net.named_parameters():
            name = name.replace('.', '/')
            writer.add_histogram(name + '/grad',
                                 param.grad.cpu().detach().numpy(), epoch + 1)

        # for name, module in net.named_modules():
        #     name = name.replace('.', '/')
        #     class_name = str(module.__class__).split('.')[-1].split("'")[0]
        #     if "quanConv2d" in class_name or "quanLinear" in class_name:
        #         sparsity = Sparsity_check(module)
        #         writer.add_scalar(name + '/sparsity/', sparsity, epoch + 1)
        #         # writer.add_histogram(name + '/ternweight/', tern_weight.detach().numpy(), epoch + 1)

        for name, module in net.named_modules():
            name = name.replace('.', '/')
            class_name = str(module.__class__).split('.')[-1].split("'")[0]
            if "quanConv2d" in class_name or "quanLinear" in class_name:
                if module.delta_th.data is not None:
                    if module.delta_th.dim(
                    ) == 0:  # zero-dimension tensor (scalar) not iterable
                        writer.add_scalar(name + '/delta/',
                                          module.delta_th.detach(), epoch + 1)
                    else:
                        for idx, delta in enumerate(module.delta_th.detach()):
                            writer.add_scalar(
                                name + '/delta/' + '{}'.format(idx), delta,
                                epoch + 1)

        writer.add_scalar('loss/train_loss', train_los, epoch + 1)
        writer.add_scalar('loss/test_loss', val_los, epoch + 1)
        writer.add_scalar('accuracy/train_accuracy', train_acc, epoch + 1)
        writer.add_scalar('accuracy/test_accuracy', val_acc, epoch + 1)
    # ============ TensorBoard logging ============#

    log.close()
def main():
  # Init logger
  
  if not os.path.isdir(args.save_path):
    os.makedirs(args.save_path)
  log = open(os.path.join(args.save_path, 'log_seed_{}.txt'.format(args.manualSeed)), 'w')
  print_log('save path : {}'.format(args.save_path), log)
  state = {k: v for k, v in args._get_kwargs()}
  print_log(state, log)
  print_log("Random Seed: {}".format(args.manualSeed), log)
  print_log("python version : {}".format(sys.version.replace('\n', ' ')), log)
  print_log("torch  version : {}".format(torch.__version__), log)
  print_log("cudnn  version : {}".format(torch.backends.cudnn.version()), log)

  # Init dataset
  
  if not os.path.exists(args.data_path):
    os.makedirs(args.data_path)

  if args.dataset == 'cifar10':
    mean = [x / 255 for x in [125.3, 123.0, 113.9]]
    std = [x / 255 for x in [63.0, 62.1, 66.7]]
  elif args.dataset == 'cifar100':
    mean = [x / 255 for x in [129.3, 124.1, 112.4]]
    std = [x / 255 for x in [68.2, 65.4, 70.4]]
  else:
    assert False, "Unknow dataset : {}".format(args.dataset)

  train_transform = transforms.Compose(
    [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(),
     transforms.Normalize(mean, std)])
  test_transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize(mean, std)])

  if args.dataset == 'cifar10':
    train_data = dset.CIFAR10(args.data_path, train=True, transform=train_transform, download=True)
    test_data = dset.CIFAR10(args.data_path, train=False, transform=test_transform, download=True)
    num_classes = 10
  elif args.dataset == 'cifar100':
    train_data = dset.CIFAR100(args.data_path, train=True, transform=train_transform, download=True)
    test_data = dset.CIFAR100(args.data_path, train=False, transform=test_transform, download=True)
    num_classes = 100
  elif args.dataset == 'svhn':
    train_data = dset.SVHN(args.data_path, split='train', transform=train_transform, download=True)
    test_data = dset.SVHN(args.data_path, split='test', transform=test_transform, download=True)
    num_classes = 10
  elif args.dataset == 'stl10':
    train_data = dset.STL10(args.data_path, split='train', transform=train_transform, download=True)
    test_data = dset.STL10(args.data_path, split='test', transform=test_transform, download=True)
    num_classes = 10
  elif args.dataset == 'imagenet':
    assert False, 'Do not finish imagenet code'
  else:
    assert False, 'Do not support dataset : {}'.format(args.dataset)

  train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True,
                         num_workers=args.workers, pin_memory=True)
  test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=False,
                        num_workers=args.workers, pin_memory=True)

  print_log("=> creating model '{}'".format(args.arch), log)
  # Init model, criterion, and optimizer
  net = models.__dict__[args.arch](num_classes)
  print_log("=> network :\n {}".format(net), log)

  net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))

  # define loss function (criterion) and optimizer
  criterion = torch.nn.CrossEntropyLoss()
  optimizer = torch.optim.SGD(net.parameters(), state['learning_rate'], momentum=state['momentum'],
                weight_decay=state['decay'], nesterov=False)

  if args.use_cuda:
    net.cuda()
    criterion.cuda()

  recorder = RecorderMeter(args.epochs)
  # optionally resume from a checkpoint
  if args.resume:
    if os.path.isfile(args.resume):
      print_log("=> loading checkpoint '{}'".format(args.resume), log)
      checkpoint = torch.load(args.resume)
      recorder = checkpoint['recorder']
      args.start_epoch = checkpoint['epoch']
      net.load_state_dict(checkpoint['state_dict'])
      optimizer.load_state_dict(checkpoint['optimizer'])
      print_log("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch']), log)
    else:
      raise ValueError("=> no checkpoint found at '{}'".format(args.resume))
  else:
    print_log("=> do not use any checkpoint for {} model".format(args.arch), log)

  if args.evaluate:
    validate(test_loader, net, criterion, log)
    return

  # Main loop
  start_time = time.time()
  epoch_time = AverageMeter()
  for epoch in range(args.start_epoch, args.epochs):
    current_learning_rate = adjust_learning_rate(optimizer, epoch, args.gammas, args.schedule)

    need_hour, need_mins, need_secs = convert_secs2time(epoch_time.avg * (args.epochs-epoch))
    need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(need_hour, need_mins, need_secs)

    print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f}]'.format(time_string(), epoch, args.epochs, need_time, current_learning_rate) \
                + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False), 100-recorder.max_accuracy(False)), log)

    # train for one epoch
    train_acc, train_los = train(train_loader, net, criterion, optimizer, epoch, log)

    # evaluate on validation set
    #val_acc,   val_los   = extract_features(test_loader, net, criterion, log)
    val_acc,   val_los   = validate(test_loader, net, criterion, log)
    is_best = recorder.update(epoch, train_los, train_acc, val_los, val_acc)

    save_checkpoint({
      'epoch': epoch + 1,
      'arch': args.arch,
      'state_dict': net.state_dict(),
      'recorder': recorder,
      'optimizer' : optimizer.state_dict(),
      'args'      : copy.deepcopy(args),
    }, is_best, args.save_path, 'hb16_10check.pth.tar')

    # measure elapsed time
    epoch_time.update(time.time() - start_time)
    start_time = time.time()
    recorder.plot_curve( os.path.join(args.save_path, 'hb16_10.png') )

  log.close()
def main():

    ### transfer data from source to current node#####
    print ("Copying the dataset to the current node's  dir...")

    tmp = args.temp_dir
    home = args.home_dir


    dataset=args.dataset
    data_source_dir = os.path.join(home,'data',dataset)
    if not os.path.exists(data_source_dir):
        os.makedirs(data_source_dir)
    data_target_dir = os.path.join(tmp,'data',dataset)
    copy_tree(data_source_dir, data_target_dir)

    ### set up the experiment directories########
    exp_name=experiment_name(arch=args.arch,
                    epochs=args.epochs,
                    dropout=args.dropout,
                    batch_size=args.batch_size,
                    lr=args.learning_rate,
                    momentum=args.momentum,
                    alpha = args.alpha,
                    decay= args.decay,
                    data_aug=args.data_aug,
                    dualcutout=args.dualcutout,
                    singlecutout = args.singlecutout,
                    cutsize = args.cutsize,
                    manualSeed=args.manualSeed,
                    job_id=args.job_id,
                    add_name=args.add_name)
    temp_model_dir = os.path.join(tmp,'experiments/DualCutout/'+dataset+'/model/'+ exp_name)
    temp_result_dir = os.path.join(tmp, 'experiments/DualCutout/'+dataset+'/results/'+ exp_name)
    model_dir = os.path.join(home, 'experiments/DualCutout/'+dataset+'/model/'+ exp_name)
    result_dir = os.path.join(home, 'experiments/DualCutout/'+dataset+'/results/'+ exp_name)


    if not os.path.exists(temp_model_dir):
        os.makedirs(temp_model_dir)

    if not os.path.exists(temp_result_dir):
        os.makedirs(temp_result_dir)

    copy_script_to_folder(os.path.abspath(__file__), temp_result_dir)

    result_png_path = os.path.join(temp_result_dir, 'results.png')


    global best_acc

    log = open(os.path.join(temp_result_dir, 'log.txt'.format(args.manualSeed)), 'w')
    print_log('save path : {}'.format(temp_result_dir), log)
    state = {k: v for k, v in args._get_kwargs()}
    print_log(state, log)
    print_log("Random Seed: {}".format(args.manualSeed), log)
    print_log("python version : {}".format(sys.version.replace('\n', ' ')), log)
    print_log("torch  version : {}".format(torch.__version__), log)
    print_log("cudnn  version : {}".format(torch.backends.cudnn.version()), log)


    train_loader, test_loader,num_classes=load_data(args.data_aug, args.batch_size,args.workers,args.dataset, data_target_dir)

    print_log("=> creating model '{}'".format(args.arch), log)
    # Init model, criterion, and optimizer

    net = models.__dict__[args.arch](num_classes,args.dropout)
    print_log("=> network :\n {}".format(net), log)

    #net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))

    # define loss function (criterion) and optimizer
    criterion = torch.nn.CrossEntropyLoss()

    optimizer = torch.optim.SGD(net.parameters(), state['learning_rate'], momentum=state['momentum'],
                weight_decay=state['decay'], nesterov=True)


    cutout = Cutout(1, args.cutsize)
    if args.use_cuda:
        net.cuda()
    criterion.cuda()

    recorder = RecorderMeter(args.epochs)
    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print_log("=> loading checkpoint '{}'".format(args.resume), log)
            checkpoint = torch.load(args.resume)
            recorder = checkpoint['recorder']
            args.start_epoch = checkpoint['epoch']
            net.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            best_acc = recorder.max_accuracy(False)
            print_log("=> loaded checkpoint '{}' accuracy={} (epoch {})" .format(args.resume, best_acc, checkpoint['epoch']), log)
        else:
            print_log("=> no checkpoint found at '{}'".format(args.resume), log)
    else:
        print_log("=> do not use any checkpoint for {} model".format(args.arch), log)

    if args.evaluate:
        validate(test_loader, net, criterion, log)
        return

    # Main loop
    start_time = time.time()
    epoch_time = AverageMeter()
    # Main loop
    train_loss = []
    train_acc=[]
    test_loss=[]
    test_acc=[]
    for epoch in range(args.start_epoch, args.epochs):
        current_learning_rate = adjust_learning_rate(optimizer, epoch, args.gammas, args.schedule)

        need_hour, need_mins, need_secs = convert_secs2time(epoch_time.avg * (args.epochs-epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(need_hour, need_mins, need_secs)

        print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f}]'.format(time_string(), epoch, args.epochs, need_time, current_learning_rate) \
                + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False), 100-recorder.max_accuracy(False)), log)

        # train for one epoch
        tr_acc, tr_los = train(train_loader, net, criterion, cutout, optimizer, epoch, log)

        # evaluate on validation set
        val_acc,   val_los   = validate(test_loader, net, criterion, log)
        train_loss.append(tr_los)
        train_acc.append(tr_acc)
        test_loss.append(val_los)
        test_acc.append(val_acc)
        dummy = recorder.update(epoch, tr_los, tr_acc, val_los, val_acc)

        is_best = False
        if val_acc > best_acc:
            is_best = True
            best_acc = val_acc

        save_checkpoint({
          'epoch': epoch + 1,
          'arch': args.arch,
          'state_dict': net.state_dict(),
          'recorder': recorder,
          'optimizer' : optimizer.state_dict(),
        }, is_best, temp_model_dir, 'checkpoint.pth.tar')

        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()
        recorder.plot_curve(result_png_path)

    train_log = OrderedDict()
    train_log['train_loss'] = train_loss
    train_log['train_acc']=train_acc
    train_log['test_loss']=test_loss
    train_log['test_acc']=test_acc

    pickle.dump(train_log, open( os.path.join(temp_result_dir,'log.pkl'), 'wb'))
    plotting(temp_result_dir)

    copy_tree(temp_model_dir, model_dir)
    copy_tree(temp_result_dir, result_dir)

    rmtree(temp_model_dir)
    rmtree(temp_result_dir)

    log.close()
Exemplo n.º 20
0
def main():
    if not os.path.isdir(args.save_path):
        os.makedirs(args.save_path)
    log = open(
        os.path.join(args.save_path, "log_seed_{}.txt".format(args.manualSeed)), "w"
    )
    print_log("save path : {}".format(args.save_path), log)
    state = {k: v for k, v in args._get_kwargs()}
    print_log(state, log)
    print_log("Random Seed: {}".format(args.manualSeed), log)
    print_log("python version : {}".format(sys.version.replace("\n", " ")), log)
    print_log("torch  version : {}".format(torch.__version__), log)
    print_log("cudnn  version : {}".format(torch.backends.cudnn.version()), log)

    # Init dataset
    if not os.path.isdir(args.data_path):
        os.makedirs(args.data_path)

    if args.dataset == "cifar10":
        mean = [x / 255 for x in [125.3, 123.0, 113.9]]
        std = [x / 255 for x in [63.0, 62.1, 66.7]]
    elif args.dataset == "cifar100":
        mean = [x / 255 for x in [129.3, 124.1, 112.4]]
        std = [x / 255 for x in [68.2, 65.4, 70.4]]
    else:
        assert False, "Unknow dataset : {}".format(args.dataset)

    train_transform = transforms.Compose(
        [
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, padding=4),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ]
    )
    test_transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize(mean, std)]
    )

    if args.dataset == "cifar10":
        train_data = dset.CIFAR10(
            args.data_path, train=True, transform=train_transform, download=True
        )
        test_data = dset.CIFAR10(
            args.data_path, train=False, transform=test_transform, download=True
        )
        num_classes = 10
    elif args.dataset == "cifar100":
        train_data = dset.CIFAR100(
            args.data_path, train=True, transform=train_transform, download=True
        )
        test_data = dset.CIFAR100(
            args.data_path, train=False, transform=test_transform, download=True
        )
        num_classes = 100
    elif args.dataset == "svhn":
        train_data = dset.SVHN(
            args.data_path, split="train", transform=train_transform, download=True
        )
        test_data = dset.SVHN(
            args.data_path, split="test", transform=test_transform, download=True
        )
        num_classes = 10
    elif args.dataset == "stl10":
        train_data = dset.STL10(
            args.data_path, split="train", transform=train_transform, download=True
        )
        test_data = dset.STL10(
            args.data_path, split="test", transform=test_transform, download=True
        )
        num_classes = 10
    elif args.dataset == "imagenet":
        assert False, "Do not finish imagenet code"
    else:
        assert False, "Do not support dataset : {}".format(args.dataset)

    train_loader = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.workers,
        pin_memory=True,
    )
    test_loader = torch.utils.data.DataLoader(
        test_data,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=True,
    )

    # Init model, criterion, and optimizer
    # net = models.__dict__[args.arch](num_classes).cuda()
    net = SENet34()

    # define loss function (criterion) and optimizer
    criterion = F.nll_loss
    optimizer = torch.optim.SGD(
        net.parameters(),
        state["learning_rate"],
        momentum=state["momentum"],
        weight_decay=state["decay"],
        nesterov=True,
    )

    if args.use_cuda:
        net.cuda()

    recorder = RecorderMeter(args.epochs)
    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print_log("=> loading checkpoint '{}'".format(args.resume), log)
            checkpoint = torch.load(args.resume)
            recorder = checkpoint["recorder"]
            args.start_epoch = checkpoint["epoch"]
            net.load_state_dict(checkpoint["state_dict"])
            optimizer.load_state_dict(checkpoint["optimizer"])
            print_log(
                "=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint["epoch"]
                ),
                log,
            )
        else:
            print_log("=> no checkpoint found at '{}'".format(args.resume), log)
    else:
        print_log("=> do not use any checkpoint for model", log)

    if args.evaluate:
        validate(test_loader, net, criterion, log)
        return

    # Main loop
    start_time = time.time()
    epoch_time = AverageMeter()
    for epoch in range(args.start_epoch, args.epochs):
        current_learning_rate = adjust_learning_rate(
            optimizer, epoch, args.gammas, args.schedule
        )

        need_hour, need_mins, need_secs = convert_secs2time(
            epoch_time.avg * (args.epochs - epoch)
        )
        need_time = "[Need: {:02d}:{:02d}:{:02d}]".format(
            need_hour, need_mins, need_secs
        )

        print_log(
            "\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f}]".format(
                time_string(), epoch, args.epochs, need_time, current_learning_rate
            )
            + " [Best : Accuracy={:.2f}, Error={:.2f}]".format(
                recorder.max_accuracy(False), 100 - recorder.max_accuracy(False)
            ),
            log,
        )

        # train for one epoch
        train_acc, train_los = train(
            train_loader, net, criterion, optimizer, epoch, log
        )

        # evaluate on validation set
        val_acc, val_los = validate(test_loader, net, criterion, log)
        is_best = recorder.update(epoch, train_los, train_acc, val_los, val_acc)

        save_checkpoint(
            {
                "epoch": epoch + 1,
                "state_dict": net.state_dict(),
                "recorder": recorder,
                "optimizer": optimizer.state_dict(),
            },
            is_best,
            args.save_path,
            "checkpoint.pth.tar",
        )

        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()
        recorder.plot_curve(os.path.join(args.save_path, "curve.png"))

    log.close()
Exemplo n.º 21
0
def main():
    # Init logger
    if not os.path.isdir(args.save_path):
        os.makedirs(args.save_path)
    log = open(os.path.join(args.save_path, 'log_seed_{}.txt'.format(args.manualSeed)), 'w')
    print_log('save path : {}'.format(args.save_path), log)
    state = {k: v for k, v in args._get_kwargs()}
    print_log(state, log)
    print_log("Random Seed: {}".format(args.manualSeed), log)
    print_log("CS Pruning Rate: {}".format(args.prune_rate_cs), log)
    print_log("GM Pruning Rate: {}".format(args.prune_rate_gm), log)
    print_log("Layer Begin: {}".format(args.layer_begin), log)
    print_log("Layer End: {}".format(args.layer_end), log)
    print_log("Layer Inter: {}".format(args.layer_inter), log)
    print_log("Epoch prune: {}".format(args.epoch_prune), log)
    print_log("use pretrain: {}".format(args.use_pretrain), log)
    print_log("Pretrain path: {}".format(args.pretrain_path), log)

    # Init dataset
    if not os.path.isdir(args.data_path):
        os.makedirs(args.data_path)

    n_mels = 32
    if args.input == 'mel40':
        n_mels = 40

    data_aug_transform = transforms.Compose([
        speech_transforms.ChangeAmplitude(),
        speech_transforms.ChangeSpeedAndPitchAudio(),
        speech_transforms.FixAudioLength(),
        speech_transforms.ToSTFT(),
        speech_transforms.StretchAudioOnSTFT(),
        speech_transforms.TimeshiftAudioOnSTFT(),
        speech_transforms.FixSTFTDimension()])

    parser.add_argument("--background_noise", type=str, default='datasets/gsc/train/_background_noise_',
                        help='path of background noise')

    backgroundNoisePname = os.path.join(args.data_path, 'train\_background_noise_')
    bg_dataset = gsc_utils.BackgroundNoiseDataset(backgroundNoisePname, data_aug_transform)
    add_bg_noise = speech_transforms.AddBackgroundNoiseOnSTFT(bg_dataset)

    train_feature_transform = transforms.Compose([
        speech_transforms.ToMelSpectrogramFromSTFT(n_mels=n_mels),
        speech_transforms.DeleteSTFT(),
        speech_transforms.ToTensor('mel_spectrogram', 'input')])

    train_dataset = gsc_utils.SpeechCommandsDataset(
        os.path.join(args.data_path, 'train'),
        transforms.Compose([speech_transforms.LoadAudio(),
                            data_aug_transform,
                            add_bg_noise,
                            train_feature_transform]))

    valid_feature_transform = transforms.Compose([
        speech_transforms.ToMelSpectrogram(n_mels=n_mels),
        speech_transforms.ToTensor('mel_spectrogram', 'input')])
    valid_dataset = gsc_utils.SpeechCommandsDataset(
        os.path.join(args.data_path, 'valid'),
        transforms.Compose([
            speech_transforms.LoadAudio(),
            speech_transforms.FixAudioLength(),
            valid_feature_transform]))
    test_dataset = gsc_utils.SpeechCommandsDataset(
        os.path.join(args.data_path, 'test'),
        transforms.Compose([
            speech_transforms.LoadAudio(),
            speech_transforms.FixAudioLength(),
            valid_feature_transform]),
        silence_percentage=0)


    weights = train_dataset.make_weights_for_balanced_classes()
    sampler = WeightedRandomSampler(weights, len(weights))
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, sampler=sampler,
                                               num_workers=args.workers, pin_memory=True)
    train_prune_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_prune_size,
                                                     num_workers=args.workers, pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False,
                                              num_workers=args.workers, pin_memory=True)

    num_classes = len(gsc_utils.CLASSES)



    print_log("=> creating model '{}'".format(args.arch), log)

    # Init model, criterion, and optimizer
    net = models.__dict__[args.arch](num_classes=num_classes, in_channels=1, fctMultLinLyr=64)
    print_log("=> network :\n {}".format(net), log)

    net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))

    # define loss function (criterion) and optimizer
    criterion = torch.nn.CrossEntropyLoss()

    optimizer = torch.optim.SGD(net.parameters(), state['learning_rate'], momentum=state['momentum'],
                                weight_decay=state['decay'], nesterov=True)

    if args.use_cuda:
        net.cuda()
        criterion.cuda()

    if args.use_pretrain:
        pretrain = torch.load(args.pretrain_path)
        if args.use_state_dict:
            net.load_state_dict(pretrain['state_dict'])
        else:
            net = pretrain['state_dict']

    recorder = RecorderMeter(args.epochs)

    mdlIdx2ConvIdx = [] # module index to conv filter index
    for index1, layr in enumerate(net.modules()):
        if isinstance(layr, torch.nn.Conv2d):
            mdlIdx2ConvIdx.append(index1)

    prmIdx2ConvIdx = [] # parameter index to conv filter index
    for index2, item in enumerate(net.parameters()):
        if len(item.size()) == 4:
            prmIdx2ConvIdx.append(index2)

    # set index of last layer depending on the known architecture
    if args.arch == 'resnet20':
        args.layer_end = 54
    elif args.arch == 'resnet56':
        args.layer_end = 162
    elif args.arch == 'resnet110':
        args.layer_end = 324
    else:
        pass # unkonwn architecture, use input value

    # asymptotic schedule
    total_pruning_rate = args.prune_rate_gm + args.prune_rate_cs
    compress_rates_total, scalling_factors, compress_rates_cs, compress_rates_fpgm, e2 =\
        cmpAsymptoticSchedule(theta3=total_pruning_rate, e3=args.epochs-1, tau=args.tau, theta_cs_final=args.prune_rate_cs, scaling_attn=args.scaling_attenuation) # tau=8.
    keep_rate_cs = 1. - compress_rates_cs

    if args.use_zero_scaling:
        scalling_factors = np.zeros(scalling_factors.shape)

    m = Mask(net, train_prune_loader, mdlIdx2ConvIdx, prmIdx2ConvIdx, scalling_factors, keep_rate_cs, compress_rates_fpgm, args.max_iter_cs)
    m.set_curr_epoch(0)
    m.set_epoch_cs(args.epoch_apply_cs)
    m.init_selected_filts()
    m.init_length()

    val_acc_1, val_los_1 = validate(test_loader, net, criterion, log)

    m.model = net
    m.init_mask(keep_rate_cs[0], compress_rates_fpgm[0], scalling_factors[0])
    #    m.if_zero()
    m.do_mask()
    m.do_similar_mask()
    net = m.model
    #    m.if_zero()
    if args.use_cuda:
        net = net.cuda()
    val_acc_2, val_los_2 = validate(test_loader, net, criterion, log)

    # Main loop
    start_time = time.time()
    epoch_time = AverageMeter()

    for epoch in range(args.start_epoch, args.epochs):

        current_learning_rate = adjust_learning_rate(optimizer, epoch, args.gammas, args.schedule)

        need_hour, need_mins, need_secs = convert_secs2time(epoch_time.avg * (args.epochs - epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(need_hour, need_mins, need_secs)

        print_log(
            '\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f}]'.format(time_string(), epoch, args.epochs,
                                                                                   need_time, current_learning_rate) \
            + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False),
                                                               100 - recorder.max_accuracy(False)), log)

        # train for one epoch
        train_acc, train_los = train(train_loader, net, criterion, optimizer, epoch, log, m)
        
        # evaluate on validation set
        if epoch % args.epoch_prune == 0 or epoch == args.epochs - 1:
            m.model = net
            m.set_curr_epoch(epoch)
            # m.if_zero()
            m.init_mask(keep_rate_cs[epoch], compress_rates_fpgm[epoch], scalling_factors[epoch])
            m.do_mask()
            m.do_similar_mask()
            # m.if_zero()
            net = m.model
            if args.use_cuda:
                net = net.cuda()
            if epoch == args.epochs - 1:
                m.if_zero()

        val_acc_2, val_los_2 = validate(test_loader, net, criterion, log)

        is_best = recorder.update(epoch, train_los, train_acc, val_los_2, val_acc_2)

        save_checkpoint({
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': net,
            'recorder': recorder,
            'optimizer': optimizer.state_dict(),
        }, is_best, args.save_path, 'checkpoint.pth.tar')

        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()
        recorder.plot_curve(os.path.join(args.save_path, 'curve.png'))

    log.close()