예제 #1
0
def csgd_prune_pipeline(local_rank, init_hdf5, base_train_config,
                        csgd_train_config, target_deps, centri_strength,
                        pacesetter_dict, succeeding_strategy):
    #   If there is no given base weights file, train from scratch.
    if init_hdf5 is None:
        csgd_init_weights = os.path.join(base_train_config.output_dir,
                                         'finish.hdf5')
        if not os.path.exists(csgd_init_weights):
            train_main(local_rank=local_rank,
                       cfg=base_train_config,
                       use_nesterov=True)
    else:
        csgd_init_weights = init_hdf5

    #   C-SGD train then prune
    pruned_weights = os.path.join(csgd_train_config.output_dir, 'pruned.hdf5')
    csgd_train_main(local_rank=local_rank,
                    cfg=csgd_train_config,
                    target_deps=target_deps,
                    succeeding_strategy=succeeding_strategy,
                    pacesetter_dict=pacesetter_dict,
                    centri_strength=centri_strength,
                    pruned_weights=pruned_weights,
                    init_hdf5=csgd_init_weights,
                    use_nesterov=True)  # TODO init?

    #   Test it.
    if local_rank == 0:
        general_test(csgd_train_config.network_type, weights=pruned_weights)
예제 #2
0
def train_base_model(local_rank,
                     network_type,
                     lrs,
                     weight_decay_strength,
                     batch_size,
                     deps,
                     auto_continue,
                     init_hdf5=None,
                     net=None,
                     dataset_name=None):

    log_dir = '{}_train'.format(network_type)

    weight_decay_bias = 0
    warmup_factor = 0

    if dataset_name is None:
        dataset_name = get_dataset_name_by_model_name(network_type)

    config = get_baseconfig_by_epoch(
        network_type=network_type,
        dataset_name=dataset_name,
        dataset_subset='train',
        global_batch_size=batch_size,
        num_node=1,
        weight_decay=weight_decay_strength,
        optimizer_type='sgd',
        momentum=0.9,
        max_epochs=lrs.max_epochs,
        base_lr=lrs.base_lr,
        lr_epoch_boundaries=lrs.lr_epoch_boundaries,
        lr_decay_factor=lrs.lr_decay_factor,
        cosine_minimum=lrs.cosine_minimum,
        warmup_epochs=0,
        warmup_method='linear',
        warmup_factor=warmup_factor,
        ckpt_iter_period=40000,
        tb_iter_period=100,
        output_dir=log_dir,
        tb_dir=log_dir,
        save_weights=None,
        val_epoch_period=5,
        linear_final_lr=lrs.linear_final_lr,
        weight_decay_bias=weight_decay_bias,
        deps=deps)

    builder = None
    trained_weights = os.path.join(log_dir, 'finish.hdf5')
    if not os.path.exists(trained_weights):
        train_main(local_rank,
                   config,
                   show_variables=True,
                   convbuilder=builder,
                   use_nesterov=False,
                   auto_continue=auto_continue,
                   init_hdf5=init_hdf5,
                   net=net)
예제 #3
0
        network_type=network_type,
        dataset_name=get_dataset_name_by_model_name(network_type),
        dataset_subset='train',
        global_batch_size=batch_size,
        num_node=1,
        weight_decay=weight_decay_strength,
        optimizer_type='sgd',
        momentum=0.9,
        max_epochs=finetune_lrs.max_epochs,
        base_lr=finetune_lrs.base_lr,
        lr_epoch_boundaries=finetune_lrs.lr_epoch_boundaries,
        cosine_minimum=finetune_lrs.cosine_minimum,
        lr_decay_factor=finetune_lrs.lr_decay_factor,
        warmup_epochs=0,
        warmup_method='linear',
        warmup_factor=warmup_factor,
        ckpt_iter_period=40000,
        tb_iter_period=100,
        output_dir=log_dir,
        tb_dir=log_dir,
        save_weights=None,
        val_epoch_period=2,
        linear_final_lr=finetune_lrs.linear_final_lr,
        weight_decay_bias=weight_decay_bias,
        deps=pruned_deps)

    train_main(local_rank=start_arg.local_rank,
               cfg=finetune_config,
               show_variables=True,
               init_hdf5=pruned_path)
