예제 #1
0
def main(args):
    render = args.render
    if not render:
        import matplotlib
        matplotlib.use('Agg')
        import matplotlib.pyplot as plt
    from utils.utils import TabularPolicy, LookAheadPolicy, SimpleMaxPolicy
    from utils.value_function import CNNValueFun, FFNNValueFun, TabularValueFun
    from algos.function_approximate_value_iteration import FunctionApproximateValueIteration
    from envs import ASRSEnv, ProbDistEnv

    assert np.array(eval(args.storage_shape)).prod() == len(
        eval(args.dist_param)
    ), 'storage_shape should be consistent with dist_param length'
    env = ProbDistEnv(
        ASRSEnv(eval(args.storage_shape),
                origin_coord=eval(args.exit_coord),
                dist_param=eval(args.dist_param)))

    env_name = env.__name__
    exp_dir = os.getcwd() + '/data/version3/%s/policy_type%s_envsize_%s/' % (
        env_name, args.policy_type, np.array(eval(args.storage_shape)).prod())
    logger.configure(dir=exp_dir,
                     format_strs=['stdout', 'log', 'csv'],
                     level=eval(args.logger_level))
    args_dict = vars(args)
    args_dict['env'] = env_name
    json.dump(vars(args),
              open(exp_dir + '/params.json', 'w'),
              indent=2,
              sort_keys=True)

    value_fun = FFNNValueFun(env)
    policy = SimpleMaxPolicy(env, value_fun, num_acts=args.num_acts)
    # policy = LookAheadPolicy(env,
    #                         value_fun,
    #                         horizon=args.horizon,
    #                         look_ahead_type=args.policy_type,
    #                         num_acts=args.num_acts)
    algo = FunctionApproximateValueIteration(env,
                                             value_fun,
                                             policy,
                                             learning_rate=args.learning_rate,
                                             batch_size=args.batch_size,
                                             num_acts=args.num_acts,
                                             render=render,
                                             num_rollouts=args.num_rollouts,
                                             max_itr=args.max_iter,
                                             log_itr=5)
    algo.train()
예제 #2
0
def main(args):
    render = args.render
    if not render:
        import matplotlib
        matplotlib.use('Agg')
        import matplotlib.pyplot as plt
    from utils.utils import TabularPolicy
    from utils.value_function import TabularValueFun
    from algos.tabular_value_iteration import ValueIteration
    from envs import ASRSEnv, TabularEnv, ProbDistEnv, DynamicProbEnv, StaticOrderProcess, SeasonalOrderProcess

    num_products = np.array(eval(args.storage_shape)).prod()
    assert (eval(args.dist_param) is None) or (num_products == len(
        eval(args.dist_param)
    )), 'storage_shape should be consistent with dist_param length'
    op = StaticOrderProcess(num_products=num_products,
                            dist_param=eval(args.dist_param))

    base_env = ASRSEnv(eval(args.storage_shape),
                       order_process=op,
                       origin_coord=eval(args.exit_coord))

    env = TabularEnv(base_env)

    env_name = env.__name__
    exp_dir = os.getcwd(
    ) + '/data/version1/%s/policy_type%s_temperature%s_envsize_%s/' % (
        env_name, args.policy_type, args.temperature,
        np.array(eval(args.storage_shape)).prod())
    logger.configure(dir=exp_dir,
                     format_strs=['stdout', 'log', 'csv'],
                     level=eval(args.logger_level))
    args_dict = vars(args)
    args_dict['env'] = env_name
    json.dump(vars(args),
              open(exp_dir + '/params.json', 'w'),
              indent=2,
              sort_keys=True)

    policy = TabularPolicy(env)
    value_fun = TabularValueFun(env)
    algo = ValueIteration(env,
                          value_fun,
                          policy,
                          policy_type=args.policy_type,
                          render=render,
                          temperature=args.temperature,
                          num_rollouts=args.num_rollouts)
    algo.train()
    value_fun.save(f'{exp_dir}/value_fun.npy')
