Ejemplo n.º 1
0
def train_dqn(args):
    if not os.path.exists('train_log'):
        os.mkdir('train_log')
    writer = TensorBoard(f'train_log/{args.run_name}')

    dqn_config = dqn_config_default.copy()

    dqn_config.update({
        "batch_size": 1024,
        "min_replay_history": 10000,
        "training_steps": 15000,
        "lr": 0.0003,
        # ....
    })

    all_policies['dqn_store_consumer'] = ConsumerC51TorchPolicy(
        env.observation_space, env.action_space_consumer,
        BaselinePolicy.get_config_from_env(env), dqn_config)

    obss = env.reset()
    agent_ids = obss.keys()
    policies = {}
    policies_to_train = []

    for agent_id in agent_ids:
        policy, if_train = policy_map_fn(agent_id)
        policies[agent_id] = all_policies[policy]
        if if_train:
            policies_to_train.append(agent_id)

    dqn_trainer = Trainer(env, policies, policies_to_train, dqn_config)
    max_mean_reward = -1000
    debug = False
    for i in range(args.num_iterations):
        result = dqn_trainer.train(i)
        now_mean_reward = print_result(result, writer, i)

        if now_mean_reward > max_mean_reward or debug:
            max_mean_reward = max(max_mean_reward, now_mean_reward)
            dqn_trainer.save(args.run_name, i)
            # print("checkpoint saved at", checkpoint)
            visualization(env, policies, i, args.run_name)
Ejemplo n.º 2
0
        if Utils.is_producer_agent(agent_id):
            return ProducerBaselinePolicy(
                env.observation_space, env.action_space_producer,
                BaselinePolicy.get_config_from_env(env))
        elif isinstance(_facility, SKUStoreUnit) or isinstance(
                _facility, SKUWarehouseUnit):
            policy = ConsumerBaseStockPolicy(
                env.observation_space, env.action_space_consumer,
                BaselinePolicy.get_config_from_env(env))
            policy.base_stock = sku_base_stocks[Utils.agentid_to_fid(agent_id)]
            return policy
        else:
            return ConsumerBaselinePolicy(
                env.observation_space, env.action_space_consumer,
                BaselinePolicy.get_config_from_env(env))

    policies = {}
    for agent_id in env.agent_ids():
        policies[agent_id] = load_policy(agent_id)

    # Simulation loop
    if args.visualization:
        visualization(env, policies, 1, policy_mode)
    else:
        tracker = SimulationTracker(episod_duration, 1, env.agent_ids())
        if args.pt:
            loc_path = f"{os.environ['PT_OUTPUT_DIR']}/{policy_mode}/"
        else:
            loc_path = 'output/%s/' % policy_mode
        tracker.run_and_render(loc_path)
