Пример #1
0
def get_data(args):
    if args.dataset == 'mnist':
        resize = args.resize_input
        batch_size = args.batch_size
        return mnist(resize=resize, batch_size=batch_size)

    elif args.dataset == 'cifar10':
        return cifar10(args.batch_size, num_workers=0)
Пример #2
0
def main():
    global args, best_prec1, writer, total_steps, exp_flops, exp_l0, param_num
    args = parser.parse_args()
    log_dir_net = args.name
    print('modl:', args.name)
    if args.tensorboard:
        from tensorboardX import SummaryWriter
        directory = 'logs/{}/{}'.format(log_dir_net, args.name)
        if os.path.exists(directory):
            shutil.rmtree(directory)
            os.makedirs(directory)
        else:
            os.makedirs(directory)
        writer = SummaryWriter(directory)

    train_loader, val_loader, num_classes = mnist(args.batch_size, pm=False)

    model = CGESModelCNN(lamba=args.lamba)

    #set mu
    layers = model.layers if not args.multi_gpu else model.module.layers
    for k, layer in enumerate(layers):
        if isinstance(layer, CGES_Dense) or isinstance(layer, CGES_Conv2d):
            layer.set_mu(k / len(layers))

    if torch.cuda.is_available():
        model = model.cuda()
    optimizer = optim.Adam(model.parameters(), args.lr)
    print('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))
    param_num = sum([p.data.nelement() for p in model.parameters()])

    print('Number of neurons: ', model.count_total_neuron())

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            total_steps = checkpoint['total_steps']
            exp_flops = checkpoint['exp_flops']
            exp_l0 = checkpoint['exp_l0']
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
            total_steps, exp_flops, exp_l0 = 0, [], []
    cudnn.benchmark = True

    loglike = nn.CrossEntropyLoss()
    if torch.cuda.is_available():
        loglike = loglike.cuda()

    # define loss function (criterion) and optimizer
    def loss_function(output, target_var, model):
        loss = loglike(output, target_var)
        total_loss = loss + model.regularization()
        if torch.cuda.is_available():
            total_loss = total_loss.cuda()
        return total_loss

    lr_schedule = lr_scheduler.MultiStepLR(optimizer,
                                           milestones=args.epoch_drop,
                                           gamma=args.lr_decay_ratio)

    for epoch in range(args.start_epoch, args.epochs):
        train(train_loader, model, loss_function, optimizer, lr_schedule,
              epoch)
        #evaluate on validation set
        prec1 = validate(val_loader, model, loss_function, epoch)

        # remember best prec@1 and save checkpoint
        is_best = prec1 < best_prec1
        best_prec1 = min(prec1, best_prec1)
        state = {
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'best_prec1': best_prec1,
            'curr_prec1': prec1,
            'optimizer': optimizer.state_dict(),
            'total_steps': total_steps,
            'exp_flops': exp_flops,
            'exp_l0': exp_l0
        }
        save_checkpoint(state, is_best, args.name)
    print('Best error: ', best_prec1)
    if args.tensorboard:
        writer.close()
def main():
    global args, best_prec1, writer, total_steps, exp_flops, exp_l0, param_num
    args = parser.parse_args()
    log_dir_net = args.name
    print('model:', args.name)
    if args.tensorboard:
        # used for logging to TensorBoard
        from tensorboardX import SummaryWriter
        directory = 'logs/{}/{}'.format(log_dir_net, args.name)
        if os.path.exists(directory):
            shutil.rmtree(directory)
            os.makedirs(directory)
        else:
            os.makedirs(directory)
        writer = SummaryWriter(directory)
    
    # Data loading code
    print('[0, 1] normalization of input')
    train_loader, val_loader, num_classes = mnist(args.batch_size, pm=False)

    # create model
    model = group_lasso_LeNet5(num_classes, input_size=(1, 28, 28), conv_dims=(20, 50), fc_dims=500, N=60000,
                     weight_decay=args.weight_decay, lambas=args.lambas, local_rep=args.local_rep,
                     temperature=args.temp)

    optimizer = torch.optim.Adam(model.parameters(), args.lr)
    print('Number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()])))
    param_num = sum([p.data.nelement() for p in model.parameters()])

    print('Number of neurons: ', model.count_total_neuron())

    if torch.cuda.is_available():
        model = model.cuda()

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            total_steps = checkpoint['total_steps']
            exp_flops = checkpoint['exp_flops']
            exp_l0 = checkpoint['exp_l0']
            if checkpoint['beta_ema'] > 0:
                model.beta_ema = checkpoint['beta_ema']
                model.avg_param = checkpoint['avg_params']
                model.steps_ema = checkpoint['steps_ema']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
            total_steps, exp_flops, exp_l0 = 0, [], []
    cudnn.benchmark = True

    loglike = nn.CrossEntropyLoss()
    if torch.cuda.is_available():
        loglike = loglike.cuda()

    # define loss function (criterion) and optimizer
    def loss_function(output, target_var, model):
        loss = loglike(output, target_var)
        total_loss = loss + model.regularization()
        if torch.cuda.is_available():
            total_loss = total_loss.cuda()
        return total_loss

    lr_schedule = lr_scheduler.MultiStepLR(optimizer, milestones=args.epoch_drop, gamma=args.lr_decay_ratio)

    for epoch in range(args.start_epoch, args.epochs):
        # train for one epoch
        train(train_loader, model, loss_function, optimizer, lr_schedule, epoch)
        # evaluate on validation set
        prec1 = validate(val_loader, model, loss_function, epoch)

        # remember best prec@1 and save checkpoint
        is_best = prec1 < best_prec1
        best_prec1 = min(prec1, best_prec1)
        state = {
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'best_prec1': best_prec1,
            'curr_prec1': prec1,
            'beta_ema': model.beta_ema,
            'optimizer': optimizer.state_dict(),
            'total_steps': total_steps,
            'exp_flops': exp_flops,
            'exp_l0': exp_l0
        }
        if model.beta_ema > 0:
            state['avg_params'] = model.avg_param
            state['steps_ema'] = model.steps_ema
        save_checkpoint(state, is_best, args.name)
    print('Best error: ', best_prec1)
    if args.tensorboard:
        writer.close()