예제 #3
0
        for i in range(self.num_products):
            color = cmap((i+1)/self.num_products)
            print(color)
            plt.plot(test_p_sequence[:,i], c=color, linestyle='-')  
            plt.plot(test_p_sequence_hat[:,i], c=color, linestyle=':')
        plt.xlabel("t")
        plt.ylabel("p")
        if save_to:
            plt.savefig('%s/%s.png' % (save_to,figure_name))
        else:
            plt.show()

    def save(self, filepath):
        self.model.save(filepath)

    def load(self, filepath):
        self.model = load_model(filepath)


if __name__ == "__main__":
    from envs import ASRSEnv
    dynamic_order = False
    base_env1 = ASRSEnv((2,5),dist_param = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9,0.95],dynamic_order = dynamic_order, beta = 1)
    rnn1 = RNNDemandPredictor(base_env1,look_back=1000, init_num_period = 10000, epochs = 2)
    rnn1.test_performance_plot(2000, save_to = 'data/',figure_name='rnn_performance_static', figsize=(11.4,4))

    dynamic_order = True
    base_env = ASRSEnv((2,5),dist_param = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9,0.95],dynamic_order = dynamic_order, beta = 1)
    rnn = RNNDemandPredictor(base_env,look_back=1000, init_num_period = 10000, epochs = 2)
    rnn.test_performance_plot(2000, save_to = 'data/', figure_name='rnn_performance_dynamic', figsize=(11.4,4))
예제 #4
0
        Returns:
            attribute of the wrapped_env

        """
        # orig_attr = self._wrapped_env.__getattribute__(attr)
        if hasattr(self._wrapped_env, '_wrapped_env'):
            orig_attr = self._wrapped_env.__getattr__(attr)
        else:
            orig_attr = self._wrapped_env.__getattribute__(attr)

        if callable(orig_attr):

            def hooked(*args, **kwargs):
                result = orig_attr(*args, **kwargs)
                return result

            return hooked
        else:
            return orig_attr


if __name__ == '__main__':
    a = ASRSEnv((3, 3, 4))
    a = ASRSEnv((2, 1, 1))
    b = TabularEnv(a)
    print(b.step())
    for i in range(10):
        print(b.step())
        if i % 2 == 0:
            print(b.step((0, 5)))
예제 #5
0
        Returns:
            attribute of the wrapped_env

        """
        # orig_attr = self._wrapped_env.__getattribute__(attr)
        if hasattr(self._wrapped_env, '_wrapped_env'):
            orig_attr = self._wrapped_env.__getattr__(attr)
        else:
            orig_attr = self._wrapped_env.__getattribute__(attr)

        if callable(orig_attr):

            def hooked(*args, **kwargs):
                result = orig_attr(*args, **kwargs)
                return result

            return hooked
        else:
            return orig_attr


if __name__ == '__main__':
    a = ASRSEnv((3, 3))
    a = ASRSEnv((2, 1, 1))
    b = MapAsPicEnv(a)
    print(b.step())
    for i in range(10):
        print(b.step())
        if i % 2 == 0:
            print(b.step((0, 5)))