def train_ppo(args):
    env_config_for_rendering.update({'init': args.init})
    ext_conf = ppo.DEFAULT_CONFIG.copy()
    ext_conf.update({
        "env": InventoryManageEnv,
        "framework": "torch",
        "num_workers": 4,
        "vf_share_layers": True,
        "vf_loss_coeff": 1.00,
        # estimated max value of vf, used to normalization
        "vf_clip_param": 100.0,
        "clip_param": 0.2,
        "use_critic": True,
        "use_gae": True,
        "lambda": 1.0,
        "gamma": 0.99,
        'env_config': env_config_for_rendering.copy(),
        # Number of steps after which the episode is forced to terminate. Defaults
        # to `env.spec.max_episode_steps` (if present) for Gym envs.
        "horizon": args.episod,
        # Calculate rewards but don't reset the environment when the horizon is
        # hit. This allows value estimation and RNN state to span across logical
        # episodes denoted by horizon. This only has an effect if horizon != inf.
        "soft_horizon": False,
        # Minimum env steps to optimize for per train call. This value does
        # not affect learning, only the length of train iterations.
        'timesteps_per_iteration': 1000,
        'batch_mode': 'complete_episodes',
        # Size of batches collected from each worker
        "rollout_fragment_length": args.rollout_fragment_length,
        # Number of timesteps collected for each SGD round. This defines the size
        # of each SGD epoch.
        "train_batch_size": args.rollout_fragment_length * args.batch_size,
        # Whether to shuffle sequences in the batch when training (recommended).
        "shuffle_sequences": True,
        # Total SGD batch size across all devices for SGD. This defines the
        # minibatch size within each epoch.
        "sgd_minibatch_size":
        args.rollout_fragment_length * args.min_batch_size,
        # Number of SGD iterations in each outer loop (i.e., number of epochs to
        # execute per train batch).
        "num_sgd_iter": 50,
        "lr": 1e-4,
        "_fake_gpus": True,
        "num_gpus": 0,
        "explore": True,
        "exploration_config": {
            "type": StochasticSampling,
            "random_timesteps":
            0,  # args.rollout_fragment_length*args.batch_size*args.stop_iters // 2,
        },
        "multiagent": {
            "policies": policies,
            "policy_mapping_fn": policy_map_fn,
            "policies_to_train": ['ppo_store_consumer']
        }
    })

    print(
        f"Environment: action space producer {env.action_space_producer}, action space consumer {env.action_space_consumer}, observation space {env.observation_space}",
        flush=True)

    if (args.is_pretrained):
        ext_conf.update({
            'num_workers': 0  #, 'episod_duration':args.episod
        })
        ppo_trainer = ppo.PPOTrainer(env=InventoryManageEnv, config=ext_conf)
        env.env_config.update({'episod_duration': args.episod})
        ppo_trainer.restore(args.premodel)
        visualization(InventoryManageEnv(env_config.copy()),
                      get_policy(env, ppo_trainer), 1, args.run_name)
        return ppo_trainer

    # ppo_trainer.restore('/root/ray_results/PPO_InventoryManageEnv_2020-11-02_18-25-55cle_glgg/checkpoint_20/checkpoint-20')

    # stop = {
    #     "training_iteration": args.stop_iters,
    #     "timesteps_total": args.stop_timesteps,
    #     "episode_reward_min": args.stop_reward,
    # }

    # analysis = tune.run(args.run, config=ext_conf, stop=stop, mode='max', checkpoint_freq=1, verbose=1)
    # checkpoints = analysis.get_trial_checkpoints_paths(
    #                         trial=analysis.get_best_trial("episode_reward_max"),
    #                         metric="episode_reward_max")
    # ppo_trainer.restore(checkpoints[0][0])

    ext_conf['env_config'].update({
        'gamma': ext_conf['gamma'],
        'training': True,
        'policies': None
    })

    ppo_trainer = ppo.PPOTrainer(env=InventoryManageEnv, config=ext_conf)
    max_mean_reward = -100

    ppo_trainer.workers.foreach_worker(lambda ev: ev.foreach_env(
        lambda env: env.set_policies(get_policy(env, ev))))

    for i in range(args.stop_iters):
        print("== Iteration", i, "==", flush=True)

        ppo_trainer.workers.foreach_worker(lambda ev: ev.foreach_env(
            lambda env: env.set_iteration(i, args.stop_iters)))
        result = ppo_trainer.train()
        print_training_results(result)
        now_mean_reward = result['policy_reward_mean']['ppo_store_consumer']

        if (
                i + 1
        ) % args.visualization_frequency == 0 or now_mean_reward > max_mean_reward:
            max_mean_reward = max(max_mean_reward, now_mean_reward)
            checkpoint = ppo_trainer.save()
            print("checkpoint saved at", checkpoint, flush=True)
            visualization(InventoryManageEnv(env_config.copy()),
                          get_policy(env, ppo_trainer), i, args.run_name)
            # exit(0)

    return ppo_trainer
Ejemplo n.º 4
0
    else:
        policy_mode = f"base_stock_"+("static" if is_static else "dynamic")+\
            f"_gap{args.buyin_gap}_updt{args.update_interval}_start{args.start_step}"
    
    # ConsumerBaseStockPolicy static fields

    def policy_oracle_setup():
        ConsumerBaseStockPolicy.time_hrz_len = env_config['evaluation_len']
        ConsumerBaseStockPolicy.oracle = True
        ConsumerBaseStockPolicy.has_order_cost = False

    ConsumerBaseStockPolicy.buyin_gap = args.buyin_gap
    
    # setup
    if is_oracle:
        policy_oracle_setup()
    else:
        ConsumerBaseStockPolicy.time_hrz_len = env_config['sale_hist_len']
    
    # always set these fields
    ConsumerBaseStockPolicy.update_interval = args.update_interval
    ConsumerBaseStockPolicy.start_step = args.start_step

    policies = get_policies(is_static)
    # ray.init()

    # Simulation loop
    vis_env = InventoryManageEnv(env_config.copy())

    visualization(vis_env, policies, 0, policy_mode, basestock=True)
