def automated_deep_compression(model, criterion, optimizer, loggers, args):
    train_loader, val_loader, test_loader, _ = apputils.load_data(
        args.dataset, os.path.expanduser(args.data), args.batch_size,
        args.workers, args.validation_split, args.deterministic,
        args.effective_train_size, args.effective_valid_size,
        args.effective_test_size)

    args.display_confusion = True
    validate_fn = partial(test,
                          test_loader=test_loader,
                          criterion=criterion,
                          loggers=loggers,
                          args=args,
                          activations_collectors=None)
    train_fn = partial(train,
                       train_loader=train_loader,
                       criterion=criterion,
                       loggers=loggers,
                       args=args)

    save_checkpoint_fn = partial(apputils.save_checkpoint,
                                 arch=args.arch,
                                 dir=msglogger.logdir)
    optimizer_data = {
        'lr': args.lr,
        'momentum': args.momentum,
        'weight_decay': args.weight_decay
    }
    adc.do_adc(model, args, optimizer_data, validate_fn, save_checkpoint_fn,
               train_fn)
示例#2
0
def automated_deep_compression(model, criterion, optimizer, loggers, args):
    train_loader, val_loader, test_loader, _ = load_data(args)

    args.display_confusion = True
    validate_fn = partial(test,
                          test_loader=test_loader,
                          criterion=criterion,
                          loggers=loggers,
                          args=args,
                          activations_collectors=None)
    train_fn = partial(train,
                       train_loader=train_loader,
                       criterion=criterion,
                       loggers=loggers,
                       args=args)

    save_checkpoint_fn = partial(apputils.save_checkpoint,
                                 arch=args.arch,
                                 dir=msglogger.logdir)
    optimizer_data = {
        'lr': args.lr,
        'momentum': args.momentum,
        'weight_decay': args.weight_decay
    }
    adc.do_adc(model, args, optimizer_data, validate_fn, save_checkpoint_fn,
               train_fn)
示例#3
0
def automated_deep_compression(model, criterion, optimizer, loggers,
                               args):  # 自动化的深层压缩
    train_loader, val_loader, test_loader, _ = get_data_loaders(
        datasets_fn, r'/home/tian/Desktop/image_yasuo', args.batch_size,
        args.workers, args.validation_split, args.deterministic,
        args.effective_train_size, args.effective_valid_size,
        args.effective_test_size)

    args.display_confusion = True
    validate_fn = partial(test,
                          test_loader=test_loader,
                          criterion=criterion,
                          loggers=loggers,
                          args=args,
                          activations_collectors=None)
    train_fn = partial(train,
                       train_loader=train_loader,
                       criterion=criterion,
                       loggers=loggers,
                       args=args)

    save_checkpoint_fn = partial(apputils.save_checkpoint,
                                 arch=args.arch,
                                 dir=msglogger.logdir)
    optimizer_data = {
        'lr': args.lr,
        'momentum': args.momentum,
        'weight_decay': args.weight_decay
    }
    adc.do_adc(model, args, optimizer_data, validate_fn, save_checkpoint_fn,
               train_fn)