コード例 #1
0
ファイル: main.py プロジェクト: zhengyu-yang/sparse_learning
def main():
    global args, best_prec1

    args = parser.parse_args()
    setup_logger(args)

    if args.fp16:
        try:
            from apex.fp16_utils import FP16_Optimizer
        except:
            print_and_log(
                'WARNING: apex not installed, ignoring --fp16 option')
            args.fp16 = False

    kwargs = {'num_workers': 1, 'pin_memory': True}
    dataset = args.model.split('_')[0]
    if dataset == 'mnist':
        full_dataset = datasets.MNIST('./data',
                                      train=True,
                                      download=True,
                                      transform=transforms.Compose([
                                          transforms.ToTensor(),
                                          transforms.Normalize((0.1307, ),
                                                               (0.3081, ))
                                      ]))

        if not (args.validate_set):
            train_loader = torch.utils.data.DataLoader(
                full_dataset,
                batch_size=args.batch_size,
                shuffle=True,
                **kwargs)
            val_loader = None
        else:
            train_dataset = split_dataset(full_dataset, split_end=50000)
            val_dataset = split_dataset(full_dataset, split_start=50000)
            train_loader = torch.utils.data.DataLoader(
                train_dataset,
                batch_size=args.batch_size,
                shuffle=True,
                **kwargs)
            val_loader = torch.utils.data.DataLoader(
                val_dataset,
                batch_size=args.batch_size,
                shuffle=False,
                **kwargs)

        test_loader = torch.utils.data.DataLoader(datasets.MNIST(
            './data',
            train=False,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307, ), (0.3081, ))
            ])),
                                                  batch_size=args.batch_size,
                                                  shuffle=False,
                                                  **kwargs)

    elif dataset == 'cifar10':
        normalize = transforms.Normalize(
            mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
            std=[x / 255.0 for x in [63.0, 62.1, 66.7]])

        if args.augment:
            transform_train = transforms.Compose([
                transforms.ToTensor(),
                transforms.Lambda(lambda x: F.pad(x.unsqueeze(0), (4, 4, 4, 4),
                                                  mode='reflect').squeeze()),
                transforms.ToPILImage(),
                transforms.RandomCrop(32),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ])
        else:
            transform_train = transforms.Compose([
                transforms.ToTensor(),
                normalize,
            ])
        transform_test = transforms.Compose([transforms.ToTensor(), normalize])

        full_dataset = datasets.CIFAR10('./data',
                                        train=True,
                                        download=True,
                                        transform=transform_train)

        if not (args.validate_set):
            train_loader = torch.utils.data.DataLoader(
                full_dataset,
                batch_size=args.batch_size,
                shuffle=True,
                **kwargs)
            val_loader = None
        else:
            train_dataset = split_dataset(full_dataset, split_end=45000)
            val_dataset = split_dataset(full_dataset, split_start=45000)
            train_loader = torch.utils.data.DataLoader(
                train_dataset,
                batch_size=args.batch_size,
                shuffle=True,
                **kwargs)
            val_loader = torch.utils.data.DataLoader(
                val_dataset,
                batch_size=args.batch_size,
                shuffle=True,
                **kwargs)

        test_loader = torch.utils.data.DataLoader(datasets.CIFAR10(
            './data', train=False, transform=transform_test),
                                                  batch_size=args.batch_size,
                                                  shuffle=True,
                                                  **kwargs)

    elif dataset == 'imagenet':
        if not (args.data):
            raise Exception(
                'need to specify imagenet dataset location using the --data argument'
            )
        traindir = os.path.join(args.data, 'train')
        valdir = os.path.join(args.data, 'val')
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        full_dataset = datasets.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))

        train_sampler = None

        if not (args.validate_set):
            train_loader = torch.utils.data.DataLoader(
                full_dataset,
                batch_size=args.batch_size,
                shuffle=(train_sampler is None),
                num_workers=args.workers,
                pin_memory=True,
                sampler=train_sampler)
            val_loader = None

        else:
            train_dataset = split_dataset(full_dataset,
                                          split_end=len(full_dataset) - 10000)
            val_dataset = split_dataset(full_dataset,
                                        split_start=len(full_dataset) - 10000)
            train_loader = torch.utils.data.DataLoader(
                train_dataset,
                batch_size=args.batch_size,
                shuffle=(train_sampler is None),
                num_workers=args.workers,
                pin_memory=True,
                sampler=train_sampler)

            val_loader = torch.utils.data.DataLoader(
                val_dataset,
                batch_size=args.batch_size,
                shuffle=False,
                num_workers=4,
                pin_memory=True)

        test_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
            valdir,
            transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ])),
                                                  batch_size=args.batch_size,
                                                  shuffle=False,
                                                  num_workers=args.workers,
                                                  pin_memory=True)

    else:
        raise RuntimeError(
            'Unknown dataset {}. Dataset is first segment of network name'.
            format(dataset))

    print_and_log(args)
    with open(args.schedule_file, 'r') as stream:
        try:
            loaded_schedule = yaml.load(stream)
        except yaml.YAMLError as exc:
            print_and_log(exc)

    if args.model == 'mnist_mlp':
        model = mnist_mlp(initial_sparsity=args.initial_sparsity_fc,
                          sparse=not (args.tied),
                          no_batch_norm=args.no_batch_norm)
    elif args.model == 'cifar10_WideResNet':
        model = cifar10_WideResNet(
            args.layers,
            widen_factor=args.widen_factor,
            initial_sparsity_conv=args.initial_sparsity_conv,
            initial_sparsity_fc=args.initial_sparsity_fc,
            sub_kernel_granularity=args.sub_kernel_granularity,
            sparse=not (args.tied))

    elif args.model == 'imagenet_resnet50':
        model = imagenet_resnet50(
            initial_sparsity_conv=args.initial_sparsity_conv,
            initial_sparsity_fc=args.initial_sparsity_fc,
            sub_kernel_granularity=args.sub_kernel_granularity,
            widen_factor=args.widen_factor,
            vanilla_conv1=True,
            vanilla_conv3=True,
            vanilla_downsample=True,
            sparse=not args.sparse_momentum)
    else:
        raise RuntimeError('unrecognized model name ' + repr(args.model))

    model = model.cuda()
    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                nesterov=args.nesterov,
                                weight_decay=args.weight_decay)

    if args.fp16:
        print_and_log('FP16')
        optimizer = FP16_Optimizer(optimizer,
                                   static_loss_scale=None,
                                   dynamic_loss_scale=True,
                                   dynamic_loss_args={'init_scale': 2**16})
        model = model.half()

    mask = None
    if not args.dense:
        decay = CosineDecay(args.prune_rate, len(train_loader) * (args.epochs))
        mask = Masking(optimizer,
                       decay,
                       prune_rate=args.prune_rate,
                       prune_mode='magnitude',
                       growth_mode=args.growth,
                       redistribution_mode=args.redistribution,
                       verbose=True,
                       fp16=args.fp16)
        mask.add_module(model, density=args.density)
        #mask.remove_weight_partial_name('downsample', verbose=True)
        #mask.remove_weight('conv1.weight')

    if dataset == 'imagenet':
        print_and_log('setting up data parallel')
        model = torch.nn.DataParallel(model).cuda()
        base_model = model.module
    else:
        base_model = model

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print_and_log("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            #args.start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            if 'optimizer' in checkpoint:
                optimizer.load_state_dict(checkpoint['optimizer'])
                print_and_log('OPTIM')
                mask.optimizer = optimizer
            print_and_log("=> loaded checkpoint '{}' ".format(args.resume))
        else:
            print_and_log("=> no checkpoint found at '{}'".format(args.resume))

    if args.copy_mask_from:
        if os.path.isfile(args.copy_mask_from):
            print_and_log("=> loading mask data '{}'".format(
                args.copy_mask_from))
            mask_data = torch.load(args.copy_mask_from)
            filtered_mask_data = collections.OrderedDict([
                (x, y) for (x, y) in mask_data['state_dict'].items()
                if 'mask' in x
            ])
            model.load_state_dict(filtered_mask_data, strict=False)
        else:
            print_and_log("=> no mask checkpoint found at '{}'".format(
                args.copy_mask_from))

    # get the number of model parameters
    model_size = base_model.get_model_size()

    cudnn.benchmark = True

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    train_loss_l = []
    test_loss_l = []
    train_prec1_l = []
    test_prec1_l = []
    train_prec5_l = []
    test_prec5_l = []

    val_loss_l = []
    val_prec1_l = []
    val_prec5_l = []

    prune_mode = args.prune_mode
    print_and_log('PRUNE MODE ' + str(prune_mode))

    start_pruning_after_epoch_n = args.start_pruning_after_epoch
    prune_every_epoch_n = args.prune_epoch_frequency
    prune_iterations = args.prune_iterations
    post_prune_epochs = args.post_prune_epochs

    filename = args.model + '_' + repr(args.job_idx)
    n_prunes_done = 0

    if prune_mode:
        ## Special consideration so that pruning mnist_mlp does not use less than 100 parameters in the top layer after pruning
        if args.prune_target_sparsity_fc > 0.9 and args.model == 'mnist_mlp':
            total_available_weights = (1. - args.prune_target_sparsity_fc) * (
                784 * 300 + 300 * 100 + 100 * 10) - 100
            prune_target_sparsity_special = 0.9
            prune_target_sparsity_fc = 1. - total_available_weights / (
                784 * 300 + 300 * 100)
        else:
            prune_target_sparsity_fc = prune_target_sparsity_special = args.prune_target_sparsity_fc

        prune_fraction_fc = 1.0 - (1 - prune_target_sparsity_fc)**(
            1.0 / prune_iterations)
        prune_fraction_conv = 1.0 - (1 - args.prune_target_sparsity_conv)**(
            1.0 / prune_iterations)

        prune_fraction_fc_special = 1.0 - (
            1 - prune_target_sparsity_special)**(1.0 / prune_iterations)

        cubic_pruning_multipliers = (
            1 - np.arange(prune_iterations + 1) / prune_iterations)**3.0

        def get_prune_fraction_cubic(current_prune_iter, final_sparsity):
            return 1 - (1 - final_sparsity + final_sparsity *
                        cubic_pruning_multipliers[current_prune_iter + 1]) / (
                            1 - final_sparsity + final_sparsity *
                            cubic_pruning_multipliers[current_prune_iter])

        nEpochs_to_prune = int(start_pruning_after_epoch_n +
                               prune_every_epoch_n *
                               (prune_iterations - 1)) + post_prune_epochs
        print_and_log(
            'prune fraction fc : {} , prune_fraction conv : {} '.format(
                prune_fraction_fc, prune_fraction_conv))
        print_and_log('nepochs ' + repr(nEpochs_to_prune))

        filename += '_target_' + repr(
            args.prune_target_sparsity_fc) + ',' + repr(
                args.prune_target_sparsity_conv)
        validate(test_loader, model, criterion, 1, 'validate')

    save_checkpoint(
        {
            'model_size': base_model.get_model_size(),
            'model_name': args.model,
            'state_dict': model.state_dict(),
            'args': args
        },
        filename=filename + '_initial')

    current_iteration = 0
    lr_schedule = loaded_schedule['lr_schedule']
    rewire_schedule = loaded_schedule['rewire_period_schedule']
    DeepR_temperature_schedule = loaded_schedule['DeepR_temperature_schedule']
    threshold = 1.0e-3
    if args.resume:
        print_and_log("Validating...")
        validate(test_loader, model, criterion, 1, 'validate')
    for epoch in range(args.start_epoch,
                       nEpochs_to_prune if prune_mode else args.epochs):
        adjust_learning_rate(optimizer, epoch, lr_schedule)
        rewire_period = get_schedule_val(rewire_schedule, epoch)
        DeepR_temperature = get_schedule_val(DeepR_temperature_schedule, epoch)
        print_and_log('rewiring every {} iterations'.format(rewire_period))

        t1 = time.time()
        current_iteration, threshold = train(mask, train_loader, model,
                                             criterion, optimizer, epoch,
                                             current_iteration, rewire_period,
                                             DeepR_temperature, threshold)
        print_and_log('epoch time ' + repr(time.time() - t1))

        if prune_mode and epoch >= start_pruning_after_epoch_n and (
                epoch - start_pruning_after_epoch_n
        ) % prune_every_epoch_n == 0 and n_prunes_done < prune_iterations:
            if args.cubic_prune_schedule:
                base_model.prune(
                    get_prune_fraction_cubic(n_prunes_done,
                                             prune_target_sparsity_fc),
                    get_prune_fraction_cubic(n_prunes_done,
                                             args.prune_target_sparsity_conv),
                    get_prune_fraction_cubic(n_prunes_done,
                                             prune_target_sparsity_special))
            else:
                base_model.prune(prune_fraction_fc, prune_fraction_conv,
                                 prune_fraction_fc_special)
            n_prunes_done += 1
            print_and_log(base_model.get_model_size())

        if not (args.no_validate_train):
            prec1_train, prec5_train, loss_train = validate(
                train_loader, model, criterion, epoch, 'train')
        else:
            prec1_train, prec5_train, loss_train = 0.0, 0.0, 0.0

        if args.validate_set:
            prec1_val, prec5_val, loss_val = validate(val_loader, model,
                                                      criterion, epoch,
                                                      'validate')
        else:
            prec1_val, prec5_val, loss_val = 0.0, 0.0, 0.0

        prec1_test, prec5_test, loss_test = validate(test_loader, model,
                                                     criterion, epoch, 'test')

        test_loss_l.append(loss_test)
        train_loss_l.append(loss_train)
        val_loss_l.append(loss_val)

        test_prec1_l.append(prec1_test)
        train_prec1_l.append(prec1_train)
        val_prec1_l.append(prec1_val)

        test_prec5_l.append(prec5_test)
        train_prec5_l.append(prec5_train)
        val_prec5_l.append(prec5_val)

        # remember best prec@1 and save checkpoint
        filenames = [filename]
        if epoch == args.stop_rewire_epoch:
            filenames += [filename + '_StopRewiringPoint_' + repr(epoch)]
        for f in filenames:
            save_checkpoint(
                {
                    'model_size': base_model.get_model_size(),
                    'test_loss': test_loss_l,
                    'train_loss': train_loss_l,
                    'val_loss': val_loss_l,
                    'test_prec1': test_prec1_l,
                    'train_prec1': train_prec1_l,
                    'val_prec1': val_prec1_l,
                    'test_prec5': test_prec5_l,
                    'train_prec5': train_prec5_l,
                    'val_prec5': train_prec5_l,
                    'model_name': args.model,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'epoch': epoch + 1,
                    'args': args
                },
                filename=f)

        if not args.dense and epoch < args.epochs:
            mask.at_end_of_epoch()

    print_and_log('Best accuracy: ', best_prec1)
コード例 #2
0
ファイル: main.py プロジェクト: scape1989/sparse_learning
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size',
                        type=int,
                        default=100,
                        metavar='N',
                        help='input batch size for training (default: 100)')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=100,
                        metavar='N',
                        help='input batch size for testing (default: 100)')
    parser.add_argument('--epochs',
                        type=int,
                        default=100,
                        metavar='N',
                        help='number of epochs to train (default: 100)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.1,
                        metavar='LR',
                        help='learning rate (default: 0.1)')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.9,
                        metavar='M',
                        help='SGD momentum (default: 0.9)')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed',
                        type=int,
                        default=17,
                        metavar='S',
                        help='random seed (default: 17)')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=100,
        metavar='N',
        help='how many batches to wait before logging training status')
    parser.add_argument('--save-model',
                        type=str,
                        default='./models/model.pt',
                        help='For Saving the current Model')
    parser.add_argument('--data', type=str, default='mnist')
    parser.add_argument('--augment', action='store_true')
    parser.add_argument('--decay_frequency', type=int, default=25000)
    parser.add_argument('--l1', type=float, default=0.0)
    parser.add_argument('--fp16',
                        action='store_true',
                        help='Run in fp16 mode.')
    parser.add_argument('--valid_split', type=float, default=0.1)
    parser.add_argument('--resume', type=str)
    parser.add_argument('--start-epoch', type=int, default=1)
    parser.add_argument('--model', type=str, default='')
    parser.add_argument('--l2', type=float, default=5.0e-4)
    parser.add_argument(
        '--iterations',
        type=int,
        default=1,
        help=
        'How many times the model should be run after each other. Default=1')
    parser.add_argument(
        '--save-features',
        action='store_true',
        help=
        'Resumes a saved model and saves its feature data to disk for plotting.'
    )
    parser.add_argument(
        '--bench',
        action='store_true',
        help='Enables the benchmarking of layers and estimates sparse speedups'
    )
    sparselearning.core.add_sparse_args(parser)

    args = parser.parse_args()

    if args.fp16:
        try:
            from apex.fp16_utils import FP16_Optimizer
        except:
            print('WARNING: apex not installed, ignoring --fp16 option')
            args.fp16 = False

    use_cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    print_and_log('\n\n')
    print_and_log('=' * 80)
    print_and_log('=' * 80)
    print_and_log(args)
    torch.manual_seed(args.seed)
    for i in range(args.iterations):
        print_and_log("\nIteration start: {0}/{1}\n".format(
            i + 1, args.iterations))

        if args.data == 'mnist':
            train_loader, valid_loader, test_loader = get_mnist_dataloaders(
                args, validation_split=args.valid_split)
        else:
            train_loader, valid_loader, test_loader = get_cifar10_dataloaders(
                args, args.valid_split)

        if args.model not in models:
            print(
                'You need to select an existing model via the --model argument. Available models include: '
            )
            for key in models:
                print('\t{0}'.format(key))
            raise Exception('You need to select a model')
        else:
            cls, cls_args = models[args.model]
            cls_args.append(args.save_features)
            cls_args.append(args.bench)
            model = cls(*cls_args).to(device)
            print_and_log(model)
            print_and_log('=' * 60)
            print_and_log(args.model)
            print_and_log('=' * 60)

            print_and_log('=' * 60)
            print_and_log('Death mode: {0}'.format(args.death))
            print_and_log('Growth mode: {0}'.format(args.growth))
            print_and_log('Redistribution mode: {0}'.format(
                args.redistribution))
            print_and_log('=' * 60)

        optimizer = optim.SGD(model.parameters(),
                              lr=args.lr,
                              momentum=args.momentum,
                              weight_decay=args.l2,
                              nesterov=True)
        lr_scheduler = optim.lr_scheduler.StepLR(optimizer,
                                                 args.decay_frequency,
                                                 gamma=0.1)

        if args.resume:
            if os.path.isfile(args.resume):
                print_and_log("=> loading checkpoint '{}'".format(args.resume))
                checkpoint = torch.load(args.resume)
                args.start_epoch = checkpoint['epoch']
                model.load_state_dict(checkpoint['state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer'])
                print_and_log("=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint['epoch']))
                print_and_log('Testing...')
                evaluate(args, model, device, test_loader)
                plot_class_feature_histograms(args, model, device,
                                              train_loader, optimizer)
            else:
                print_and_log("=> no checkpoint found at '{}'".format(
                    args.resume))

        if args.fp16:
            optimizer = FP16_Optimizer(optimizer,
                                       static_loss_scale=None,
                                       dynamic_loss_scale=True,
                                       dynamic_loss_args={'init_scale': 2**16})
            model = model.half()

        mask = None
        if args.sparse:
            decay = CosineDecay(args.death_rate,
                                len(train_loader) * (args.epochs))
            mask = Masking(optimizer,
                           death_mode=args.death,
                           death_rate_decay=decay,
                           growth_mode=args.growth,
                           redistribution_mode=args.redistribution)
            mask.add_module(model, density=args.density)

        for epoch in range(1, args.epochs + 1):

            t0 = time.time()
            train(args, model, device, train_loader, optimizer, epoch,
                  lr_scheduler, mask)

            if args.valid_split > 0.0:
                val_acc = evaluate(args, model, device, valid_loader)

            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict()
                },
                is_best=False,
                filename=args.save_model)

            if args.sparse and epoch < args.epochs:
                mask.at_end_of_epoch()

            print_and_log(
                'Current learning rate: {0}. Time taken for epoch: {1}.\n'.
                format(optimizer.param_groups[0]['lr'],
                       time.time() - t0))

        evaluate(args, model, device, test_loader)
        print_and_log("\nIteration end: {0}/{1}\n".format(
            i + 1, args.iterations))