def train_dqn(args):
    if not os.path.exists('train_log'):
        os.mkdir('train_log')
    writer = TensorBoard(f'train_log/{args.run_name}')

    dqn_config = dqn_config_default.copy()

    dqn_config.update({
        "batch_size": 1024,
        "min_replay_history": 10000,
        "training_steps": 1500,
        "lr": 0.0003,
        "target_update_period": 1000,
        "gamma": args.gamma,
        "use_unc_part": args.use_unc_part,
        "pretrain": args.pretrain,
        "fixed_uncontrollable_param": args.fixed_uncontrollable_param,
        "use_cnn_state": args.use_cnn_state,
        "pretrain_epoch": args.pretrain_epoch,
        "embeddingmerge": args.embeddingmerge,  # 'cat' or 'dot'
        "activation_func": args.activation_func,  # 'relu', 'sigmoid', 'tanh'
        "use_bn": args.use_bn,
        "weight_decay": args.weight_decay,

        # data augmentation setting
        "train_augmentation": args.train_augmentation,
        "noise_scale": args.noise_scale,
        "sparse_scale": args.sparse_scale,
        # ....
    })

    print(dqn_config)

    if dqn_config["use_cnn_state"]:
        global_config.use_cnn_state = True

    if args.training_length != 4 * 365:
        global_config.training_length = args.training_length

    if args.oracle:
        global_config.oracle = True
        if dqn_config['pretrain']:
            raise Exception("dqn oracle does not support pretrain")
        if dqn_config['train_augmentation'] != 'none':
            raise Exception("dqn oracle should not use augmentation")

    # if args.env_demand_noise != 'none':
    #     env.env_config['init'] = 'rst'

    # all_policies['dqn_store_consumer'] = ConsumerDQNTorchPolicy(env.observation_space, env.action_space_consumer, BaselinePolicy.get_config_from_env(env), dqn_config)
    all_policies[
        'dqn_store_consumer'] = ConsumerRepresentationLearningDQNTorchPolicy(
            env.observation_space, env.action_space_consumer,
            BaselinePolicy.get_config_from_env(env), dqn_config)

    obss = env.reset()
    agent_ids = obss.keys()
    policies = {}
    policies_to_train = []

    for agent_id in agent_ids:
        policy, if_train = policy_map_fn(agent_id)
        policies[agent_id] = all_policies[policy]
        if if_train:
            policies_to_train.append(agent_id)

    dqn_trainer = Trainer(env, policies, policies_to_train, dqn_config)
    max_mean_reward = -10000000000
    if dqn_config['pretrain']:
        # global_config.random_noise = 'none'

        print('start load data ...')
        dqn_trainer.load_data(eval=False)
        dqn_trainer.load_data(eval=True)
        # now_mean_reward = print_result(result, writer, -1)
        print('load success!')
        print('start pre-training ...')
        all_policies['dqn_store_consumer'].pre_train(writer)
        print('pre-training success!')

    debug = False

    for i in range(args.num_iterations):

        # if args.env_demand_noise != 'none':
        #     global_config.random_noise = args.env_demand_noise
        result = dqn_trainer.train(i)
        print_result(result, writer, i)
        # global_config.random_noise = 'none'

        eval_on_trainingset_result = dqn_trainer.eval(i,
                                                      eval_on_trainingset=True)
        print_eval_on_trainingset_result(eval_on_trainingset_result, writer, i)

        eval_result = dqn_trainer.eval(i)
        eval_mean_reward = print_eval_result(eval_result, writer, i)

        if eval_mean_reward > max_mean_reward or debug:
            max_mean_reward = max(max_mean_reward, eval_mean_reward)
            dqn_trainer.save(args.run_name, i)
            # print("checkpoint saved at", checkpoint)
            visualization(env, policies, i, args.run_name)