예제 #6
0
def main(args):
    render = args.render
    if not render:
        import matplotlib
        matplotlib.use('Agg')
        import matplotlib.pyplot as plt
    from utils.utils import TabularPolicy, LookAheadPolicy, SimpleMaxPolicy
    from utils.value_function import CNNValueFun, FFNNValueFun, TabularValueFun
    from algos import FunctionApproximateValueIteration, RNNDemandPredictor, TruePPredictor
    from envs import ASRSEnv, ProbDistEnv, DynamicProbEnv, StaticOrderProcess, SeasonalOrderProcess

    num_products = np.array(eval(args.storage_shape)).prod()
    assert (eval(args.dist_param) is None) or (num_products == len(eval(args.dist_param))), 'storage_shape should be consistent with dist_param length'
    if args.dynamic_order:
        op = SeasonalOrderProcess(num_products = num_products, dist_param = eval(args.dist_param),season_length = 500, beta=1, rho=0.99)
    else:
        op = StaticOrderProcess(num_products = num_products, dist_param = eval(args.dist_param)) 

    base_env = ASRSEnv(eval(args.storage_shape), order_process = op, origin_coord=eval(args.exit_coord))
    if args.true_p:
        true_p = TruePPredictor(base_env, look_back=10, dynamic=args.dynamic_order, init_num_period = args.rnn_init_num_period, num_p_in_states = args.num_p_in_states)
        env = DynamicProbEnv(base_env,demand_predictor = true_p, alpha=1, num_p_in_states = args.num_p_in_states)
    else:
        rnn = RNNDemandPredictor(base_env,look_back=args.rnn_lookback, init_num_period = args.rnn_init_num_period, epochs = args.rnn_epoch)
        env = DynamicProbEnv(base_env,demand_predictor = rnn, alpha=1, num_p_in_states = args.num_p_in_states)

    env_name = env.__name__
    # exp_dir = os.getcwd() + '/data/version4/%s/policy_type%s_envsize_%s_dynamic_%s_p_hat_%s_%s/' % (env_name, args.policy_type,np.array(eval(args.storage_shape)).prod(), args.dynamic_order, not args.true_p, args.exp_name)
    exp_dir = os.getcwd() + '/data/report/%s/policy_type%s_envsize_%s_dynamic_%s_p_hat_%s_%s/' % (env_name, args.policy_type,np.array(eval(args.storage_shape)).prod(), args.dynamic_order, not args.true_p, args.exp_name)
    logger.configure(dir=exp_dir, format_strs=['stdout', 'log', 'csv'], level=eval(args.logger_level))
    args_dict = vars(args)
    args_dict['env'] = env_name
    json.dump(vars(args), open(exp_dir + '/params.json', 'w'), indent=2, sort_keys=True)
    if not args.true_p:
        rnn.test_performance_plot(2000, save_to=exp_dir)
        rnn.save(f'{exp_dir}/rnn_model.h5')

    value_fun = FFNNValueFun(env)
    policy = SimpleMaxPolicy(env,
                            value_fun,
                            num_acts = args.num_acts,
                            all_actions= args.all_actions)
    # policy = LookAheadPolicy(env,
    #                         value_fun,
    #                         horizon=args.horizon,
    #                         look_ahead_type=args.policy_type,
    #                         num_acts=args.num_acts)
    if args.dynamic_order:
        last_max_path_length = args.last_max_path_length
    else:
        last_max_path_length = args.max_path_length
        
    algo = FunctionApproximateValueIteration(env,
                            value_fun,
                            policy,
                            learning_rate=args.learning_rate,
                            batch_size=args.batch_size,
                            num_acts=args.num_acts,
                            all_actions= args.all_actions,
                            render=render,
                            num_rollouts = args.num_rollouts,
                            max_itr=args.max_iter,
                            log_itr=5,
                            max_path_length=args.max_path_length,
                            last_max_path_length=last_max_path_length
                            )
    algo.train()
    value_fun.save(f'{exp_dir}/value_fun.h5')
예제 #7
0
        Returns:
            attribute of the wrapped_env

        """
        # orig_attr = self._wrapped_env.__getattribute__(attr)
        if hasattr(self._wrapped_env, '_wrapped_env'):
            orig_attr = self._wrapped_env.__getattr__(attr)
        else:
            orig_attr = self._wrapped_env.__getattribute__(attr)

        if callable(orig_attr):
            def hooked(*args, **kwargs):
                result = orig_attr(*args, **kwargs)
                return result

            return hooked
        else:
            return orig_attr


if __name__ == '__main__':
    from envs import ASRSEnv, StaticOrderProcess
    from algos import TruePPredictor
    op = StaticOrderProcess(num_products = 9) 
    base_env = ASRSEnv((3, 3), op)
    true_p = TruePPredictor(base_env, look_back=1000, dynamic=False, init_num_period = 1000, num_p_in_states =1)
    env = DynamicProbEnv(base_env,demand_predictor = true_p, alpha=1, num_p_in_states = 1)
    env.sample_actions(5)
    env.sample_states(5)
    print(env.step(None))