コード例 #3
0
ファイル: main.py プロジェクト: scape1989/sparse_learning
def train_net(args, logger_cls):
    exp_start_time = time.time()
    global best_prec1
    best_prec1 = 0

    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1

    args.gpu = 0
    args.world_size = 1

    if args.distributed:
        args.gpu = args.local_rank % torch.cuda.device_count()
        torch.cuda.set_device(args.gpu)
        dist.init_process_group(backend='nccl', init_method='env://')
        args.world_size = torch.distributed.get_world_size()


    if args.seed is not None:
        print("Using seed = {}".format(args.seed))
        torch.manual_seed(args.seed + args.local_rank)
        torch.cuda.manual_seed(args.seed + args.local_rank)
        np.random.seed(seed=args.seed + args.local_rank)
        random.seed(args.seed + args.local_rank)

        def _worker_init_fn(id):
            np.random.seed(seed=args.seed + args.local_rank + id)
            random.seed(args.seed + args.local_rank + id)
    else:
        def _worker_init_fn(id):
            pass

    if args.fp16:
        assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled."

    if args.static_loss_scale != 1.0:
        if not args.fp16:
            print("Warning:  if --fp16 is not used, static_loss_scale will be ignored.")

    pretrained_weights = None
    if args.pretrained_weights:
        if os.path.isfile(args.pretrained_weights):
            print("=> loading pretrained weights from '{}'".format(args.pretrained_weights))
            pretrained_weights = torch.load(args.pretrained_weights)
        else:
            print("=> no pretrained weights found at '{}'".format(args.resume))

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location = lambda storage, loc: storage.cuda(args.gpu))
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model_state = checkpoint['state_dict']
            optimizer_state = checkpoint['optimizer']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
            model_state = None
            optimizer_state = None
    else:
        model_state = None
        optimizer_state = None

    model_and_loss = ModelAndLoss(args,
            (args.arch, args.model_config),
            nn.CrossEntropyLoss if args.label_smoothing == 0.0 else (lambda: LabelSmoothing(args.label_smoothing)),
            pretrained_weights=pretrained_weights,
            state=model_state,
            cuda = True, fp16 = args.fp16, distributed = args.distributed)

    # Create data loaders and optimizers as needed

    if not (args.evaluate or args.inferbench):
        optimizer = get_optimizer(list(model_and_loss.model.named_parameters()),
                args.fp16,
                args.lr, args.momentum, args.weight_decay,
                bn_weight_decay = args.bn_weight_decay,
                state=optimizer_state,
                static_loss_scale = args.static_loss_scale,
                dynamic_loss_scale = args.dynamic_loss_scale)


        train_loader = get_train_loader(args.data, args.batch_size, workers=args.workers, _worker_init_fn=_worker_init_fn)
        train_loader_len = len(train_loader)
    else:
        train_loader_len = 0

    if not args.trainbench:
        val_loader = get_val_loader(args.data, args.batch_size, workers=args.workers, _worker_init_fn=_worker_init_fn)
        val_loader_len = len(val_loader)
    else:
        val_loader_len = 0


    decay = CosineDecay(args.death_rate, len(train_loader)*args.epochs)
    mask = Masking(optimizer, death_mode=args.death, death_rate_decay=decay, growth_mode=args.growth, redistribution_mode=args.redistribution)
    model_and_loss.mask = mask
    if args.sparse:
        mask.add_module(model_and_loss.model, density=args.density)


    if args.evaluate:
        logger = logger_cls(train_loader_len, val_loader_len, args)
        validate(val_loader, model_and_loss, args.fp16, logger, 0)
        return

    if args.trainbench:
        model_and_loss.model.train()
        logger = logger_cls("Train", args.world_size * args.batch_size, args.bench_warmup)
        bench(get_train_step(model_and_loss, optimizer, args.fp16), train_loader,
              args.bench_warmup, args.bench_iterations, args.fp16, logger, epoch_warmup = True)
        return

    if args.inferbench:
        model_and_loss.model.eval()
        logger = logger_cls("Inference", args.world_size * args.batch_size, args.bench_warmup)
        bench(get_val_step(model_and_loss), val_loader,
              args.bench_warmup, args.bench_iterations, args.fp16, logger, epoch_warmup = False)
        return

    logger = logger_cls(train_loader_len, val_loader_len, args)
    train_loop(args, model_and_loss, optimizer, adjust_learning_rate(args), train_loader, val_loader, args.epochs,
            args.fp16, logger, should_backup_checkpoint(args),
            start_epoch = args.start_epoch, best_prec1 = best_prec1, prof=args.prof)

    exp_duration = time.time() - exp_start_time
    logger.experiment_timer(exp_duration)
    logger.end_callback()
    print("Experiment ended")