Example #1
0
def train_auto_compressor(model, args, optimizer_data, validate_fn,
                          save_checkpoint_fn, train_fn):
    dataset = args.dataset
    arch = args.arch
    num_ft_epochs = args.amc_ft_epochs
    action_range = args.amc_action_range

    config_verbose(False)

    # Read the experiment configuration
    amc_cfg_fname = args.amc_cfg_file
    if not amc_cfg_fname:
        raise ValueError(
            "You must specify a valid configuration file path using --amc-cfg")

    with open(amc_cfg_fname, 'r') as cfg_file:
        compression_cfg = distiller.utils.yaml_ordered_load(cfg_file)

    if not args.amc_rllib:
        raise ValueError("You must set --amc-rllib to a valid value")

    #rl_lib = compression_cfg["rl_lib"]["name"]
    #msglogger.info("Executing AMC: RL agent - %s   RL library - %s", args.amc_agent_algo, rl_lib)

    # Create a dictionary of parameters that Coach will handover to DistillerWrapperEnvironment
    # Once it creates it.
    services = distiller.utils.MutableNamedTuple({
        'validate_fn': validate_fn,
        'save_checkpoint_fn': save_checkpoint_fn,
        'train_fn': train_fn
    })

    app_args = distiller.utils.MutableNamedTuple({
        'dataset': dataset,
        'arch': arch,
        'optimizer_data': optimizer_data,
        'seed': args.seed
    })

    ddpg_cfg = distiller.utils.MutableNamedTuple({
        'heatup_noise':
        0.5,
        'initial_training_noise':
        0.5,
        'training_noise_decay':
        0.95,
        'num_heatup_episodes':
        args.amc_heatup_episodes,
        'num_training_episodes':
        args.amc_training_episodes,
        'actor_lr':
        1e-4,
        'critic_lr':
        1e-3
    })

    amc_cfg = distiller.utils.MutableNamedTuple({
        'modules_dict':
        compression_cfg["network"],  # dict of modules, indexed by arch name
        'save_chkpts':
        args.amc_save_chkpts,
        'protocol':
        args.amc_protocol,
        'agent_algo':
        args.amc_agent_algo,
        'num_ft_epochs':
        num_ft_epochs,
        'action_range':
        action_range,
        'reward_frequency':
        args.amc_reward_frequency,
        'ft_frequency':
        args.amc_ft_frequency,
        'pruning_pattern':
        args.amc_prune_pattern,
        'pruning_method':
        args.amc_prune_method,
        'group_size':
        args.amc_group_size,
        'n_points_per_fm':
        args.amc_fm_reconstruction_n_pts,
        'ddpg_cfg':
        ddpg_cfg,
        'ranking_noise':
        args.amc_ranking_noise
    })

    #net_wrapper = NetworkWrapper(model, app_args, services)
    #return sample_networks(net_wrapper, services)

    amc_cfg.target_density = args.amc_target_density
    amc_cfg.reward_fn, amc_cfg.action_constrain_fn = reward_factory(
        args.amc_protocol)

    def create_environment():
        env = DistillerWrapperEnvironment(model, app_args, amc_cfg, services)
        #env.amc_cfg.ddpg_cfg.replay_buffer_size = int(1.5 * amc_cfg.ddpg_cfg.num_heatup_episodes * env.steps_per_episode)
        env.amc_cfg.ddpg_cfg.replay_buffer_size = amc_cfg.ddpg_cfg.num_heatup_episodes * env.steps_per_episode
        return env

    env1 = create_environment()

    if args.amc_rllib == "spinningup":
        from rl_libs.spinningup import spinningup_if
        rl = spinningup_if.RlLibInterface()
        env2 = create_environment()
        steps_per_episode = env1.steps_per_episode
        rl.solve(env1, env2)
    elif args.amc_rllib == "hanlab":
        from rl_libs.hanlab import hanlab_if
        rl = hanlab_if.RlLibInterface()
        args.observation_len = len(Observation._fields)
        rl.solve(env1, args)
    elif args.amc_rllib == "coach":
        from rl_libs.coach import coach_if
        rl = coach_if.RlLibInterface()
        env_cfg = {
            'model': model,
            'app_args': app_args,
            'amc_cfg': amc_cfg,
            'services': services
        }
        steps_per_episode = env1.steps_per_episode
        rl.solve(**env_cfg, steps_per_episode=steps_per_episode)
    elif args.amc_rllib == "random":
        from rl_libs.random import random_if
        rl = random_if.RlLibInterface()
        return rl.solve(env1)
    else:
        raise ValueError("unsupported rl library: ", args.amc_rllib)
