Beispiel #1
0
    def __init__(self, observation_space, action_space, config, dqn_config):
        BaselinePolicy.__init__(self, observation_space, action_space, config)

        self.dqn_config = dqn_config
        self.epsilon = 1
        self.Vmin = -1
        self.Vmax = 1
        self.atoms = 51
        self.device = torch.device('cpu')

        self.num_states = int(np.product(observation_space.shape))
        self.num_actions = int(action_space.n)
        print(
            f'dqn state space:{self.num_states}, action space:{self.num_actions}'
        )

        self.eval_net = c51_net(self.num_states, self.num_actions, self.atoms)
        self.target_net = c51_net(self.num_states, self.num_actions,
                                  self.atoms)

        self.target_net.load_state_dict(self.eval_net.state_dict())

        self.learn_step_counter = 0
        self.memory = replay_memory(dqn_config['replay_capacity'])

        self.optimizer = torch.optim.Adam(self.eval_net.parameters(),
                                          lr=dqn_config['lr'])
        self.support = torch.linspace(self.Vmin, self.Vmax, self.atoms)
    def __init__(self, observation_space, action_space, config, dqn_config):
        BaselinePolicy.__init__(self, observation_space, action_space, config)

        self.dqn_config = dqn_config
        self.epsilon = 1
        mixed_dqn_config = mixed_dqn_net_config_example.copy()
        mixed_dqn_config.update({
            "controllable_state_num" : int(np.product(observation_space.shape)),
            "action_num": int(action_space.n),
            "uncontrollable_state_num": 31,
            "uncontrollable_pred_num": 3,
        })
        self.num_states =  int(np.product(observation_space.shape))
        self.num_actions = int(action_space.n)
        print(f'dqn state space:{self.num_states}, action space:{self.num_actions}')
        self.use_unc_part = dqn_config['use_unc_part']

        self.eval_net = mixed_dqn_net(mixed_dqn_net_config_example)
        self.target_net = mixed_dqn_net(mixed_dqn_net_config_example)

        self.target_net.load_state_dict(self.eval_net.state_dict())

        self.learn_step_counter = 0
        self.memory = replay_memory(dqn_config['replay_capacity'])

        self.optimizer = torch.optim.Adam(self.eval_net.parameters(), lr=dqn_config['lr'])
        self.loss_func = nn.SmoothL1Loss()

        self.rand_action = 0
        self.greedy_action = 0
    def __init__(self, observation_space, action_space, config, dqn_config):
        BaselinePolicy.__init__(self, observation_space, action_space, config)

        self.dqn_config = dqn_config
        self.epsilon = 1

        self.num_states = int(np.product(observation_space.shape))
        self.num_actions = int(action_space.n)
        print(
            f'dqn state space:{self.num_states}, action space:{self.num_actions}'
        )
        self.pred_head = dqn_config['pred']

        self.eval_net = dqn_net(self.num_states, self.num_actions,
                                self.pred_head)
        self.target_net = dqn_net(self.num_states, self.num_actions,
                                  self.pred_head)

        self.target_net.load_state_dict(self.eval_net.state_dict())

        self.learn_step_counter = 0
        self.memory = replay_memory(dqn_config['replay_capacity'])

        self.optimizer = torch.optim.Adam(self.eval_net.parameters(),
                                          lr=dqn_config['lr'])
        self.loss_func = nn.SmoothL1Loss()

        self.rand_action = 0
        self.greedy_action = 0
Beispiel #4
0
 def load_base_policy(agent_id):
     if Utils.is_producer_agent(agent_id):
         return ProducerBaselinePolicy(
             env.observation_space, env.action_space_producer,
             BaselinePolicy.get_config_from_env(env))
     else:
         return ConsumerBaselinePolicy(
             env.observation_space, env.action_space_consumer,
             BaselinePolicy.get_config_from_env(env))
Beispiel #5
0
 def load_policy(agent_id):
     if Utils.is_producer_agent(agent_id):
         return ProducerBaselinePolicy(
             env.observation_space, env.action_space_producer,
             BaselinePolicy.get_config_from_env(env))
     elif Utils.is_consumer_agent(agent_id):
         return ConsumerMinMaxPolicy(
             env.observation_space, env.action_space_consumer,
             BaselinePolicy.get_config_from_env(env))
     else:
         raise Exception(f'Unknown agent type {agent_id}')