예제 #4
0
    log_dir = 'acnet_exps/{}_{}_train'.format(network_type, block_type)

    weight_decay_bias = weight_decay_strength
    config = get_baseconfig_by_epoch(network_type=network_type,
                                     dataset_name=get_dataset_name_by_model_name(network_type), dataset_subset='train',
                                     global_batch_size=batch_size, num_node=1,
                                     weight_decay=weight_decay_strength, optimizer_type='sgd', momentum=0.9,
                                     max_epochs=lrs.max_epochs, base_lr=lrs.base_lr, lr_epoch_boundaries=lrs.lr_epoch_boundaries, cosine_minimum=lrs.cosine_minimum,
                                     lr_decay_factor=lrs.lr_decay_factor,
                                     warmup_epochs=0, warmup_method='linear', warmup_factor=0,
                                     ckpt_iter_period=40000, tb_iter_period=100, output_dir=log_dir,
                                     tb_dir=log_dir, save_weights=None, val_epoch_period=5, linear_final_lr=lrs.linear_final_lr,
                                     weight_decay_bias=weight_decay_bias, deps=None)

    if block_type == 'acb':
        builder = ACNetBuilder(base_config=config, deploy=False, gamma_init=gamma_init)
    else:
        builder = ConvBuilder(base_config=config)

    target_weights = os.path.join(log_dir, 'finish.hdf5')
    if not os.path.exists(target_weights):
        train_main(local_rank=start_arg.local_rank, cfg=config, convbuilder=builder,
               show_variables=True, auto_continue=auto_continue)

    if block_type == 'acb' and start_arg.local_rank == 0:
        convert_acnet_weights(target_weights, target_weights.replace('.hdf5', '_deploy.hdf5'), eps=1e-5)
        deploy_builder = ACNetBuilder(base_config=config, deploy=True)
        general_test(network_type=network_type, weights=target_weights.replace('.hdf5', '_deploy.hdf5'),
                 builder=deploy_builder)
예제 #5
0
파일: do_rfnet.py 프로젝트: LeeJZh/MVConv
        raise ValueError('...')

    log_dir = 'rfnet_exps/{}_{}_train/scale_{}_alpha_{}_epochs_{}/'.format(network_type, block_type, scale, alpha, epochs)

    weight_decay_bias = weight_decay_strength
    config = get_baseconfig_by_epoch(network_type=network_type,
                                     dataset_name=get_dataset_name_by_model_name(network_type), dataset_subset='train',
                                     global_batch_size=batch_size, num_node=1,
                                     weight_decay=weight_decay_strength, optimizer_type='sgd', momentum=0.9,
                                     max_epochs=lrs.max_epochs, base_lr=lrs.base_lr, lr_epoch_boundaries=lrs.lr_epoch_boundaries, cosine_minimum=lrs.cosine_minimum,
                                     lr_decay_factor=lrs.lr_decay_factor,
                                     warmup_epochs=0, warmup_method='linear', warmup_factor=0,
                                     ckpt_iter_period=40000, tb_iter_period=100, output_dir=log_dir,
                                     tb_dir=log_dir, save_weights=None, val_epoch_period=5, linear_final_lr=lrs.linear_final_lr,
                                     weight_decay_bias=weight_decay_bias, deps=None)

    if block_type == 'rfb':
        builder = RFNetBuilder(base_config=config, deploy=False, gamma_init=gamma_init, alpha=alpha, scale=scale)
    else:
        builder = ConvBuilder(base_config=config)

    target_weights = os.path.join(log_dir, 'finish.hdf5')
    if not os.path.exists(target_weights):
        train_main(local_rank=start_arg.local_rank, cfg=config, convbuilder=builder,
               show_variables=True, auto_continue=auto_continue)#, tensorflow_style_init=True)

    if block_type == 'rfb':
        convert_rfnet_weights(target_weights, target_weights.replace('.hdf5', '_deploy.hdf5'), eps=1e-5)
        deploy_builder = RFNetBuilder(base_config=config, deploy=True)
        general_test(network_type=network_type, weights=target_weights.replace('.hdf5', '_deploy.hdf5'),
                 builder=deploy_builder)