Example #2
0
def train_auto_compressor(model, args, optimizer_data, validate_fn,
                          save_checkpoint_fn, train_fn):
    dataset = args.dataset
    arch = args.arch
    num_ft_epochs = args.cacp_ft_epochs
    action_range = args.cacp_action_range
    conditional = args.conditional

    config_verbose(False)

    # Read the experiment configuration
    cacp_cfg_fname = args.cacp_cfg_file
    if not cacp_cfg_fname:
        raise ValueError(
            "You must specify a valid configuration file path using --cacp-cfg"
        )

    with open(cacp_cfg_fname, 'r') as cfg_file:
        compression_cfg = utils.yaml_ordered_load(cfg_file)

    if not args.cacp_rllib:
        raise ValueError("You must set --cacp-rllib to a valid value")

    #rl_lib = compression_cfg["rl_lib"]["name"]
    #msglogger.info("Executing cacp: RL agent - %s   RL library - %s", args.cacp_agent_algo, rl_lib)

    # Create a dictionary of parameters that Coach will handover to WrapperEnvironment
    # Once it creates it.
    services = utils.MutableNamedTuple({
        'validate_fn': validate_fn,
        'save_checkpoint_fn': save_checkpoint_fn,
        'train_fn': train_fn
    })

    app_args = utils.MutableNamedTuple({
        'dataset': dataset,
        'arch': arch,
        'optimizer_data': optimizer_data,
        'seed': args.seed
    })

    ddpg_cfg = utils.MutableNamedTuple({
        'heatup_noise': 0.5,
        'initial_training_noise': 0.5,
        'training_noise_decay': 0.95,
        'num_heatup_episodes': args.cacp_heatup_episodes,
        'num_training_episodes': args.cacp_training_episodes,
        'actor_lr': 1e-4,
        'critic_lr': 1e-3,
        "conditional": conditional
    })

    cacp_cfg = utils.MutableNamedTuple({
        'modules_dict':
        compression_cfg["network"],  # dict of modules, indexed by arch name
        'save_chkpts':
        args.cacp_save_chkpts,
        'protocol':
        args.cacp_protocol,
        'agent_algo':
        args.cacp_agent_algo,
        'num_ft_epochs':
        num_ft_epochs,
        'action_range':
        action_range,
        'reward_frequency':
        args.cacp_reward_frequency,
        'ft_frequency':
        args.cacp_ft_frequency,
        'pruning_pattern':
        args.cacp_prune_pattern,
        'pruning_method':
        args.cacp_prune_method,
        'group_size':
        args.cacp_group_size,
        'n_points_per_fm':
        args.cacp_fm_reconstruction_n_pts,
        'ddpg_cfg':
        ddpg_cfg,
        'ranking_noise':
        args.cacp_ranking_noise,
        "conditional":
        conditional,
        "support_raito":
        args.support_ratio
    })

    cacp_cfg.reward_fn, cacp_cfg.action_constrain_fn = reward_factory(
        args.cacp_protocol)

    def create_environment():
        env = CACPWrapperEnvironment(model, app_args, cacp_cfg, services)
        env.cacp_cfg.ddpg_cfg.replay_buffer_size = cacp_cfg.ddpg_cfg.num_heatup_episodes * env.steps_per_episode
        return env

    env1 = create_environment()

    from lib.hanlab import hanlab_if
    rl = hanlab_if.RlLibInterface()
    args.observation_len = len(Observation._fields)
    rl.solve(env1, args)