コード例 #1
0
ファイル: main_prune.py プロジェクト: mhariat/DeepPruning
def init_network(config, num_classes):
    kwargs = {'depth': config.depth, 'num_classes': num_classes}
    if config.network == 'vgg':
        if 'avg_pool2d' in config:
            kwargs.update({'avg_pool2d': config.avg_pool2d})
        if 'batch_norm' in config:
            kwargs.update({'batch_norm': config.batch_norm})
        net = vgg(**kwargs)
    elif config.network == 'resnet':
        dataset = config.data_dir.split('/')[-1]
        if 'cifar' in dataset.lower():
            kwargs.update({'dataset': 'cifar'})
        net = resnet(**kwargs)
    elif config.network == 'reskipnet':
        net = reskipnet(**kwargs)
    else:
        raise NotImplementedError
    assert os.path.exists('{}/checkpoint/original'.format(config.result_dir)),\
        'No checkpoint directory for original model!'
    dataset = config.data_dir.split('/')[-1]
    path_to_add = '{}/{}/{}'.format(dataset, config.network, config.depth)
    checkpoint_path, exp_name, epochs = get_best_checkpoint(
        '{}/checkpoint/original/{}'.format(config.result_dir, path_to_add))
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    net.load_state_dict(checkpoint)
    if torch.cuda.is_available():
        net.cuda()
    return net, exp_name, epochs
コード例 #2
0
def init_network(config, num_classes):
    if config.network == 'resnet':
        kwargs = {'depth': config.depth, 'num_classes': num_classes}
        depth = config.depth
        net = reskipnet(**kwargs)
    else:
        raise NotImplementedError
    dataset = config.data_dir.split('/')[-1]
    path_to_add = '{}/{}/{}'.format(dataset, config.network, depth)
    assert os.path.exists('{}/checkpoint/sp/{}'.format(config.result_dir, path_to_add)),\
        'No checkpoint directory for sp model!'
    path_to_add = '{}/{}/{}'.format(dataset, config.network, config.depth)
    checkpoint_path, exp_name, epochs = get_best_checkpoint('{}/checkpoint/original/{}'.
                                                            format(config.result_dir, path_to_add))
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    net.load_state_dict(checkpoint)
    if torch.cuda.is_available():
        net.cuda()
    return net, exp_name, epochs
コード例 #3
0
def init_network(args, num_classes):
    kwargs = {'depth': args.depth, 'num_classes': num_classes}
    depth = str(args.depth)
    if args.network == 'vgg':
        kwargs.update({'avg_pool2d': args.avg_pool2d})
        kwargs.update({'batch_norm': True})
        net = vgg(**kwargs)
    elif args.network == 'resnet':
        dataset = args.data_dir.split('/')[-1]
        if 'cifar' in dataset.lower():
            kwargs.update({'dataset': 'cifar'})
        net = resnet(**kwargs)
    elif args.network == 'reskipnet':
        net = reskipnet(**kwargs)
    elif args.network == 'resnext':
        dataset = args.data_dir.split('/')[-1]
        if 'cifar' in dataset.lower():
            kwargs.update({'dataset': 'cifar'})
        kwargs.update({
            'cardinality': args.cardinality,
            'base_width': args.base_width
        })
        net = resnext(**kwargs)
        depth = '{}_{}_{}d'.format(depth, args.cardinality, args.base_width)
    else:
        raise NotImplementedError
    if args.resume:
        dataset = args.data_dir.split('/')[-1]
        path_to_add = '{}/{}/{}'.format(dataset, args.network, depth)
        assert os.path.exists('{}/checkpoint/original/{}'.format(args.result_dir, path_to_add)),\
            'No checkpoint directory for original model!'
        checkpoint_path, _, previous_epochs = get_best_checkpoint(
            '{}/checkpoint/original/{}'.format(args.result_dir, path_to_add))
        checkpoint = torch.load(checkpoint_path, map_location='cpu')
        net.load_state_dict(checkpoint)
        previous_accuracy = checkpoint_path.split('_')[-2]
    else:
        previous_accuracy = 0
        previous_epochs = 0
    if torch.cuda.is_available():
        net.cuda()
    return net, float(previous_accuracy), previous_epochs, depth
コード例 #4
0
def init_network(args, num_classes):
    if args.network == 'resnet':
        kwargs = {'depth': args.depth, 'num_classes': num_classes}
        depth = str(args.depth)
        net = reskipnet(**kwargs)
    else:
        raise NotImplementedError
    net.rl_mode = True
    dataset = args.data_dir.split('/')[-1]
    path_to_add = '{}/{}/{}'.format(dataset, args.network, depth)
    assert os.path.exists('{}/checkpoint/sp/{}'.format(args.result_dir, path_to_add)),\
        'No checkpoint directory for sp model!'
    checkpoint_path, _, epochs = get_best_checkpoint(
        '{}/checkpoint/sp/{}'.format(args.result_dir, path_to_add))
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    net.load_state_dict(checkpoint)
    if torch.cuda.is_available():
        net.cuda()
    net.fill_flop_weights()
    return net, epochs, depth
コード例 #5
0
def init_network(config, num_classes):
    if config.network == 'resnet':
        kwargs = {'depth': config.depth, 'num_classes': num_classes}
        depth = config.depth
        net = reskipnet(**kwargs)
    else:
        raise NotImplementedError
    net.rl_mode = True
    dataset = config.data_dir.split('/')[-1]
    path_to_add = '{}/{}/{}'.format(dataset, config.network, depth)
    assert os.path.exists('{}/checkpoint/pruned/{}'.format(config.result_dir, path_to_add)),\
        'No checkpoint directory for pruned model!'
    checkpoint_file = config.checkpoint_file
    checkpoint_path = '{}/{}'.format(path_to_add, checkpoint_file)
    exp_name = config.exp_name
    net = load_checkpoint_pruning(checkpoint_path, net, use_bias=True)
    if torch.cuda.is_available():
        net.cuda()
    net.fill_flop_weights()
    compression = checkpoint_file.split('_')[-1].split('.pth')[0]
    return net, exp_name, float(compression)
コード例 #6
0
def init_network(args, num_classes):
    if args.network == 'resnet':
        kwargs = {'depth': args.depth, 'num_classes': num_classes}
        depth = str(args.depth)
        net = reskipnet(**kwargs)
    else:
        raise NotImplementedError
    if args.resume:
        dataset = args.data_dir.split('/')[-1]
        path_to_add = '{}/{}/{}'.format(dataset, args.network, depth)
        assert os.path.exists('{}/checkpoint/original/{}'.format(args.result_dir, path_to_add)),\
            'No checkpoint directory for original model!'
        checkpoint_path, _, previous_epochs = get_best_checkpoint(
            '{}/checkpoint/original/{}'.format(args.result_dir, path_to_add))
        checkpoint = torch.load(checkpoint_path, map_location='cpu')
        net.load_state_dict(checkpoint)
        previous_accuracy = checkpoint_path.split('_')[-2]
    else:
        previous_accuracy = 0
        previous_epochs = 0
    if torch.cuda.is_available():
        net.cuda()
    return net, float(previous_accuracy), previous_epochs, depth