Beispiel #6
0
 def load_policy(agent_id):
     _facility = env.world.facilities[Utils.agentid_to_fid(agent_id)]
     if Utils.is_producer_agent(agent_id):
         return ProducerBaselinePolicy(
             env.observation_space, env.action_space_producer,
             BaselinePolicy.get_config_from_env(env))
     elif isinstance(_facility, SKUStoreUnit) or isinstance(
             _facility, SKUWarehouseUnit):
         policy = ConsumerBaseStockPolicy(
             env.observation_space, env.action_space_consumer,
             BaselinePolicy.get_config_from_env(env))
         policy.base_stock = sku_base_stocks[Utils.agentid_to_fid(agent_id)]
         return policy
     else:
         return ConsumerBaselinePolicy(
             env.observation_space, env.action_space_consumer,
             BaselinePolicy.get_config_from_env(env))
Beispiel #7
0
def train_dqn(args):
    if not os.path.exists('train_log'):
        os.mkdir('train_log')
    writer = TensorBoard(f'train_log/{args.run_name}')

    dqn_config = dqn_config_default.copy()

    dqn_config.update({
        "batch_size": 1024,
        "min_replay_history": 10000,
        "training_steps": 15000,
        "lr": 0.0003,
        # ....
    })

    all_policies['dqn_store_consumer'] = ConsumerC51TorchPolicy(
        env.observation_space, env.action_space_consumer,
        BaselinePolicy.get_config_from_env(env), dqn_config)

    obss = env.reset()
    agent_ids = obss.keys()
    policies = {}
    policies_to_train = []

    for agent_id in agent_ids:
        policy, if_train = policy_map_fn(agent_id)
        policies[agent_id] = all_policies[policy]
        if if_train:
            policies_to_train.append(agent_id)

    dqn_trainer = Trainer(env, policies, policies_to_train, dqn_config)
    max_mean_reward = -1000
    debug = False
    for i in range(args.num_iterations):
        result = dqn_trainer.train(i)
        now_mean_reward = print_result(result, writer, i)

        if now_mean_reward > max_mean_reward or debug:
            max_mean_reward = max(max_mean_reward, now_mean_reward)
            dqn_trainer.save(args.run_name, i)
            # print("checkpoint saved at", checkpoint)
            visualization(env, policies, i, args.run_name)
        # == LSTM ==
        "use_lstm": False,
        "max_seq_len": 14,
        "lstm_cell_size": 128, 
        "lstm_use_prev_action_reward": False
    }
}

# Model Configuration ===============================================================================
models.ModelCatalog.register_custom_model("sku_store_net", SKUStoreDNN)
models.ModelCatalog.register_custom_model("sku_warehouse_net", SKUWarehouseDNN)

MyTorchPolicy = PPOTorchPolicy

policies = {
        'baseline_producer': (ProducerBaselinePolicy, env.observation_space, env.action_space_producer, BaselinePolicy.get_config_from_env(env)),
        'baseline_consumer': (ConsumerBaselinePolicy, env.observation_space, env.action_space_consumer, BaselinePolicy.get_config_from_env(env)),
        'ppo_producer': (MyTorchPolicy, env.observation_space, env.action_space_producer, ppo_policy_config_producer),
        'ppo_store_consumer': (MyTorchPolicy, env.observation_space, env.action_space_consumer, ppo_policy_config_store_consumer),
        'ppo_warehouse_consumer': (MyTorchPolicy, env.observation_space, env.action_space_consumer, ppo_policy_config_warehouse_consumer)
    }

def filter_keys(d, keys):
    return {k:v for k,v in d.items() if k in keys}

# Training Routines ===============================================================================

def print_training_results(result):
    keys = ['date', 'episode_len_mean', 'episodes_total', 'episode_reward_max', 'episode_reward_mean', 'episode_reward_min', 
            'timesteps_total', 'policy_reward_max', 'policy_reward_mean', 'policy_reward_min']
    for k in keys:
