Пример #1
0
def convert_and_test(network_type, train_weights):
    builder = ACNetBuilder(base_config=None, deploy=False)
    general_test(network_type=network_type,
                 weights=train_weights,
                 builder=builder)
    deploy_weights = train_weights.replace('.hdf5', '_deploy.hdf5')
    convert_acnet_weights(train_weights,
                          deploy_weights=deploy_weights,
                          eps=1e-5)
    builder.switch_to_deploy()
    general_test(network_type=network_type,
                 weights=deploy_weights,
                 builder=builder)
Пример #2
0
def acnet_cfqkbnc():
    try_arg = start_exp()

    network_type = 'cfqkbnc'
    dataset_name = 'cifar10'
    log_dir = 'acnet_exps/{}_{}_train'.format(network_type, try_arg)
    save_weights = 'acnet_exps/{}_{}_savedweights.pth'.format(
        network_type, try_arg)
    weight_decay_strength = 1e-4
    batch_size = 64

    lrs = parse_usual_lr_schedule(try_arg)

    if 'bias' in try_arg:
        weight_decay_bias = weight_decay_strength
    else:
        weight_decay_bias = 0

    if 'warmup' in try_arg:
        warmup_factor = 0
    else:
        warmup_factor = 1

    config = get_baseconfig_by_epoch(
        network_type=network_type,
        dataset_name=dataset_name,
        dataset_subset='train',
        global_batch_size=batch_size,
        num_node=1,
        weight_decay=weight_decay_strength,
        optimizer_type='sgd',
        momentum=0.9,
        max_epochs=lrs.max_epochs,
        base_lr=lrs.base_lr,
        lr_epoch_boundaries=lrs.lr_epoch_boundaries,
        lr_decay_factor=lrs.lr_decay_factor,
        warmup_epochs=5,
        warmup_method='linear',
        warmup_factor=warmup_factor,
        ckpt_iter_period=20000,
        tb_iter_period=100,
        output_dir=log_dir,
        tb_dir=log_dir,
        save_weights=save_weights,
        val_epoch_period=2,
        linear_final_lr=lrs.linear_final_lr,
        weight_decay_bias=weight_decay_bias)

    if 'normal' in try_arg:
        builder = None
    elif 'acnet' in try_arg:
        from acnet.acnet_builder import ACNetBuilder
        builder = ACNetBuilder(base_config=config, deploy=False)
    else:
        assert False

    ding_train(config,
               show_variables=True,
               convbuilder=builder,
               use_nesterov='nest' in try_arg)
Пример #3
0
    log_dir = 'acnet_exps/{}_{}_train'.format(network_type, block_type)

    weight_decay_bias = weight_decay_strength
    config = get_baseconfig_by_epoch(network_type=network_type,
                                     dataset_name=get_dataset_name_by_model_name(network_type), dataset_subset='train',
                                     global_batch_size=batch_size, num_node=1,
                                     weight_decay=weight_decay_strength, optimizer_type='sgd', momentum=0.9,
                                     max_epochs=lrs.max_epochs, base_lr=lrs.base_lr, lr_epoch_boundaries=lrs.lr_epoch_boundaries, cosine_minimum=lrs.cosine_minimum,
                                     lr_decay_factor=lrs.lr_decay_factor,
                                     warmup_epochs=0, warmup_method='linear', warmup_factor=0,
                                     ckpt_iter_period=40000, tb_iter_period=100, output_dir=log_dir,
                                     tb_dir=log_dir, save_weights=None, val_epoch_period=5, linear_final_lr=lrs.linear_final_lr,
                                     weight_decay_bias=weight_decay_bias, deps=None)

    if block_type == 'acb':
        builder = ACNetBuilder(base_config=config, deploy=False, gamma_init=gamma_init)
    else:
        builder = ConvBuilder(base_config=config)

    target_weights = os.path.join(log_dir, 'finish.hdf5')
    if not os.path.exists(target_weights):
        train_main(local_rank=start_arg.local_rank, cfg=config, convbuilder=builder,
               show_variables=True, auto_continue=auto_continue)

    if block_type == 'acb' and start_arg.local_rank == 0:
        convert_acnet_weights(target_weights, target_weights.replace('.hdf5', '_deploy.hdf5'), eps=1e-5)
        deploy_builder = ACNetBuilder(base_config=config, deploy=True)
        general_test(network_type=network_type, weights=target_weights.replace('.hdf5', '_deploy.hdf5'),
                 builder=deploy_builder)
Пример #4
0
        deps = wrn_origin_deps_flattened(2, 8)

    if batch_size is None:
        batch_size = TEST_BATCH_SIZE
    test_config = get_baseconfig_for_test(network_type=network_type,
                                          dataset_subset='val',
                                          global_batch_size=batch_size,
                                          init_weights=init_weights,
                                          deps=deps,
                                          dataset_name=dataset_name)
    return ding_test(cfg=test_config,
                     net=net,
                     show_variables=True,
                     init_hdf5=init_hdf5,
                     convbuilder=builder,
                     weights_dict=weights_dict)


if __name__ == '__main__':

    import sys
    network_type = 'resnet50'
    weights = sys.argv[1]
    dataset_name = 'imagenet_standard'
    from acnet.acnet_builder import ACNetBuilder
    builder = ACNetBuilder(base_config=None, deploy=False, gamma_init=1 / 3)

    general_test(network_type=network_type,
                 weights=weights,
                 builder=builder,
                 dataset_name=dataset_name)
Пример #5
0
def ac_resnet50(pretrained=False, progress=True, **kwargs):
    from acnet.acnet_builder import ACNetBuilder
    #base_config is not None when using SE
    builder = ACNetBuilder(base_config=None, deploy=False)
    return ResNet(builder, Bottleneck, [3,4,6,3], num_classes=1000)