def convert_and_test(network_type, train_weights): builder = ACNetBuilder(base_config=None, deploy=False) general_test(network_type=network_type, weights=train_weights, builder=builder) deploy_weights = train_weights.replace('.hdf5', '_deploy.hdf5') convert_acnet_weights(train_weights, deploy_weights=deploy_weights, eps=1e-5) builder.switch_to_deploy() general_test(network_type=network_type, weights=deploy_weights, builder=builder)
def acnet_cfqkbnc(): try_arg = start_exp() network_type = 'cfqkbnc' dataset_name = 'cifar10' log_dir = 'acnet_exps/{}_{}_train'.format(network_type, try_arg) save_weights = 'acnet_exps/{}_{}_savedweights.pth'.format( network_type, try_arg) weight_decay_strength = 1e-4 batch_size = 64 lrs = parse_usual_lr_schedule(try_arg) if 'bias' in try_arg: weight_decay_bias = weight_decay_strength else: weight_decay_bias = 0 if 'warmup' in try_arg: warmup_factor = 0 else: warmup_factor = 1 config = get_baseconfig_by_epoch( network_type=network_type, dataset_name=dataset_name, dataset_subset='train', global_batch_size=batch_size, num_node=1, weight_decay=weight_decay_strength, optimizer_type='sgd', momentum=0.9, max_epochs=lrs.max_epochs, base_lr=lrs.base_lr, lr_epoch_boundaries=lrs.lr_epoch_boundaries, lr_decay_factor=lrs.lr_decay_factor, warmup_epochs=5, warmup_method='linear', warmup_factor=warmup_factor, ckpt_iter_period=20000, tb_iter_period=100, output_dir=log_dir, tb_dir=log_dir, save_weights=save_weights, val_epoch_period=2, linear_final_lr=lrs.linear_final_lr, weight_decay_bias=weight_decay_bias) if 'normal' in try_arg: builder = None elif 'acnet' in try_arg: from acnet.acnet_builder import ACNetBuilder builder = ACNetBuilder(base_config=config, deploy=False) else: assert False ding_train(config, show_variables=True, convbuilder=builder, use_nesterov='nest' in try_arg)
log_dir = 'acnet_exps/{}_{}_train'.format(network_type, block_type) weight_decay_bias = weight_decay_strength config = get_baseconfig_by_epoch(network_type=network_type, dataset_name=get_dataset_name_by_model_name(network_type), dataset_subset='train', global_batch_size=batch_size, num_node=1, weight_decay=weight_decay_strength, optimizer_type='sgd', momentum=0.9, max_epochs=lrs.max_epochs, base_lr=lrs.base_lr, lr_epoch_boundaries=lrs.lr_epoch_boundaries, cosine_minimum=lrs.cosine_minimum, lr_decay_factor=lrs.lr_decay_factor, warmup_epochs=0, warmup_method='linear', warmup_factor=0, ckpt_iter_period=40000, tb_iter_period=100, output_dir=log_dir, tb_dir=log_dir, save_weights=None, val_epoch_period=5, linear_final_lr=lrs.linear_final_lr, weight_decay_bias=weight_decay_bias, deps=None) if block_type == 'acb': builder = ACNetBuilder(base_config=config, deploy=False, gamma_init=gamma_init) else: builder = ConvBuilder(base_config=config) target_weights = os.path.join(log_dir, 'finish.hdf5') if not os.path.exists(target_weights): train_main(local_rank=start_arg.local_rank, cfg=config, convbuilder=builder, show_variables=True, auto_continue=auto_continue) if block_type == 'acb' and start_arg.local_rank == 0: convert_acnet_weights(target_weights, target_weights.replace('.hdf5', '_deploy.hdf5'), eps=1e-5) deploy_builder = ACNetBuilder(base_config=config, deploy=True) general_test(network_type=network_type, weights=target_weights.replace('.hdf5', '_deploy.hdf5'), builder=deploy_builder)
deps = wrn_origin_deps_flattened(2, 8) if batch_size is None: batch_size = TEST_BATCH_SIZE test_config = get_baseconfig_for_test(network_type=network_type, dataset_subset='val', global_batch_size=batch_size, init_weights=init_weights, deps=deps, dataset_name=dataset_name) return ding_test(cfg=test_config, net=net, show_variables=True, init_hdf5=init_hdf5, convbuilder=builder, weights_dict=weights_dict) if __name__ == '__main__': import sys network_type = 'resnet50' weights = sys.argv[1] dataset_name = 'imagenet_standard' from acnet.acnet_builder import ACNetBuilder builder = ACNetBuilder(base_config=None, deploy=False, gamma_init=1 / 3) general_test(network_type=network_type, weights=weights, builder=builder, dataset_name=dataset_name)
def ac_resnet50(pretrained=False, progress=True, **kwargs): from acnet.acnet_builder import ACNetBuilder #base_config is not None when using SE builder = ACNetBuilder(base_config=None, deploy=False) return ResNet(builder, Bottleneck, [3,4,6,3], num_classes=1000)