Beispiel #9
0
    def __init__(self, observation_space, action_space, config, dqn_config):
        BaselinePolicy.__init__(self, observation_space, action_space, config)

        self.dqn_config = dqn_config
        self.epsilon = 1
        mixed_dqn_config = mixed_dqn_net_config_example.copy()
        mixed_dqn_config.update({
            "controllable_state_num":
            int(np.product(observation_space.shape)),
            "action_num":
            int(action_space.n),
            "uncontrollable_state_num":
            21 * 7,  # 31
            "uncontrollable_pred_num":
            6,
            'fixed_uncontrollable_param':
            dqn_config['fixed_uncontrollable_param'],
            'uncontrollable_use_cnn':
            dqn_config['use_cnn_state'],
            'embeddingmerge':
            dqn_config['embeddingmerge'],
            'activation_func':
            dqn_config['activation_func'],
            'use_bn':
            dqn_config['use_bn'],
        })
        self.num_states = int(np.product(observation_space.shape))
        self.num_actions = int(action_space.n)
        print(
            f'dqn state space:{self.num_states}, action space:{self.num_actions}'
        )

        self.use_unc_part = dqn_config['use_unc_part']
        #self.pre_train = dqn_config['pretrain']
        self.fixed_uncontrollable_param = dqn_config['pretrain']

        if dqn_config['use_cnn_state']:
            dqn_net = mixed_dqn_unc_cnn_net
        else:
            dqn_net = mixed_dqn_net
        self.eval_net = dqn_net(mixed_dqn_config)
        self.target_net = dqn_net(mixed_dqn_config)

        self.target_net.load_state_dict(self.eval_net.state_dict())

        self.learn_step_counter = 0
        self.memory = replay_memory(dqn_config['replay_capacity'])
        self.eval_memory = replay_memory(dqn_config['replay_capacity'])

        self.optimizer = torch.optim.Adam(
            self.eval_net.parameters(),
            lr=dqn_config['lr'],
            weight_decay=dqn_config['weight_decay'])
        self.loss_func = nn.SmoothL1Loss()

        self.rand_action = 0
        self.greedy_action = 0
        self.evaluation = False  # only for epsilon-greedy

        # augmentation setting
        self.train_augmentation = dqn_config['train_augmentation']
        self.demand_augmentation = demand_augmentation(
            noise_type=dqn_config['train_augmentation'],
            noise_scale=dqn_config['noise_scale'],
            sparse_scale=dqn_config['sparse_scale'])
Beispiel #10
0
    # "num_iterations": 200,
    "training_steps": 25000,
    "max_steps_per_episode": 60,
    "replay_capacity": 1000000,
    "batch_size": 2048,
    "double_q": True,
    # "nstep": 1,
}
env_config_for_rendering = env_config.copy()
episod_duration = env_config_for_rendering['episod_duration']
env = InventoryManageEnv(env_config_for_rendering)

all_policies = {
    'baseline_producer':
    ProducerBaselinePolicy(env.observation_space, env.action_space_producer,
                           BaselinePolicy.get_config_from_env(env)),
    'baseline_consumer':
    ConsumerBaselinePolicy(env.observation_space, env.action_space_consumer,
                           BaselinePolicy.get_config_from_env(env)),
    'dqn_store_consumer':
    ConsumerBaselinePolicy(env.observation_space, env.action_space_consumer,
                           BaselinePolicy.get_config_from_env(env)),
}


def policy_map_fn(agent_id):
    if Utils.is_producer_agent(agent_id):
        return 'baseline_producer', False
    else:
        if agent_id.startswith('SKUStoreUnit') or agent_id.startswith(
                'OuterSKUStoreUnit'):
Beispiel #11
0
 def __init__(self, observation_space, action_space, config):
     BaselinePolicy.__init__(self, observation_space, action_space, config)
 def load_policy(agent_id):
     if Utils.is_producer_agent(agent_id):
         return ProducerBaselinePolicy(env.observation_space, env.action_space_producer, BaselinePolicy.get_config_from_env(env))
     if agent_id.startswith('SKUStoreUnit') or agent_id.startswith('OuterSKUStoreUnit'):
         return ConsumerEOQPolicy(env.observation_space, env.action_space_consumer, BaselinePolicy.get_config_from_env(env))
     else:
         return ConsumerBaselinePolicy(env.observation_space, env.action_space_consumer, BaselinePolicy.get_config_from_env(env))
