from scheduler.inventory_random_policy import ProducerBaselinePolicy, BaselinePolicy
# from scheduler.inventory_random_policy import ConsumerBaselinePolicy
from scheduler.inventory_minmax_policy import ConsumerMinMaxPolicy as ConsumerBaselinePolicy
# from scheduler.inventory_eoq_policy import ConsumerEOQPolicy as ConsumerBaselinePolicy
from utility.tools import SimulationTracker
# from scheduler.inventory_tf_model import FacilityNet
from scheduler.inventory_torch_model import SKUStoreBatchNormModel as SKUStoreDNN
from scheduler.inventory_torch_model import SKUWarehouseBatchNormModel as SKUWarehouseDNN
from config.inventory_config import env_config
from explorer.stochastic_sampling import StochasticSampling


# Configuration ===============================================================================


env_config_for_rendering = env_config.copy()
episod_duration = env_config_for_rendering['episod_duration']
env = InventoryManageEnv(env_config_for_rendering)


ppo_policy_config_producer = {
    "model": {
        "fcnet_hiddens": [128, 128],
        "custom_model": "facility_net"
    }
}

ppo_policy_config_store_consumer = {
    "model": {
        "fcnet_hiddens": [16, 16],
        "custom_model": "sku_store_net",
Example #2
0
    else:
        policy_mode = f"base_stock_"+("static" if is_static else "dynamic")+\
            f"_gap{args.buyin_gap}_updt{args.update_interval}_start{args.start_step}"
    
    # ConsumerBaseStockPolicy static fields

    def policy_oracle_setup():
        ConsumerBaseStockPolicy.time_hrz_len = env_config['evaluation_len']
        ConsumerBaseStockPolicy.oracle = True
        ConsumerBaseStockPolicy.has_order_cost = False

    ConsumerBaseStockPolicy.buyin_gap = args.buyin_gap
    
    # setup
    if is_oracle:
        policy_oracle_setup()
    else:
        ConsumerBaseStockPolicy.time_hrz_len = env_config['sale_hist_len']
    
    # always set these fields
    ConsumerBaseStockPolicy.update_interval = args.update_interval
    ConsumerBaseStockPolicy.start_step = args.start_step

    policies = get_policies(is_static)
    # ray.init()

    # Simulation loop
    vis_env = InventoryManageEnv(env_config.copy())

    visualization(vis_env, policies, 0, policy_mode, basestock=True)
Example #3
0
def get_policies(is_static):
    env_config_for_rendering = env_config.copy()
    
    episod_duration = args.episod
    env_config_for_rendering['episod_duration'] = episod_duration
    env = InventoryManageEnv(env_config_for_rendering)
    obss = env.reset()

    ConsumerBaseStockPolicy.env_config = env_config_for_rendering
    ConsumerBaseStockPolicy.facilities = env.world.facilities
    ConsumerBaseStockPolicy.stop_order_factor = args.stop_order

    starter_step = env.env_config['episod_duration']+env.env_config['tail_timesteps']
    env.set_retailer_step(starter_step-episod_duration)
    env.set_iteration(1, 1)
    print(f"Environment: Producer action space {env.action_space_producer}, Consumer action space {env.action_space_consumer}, Observation space {env.observation_space}")

    if is_static: # base-stock levels are predefined before eval
        def load_base_policy(agent_id):
            if Utils.is_producer_agent(agent_id):
                return ProducerBaselinePolicy(env.observation_space, env.action_space_producer, BaselinePolicy.get_config_from_env(env))
            else:
                return ConsumerBaselinePolicy(env.observation_space, env.action_space_consumer, BaselinePolicy.get_config_from_env(env))

        base_policies = {}
        for agent_id in env.agent_ids():
            base_policies[agent_id] = load_base_policy(agent_id)

        _, infos = env.state_calculator.world_to_state(env.world)
        rnn_states = {}
        rewards = {}
        for agent_id in obss.keys():
            rnn_states[agent_id] = base_policies[agent_id].get_initial_state()
            rewards[agent_id] = 0

        # initializing for base stock policies
        for epoch in tqdm(range(args.episod)):
            action_dict = {}
            for agent_id, obs in obss.items():
                policy = base_policies[agent_id]
                action, _, _ = policy.compute_single_action(obs, state=rnn_states[agent_id], info=infos[agent_id], explore=True ) 
                action_dict[agent_id] = action
            obss, rewards, _, infos = env.step(action_dict)
        ConsumerBaseStockPolicy.update_base_stocks()

    def load_policy(agent_id):
        _facility = env.world.facilities[Utils.agentid_to_fid(agent_id)] 
        if Utils.is_producer_agent(agent_id):
            return ProducerBaselinePolicy(env.observation_space, env.action_space_producer, BaselinePolicy.get_config_from_env(env))
        # elif isinstance(_facility, SKUStoreUnit) or isinstance(_facility, SKUWarehouseUnit):
        elif isinstance(_facility, SKUStoreUnit):
            policy = ConsumerBaseStockPolicy(env.observation_space, env.action_space_consumer,
                        BaselinePolicy.get_config_from_env(env), is_static)
            return policy
        else:
            return ConsumerBaselinePolicy(env.observation_space, env.action_space_consumer, BaselinePolicy.get_config_from_env(env))

    policies = {}
    for agent_id in env.agent_ids():
        policies[agent_id] = load_policy(agent_id)
    return policies
def train_ppo(args):
    env_config_for_rendering.update({'init': args.init})
    ext_conf = ppo.DEFAULT_CONFIG.copy()
    ext_conf.update({
        "env": InventoryManageEnv,
        "framework": "torch",
        "num_workers": 4,
        "vf_share_layers": True,
        "vf_loss_coeff": 1.00,
        # estimated max value of vf, used to normalization
        "vf_clip_param": 100.0,
        "clip_param": 0.2,
        "use_critic": True,
        "use_gae": True,
        "lambda": 1.0,
        "gamma": 0.99,
        'env_config': env_config_for_rendering.copy(),
        # Number of steps after which the episode is forced to terminate. Defaults
        # to `env.spec.max_episode_steps` (if present) for Gym envs.
        "horizon": args.episod,
        # Calculate rewards but don't reset the environment when the horizon is
        # hit. This allows value estimation and RNN state to span across logical
        # episodes denoted by horizon. This only has an effect if horizon != inf.
        "soft_horizon": False,
        # Minimum env steps to optimize for per train call. This value does
        # not affect learning, only the length of train iterations.
        'timesteps_per_iteration': 1000,
        'batch_mode': 'complete_episodes',
        # Size of batches collected from each worker
        "rollout_fragment_length": args.rollout_fragment_length,
        # Number of timesteps collected for each SGD round. This defines the size
        # of each SGD epoch.
        "train_batch_size": args.rollout_fragment_length * args.batch_size,
        # Whether to shuffle sequences in the batch when training (recommended).
        "shuffle_sequences": True,
        # Total SGD batch size across all devices for SGD. This defines the
        # minibatch size within each epoch.
        "sgd_minibatch_size":
        args.rollout_fragment_length * args.min_batch_size,
        # Number of SGD iterations in each outer loop (i.e., number of epochs to
        # execute per train batch).
        "num_sgd_iter": 50,
        "lr": 1e-4,
        "_fake_gpus": True,
        "num_gpus": 0,
        "explore": True,
        "exploration_config": {
            "type": StochasticSampling,
            "random_timesteps":
            0,  # args.rollout_fragment_length*args.batch_size*args.stop_iters // 2,
        },
        "multiagent": {
            "policies": policies,
            "policy_mapping_fn": policy_map_fn,
            "policies_to_train": ['ppo_store_consumer']
        }
    })

    print(
        f"Environment: action space producer {env.action_space_producer}, action space consumer {env.action_space_consumer}, observation space {env.observation_space}",
        flush=True)

    if (args.is_pretrained):
        ext_conf.update({
            'num_workers': 0  #, 'episod_duration':args.episod
        })
        ppo_trainer = ppo.PPOTrainer(env=InventoryManageEnv, config=ext_conf)
        env.env_config.update({'episod_duration': args.episod})
        ppo_trainer.restore(args.premodel)
        visualization(InventoryManageEnv(env_config.copy()),
                      get_policy(env, ppo_trainer), 1, args.run_name)
        return ppo_trainer

    # ppo_trainer.restore('/root/ray_results/PPO_InventoryManageEnv_2020-11-02_18-25-55cle_glgg/checkpoint_20/checkpoint-20')

    # stop = {
    #     "training_iteration": args.stop_iters,
    #     "timesteps_total": args.stop_timesteps,
    #     "episode_reward_min": args.stop_reward,
    # }

    # analysis = tune.run(args.run, config=ext_conf, stop=stop, mode='max', checkpoint_freq=1, verbose=1)
    # checkpoints = analysis.get_trial_checkpoints_paths(
    #                         trial=analysis.get_best_trial("episode_reward_max"),
    #                         metric="episode_reward_max")
    # ppo_trainer.restore(checkpoints[0][0])

    ext_conf['env_config'].update({
        'gamma': ext_conf['gamma'],
        'training': True,
        'policies': None
    })

    ppo_trainer = ppo.PPOTrainer(env=InventoryManageEnv, config=ext_conf)
    max_mean_reward = -100

    ppo_trainer.workers.foreach_worker(lambda ev: ev.foreach_env(
        lambda env: env.set_policies(get_policy(env, ev))))

    for i in range(args.stop_iters):
        print("== Iteration", i, "==", flush=True)

        ppo_trainer.workers.foreach_worker(lambda ev: ev.foreach_env(
            lambda env: env.set_iteration(i, args.stop_iters)))
        result = ppo_trainer.train()
        print_training_results(result)
        now_mean_reward = result['policy_reward_mean']['ppo_store_consumer']

        if (
                i + 1
        ) % args.visualization_frequency == 0 or now_mean_reward > max_mean_reward:
            max_mean_reward = max(max_mean_reward, now_mean_reward)
            checkpoint = ppo_trainer.save()
            print("checkpoint saved at", checkpoint, flush=True)
            visualization(InventoryManageEnv(env_config.copy()),
                          get_policy(env, ppo_trainer), i, args.run_name)
            # exit(0)

    return ppo_trainer