def train_dqn(args):
    if not os.path.exists('train_log'):
        os.mkdir('train_log')
    writer = TensorBoard(f'train_log/{args.run_name}')

    dqn_config = dqn_config_default.copy()

    dqn_config.update({
        "batch_size": 1024,
        "min_replay_history": 10000,
        "training_steps": 1500,
        "lr": 0.0003,
        "target_update_period": 1000,
        "gamma": args.gamma,
        "use_unc_part": args.use_unc_part,
        "pretrain": args.pretrain,
        "fixed_uncontrollable_param": args.fixed_uncontrollable_param,
        "use_cnn_state": args.use_cnn_state,
        "pretrain_epoch": args.pretrain_epoch,
        "embeddingmerge": args.embeddingmerge,  # 'cat' or 'dot'
        "activation_func": args.activation_func,  # 'relu', 'sigmoid', 'tanh'
        "use_bn": args.use_bn,
        "weight_decay": args.weight_decay,

        # data augmentation setting
        "train_augmentation": args.train_augmentation,
        "noise_scale": args.noise_scale,
        "sparse_scale": args.sparse_scale,
        # ....
    })

    print(dqn_config)

    if dqn_config["use_cnn_state"]:
        global_config.use_cnn_state = True

    if args.training_length != 4 * 365:
        global_config.training_length = args.training_length

    if args.oracle:
        global_config.oracle = True
        if dqn_config['pretrain']:
            raise Exception("dqn oracle does not support pretrain")
        if dqn_config['train_augmentation'] != 'none':
            raise Exception("dqn oracle should not use augmentation")

    # if args.env_demand_noise != 'none':
    #     env.env_config['init'] = 'rst'

    # all_policies['dqn_store_consumer'] = ConsumerDQNTorchPolicy(env.observation_space, env.action_space_consumer, BaselinePolicy.get_config_from_env(env), dqn_config)
    all_policies[
        'dqn_store_consumer'] = ConsumerRepresentationLearningDQNTorchPolicy(
            env.observation_space, env.action_space_consumer,
            BaselinePolicy.get_config_from_env(env), dqn_config)

    obss = env.reset()
    agent_ids = obss.keys()
    policies = {}
    policies_to_train = []

    for agent_id in agent_ids:
        policy, if_train = policy_map_fn(agent_id)
        policies[agent_id] = all_policies[policy]
        if if_train:
            policies_to_train.append(agent_id)

    dqn_trainer = Trainer(env, policies, policies_to_train, dqn_config)
    max_mean_reward = -10000000000
    if dqn_config['pretrain']:
        # global_config.random_noise = 'none'

        print('start load data ...')
        dqn_trainer.load_data(eval=False)
        dqn_trainer.load_data(eval=True)
        # now_mean_reward = print_result(result, writer, -1)
        print('load success!')
        print('start pre-training ...')
        all_policies['dqn_store_consumer'].pre_train(writer)
        print('pre-training success!')

    debug = False

    for i in range(args.num_iterations):

        # if args.env_demand_noise != 'none':
        #     global_config.random_noise = args.env_demand_noise
        result = dqn_trainer.train(i)
        print_result(result, writer, i)
        # global_config.random_noise = 'none'

        eval_on_trainingset_result = dqn_trainer.eval(i,
                                                      eval_on_trainingset=True)
        print_eval_on_trainingset_result(eval_on_trainingset_result, writer, i)

        eval_result = dqn_trainer.eval(i)
        eval_mean_reward = print_eval_result(eval_result, writer, i)

        if eval_mean_reward > max_mean_reward or debug:
            max_mean_reward = max(max_mean_reward, eval_mean_reward)
            dqn_trainer.save(args.run_name, i)
            # print("checkpoint saved at", checkpoint)
            visualization(env, policies, i, args.run_name)