Пример #1
0
def main(_):
    FLAGS.agent = model(params=FLAGS)
    FLAGS.environment = get_env(FLAGS)
    FLAGS.act = action()

    FLAGS.step_max = FLAGS.environment.data_len()
    FLAGS.train_freq = 40
    FLAGS.update_q_freq = 50
    FLAGS.gamma = 0.97
    FLAGS.show_log_freq = 5
    FLAGS.memory = []  #Experience(FLAGS.memory_size)

    init = tf.global_variables_initializer()
    saver = tf.train.Saver()

    #创建用于保存模型的目录
    if not os.path.exists(FLAGS.model_dir):
        os.makedirs(FLAGS.model_dir)
    start = time.time()

    with tf.Session() as sess:
        sess.run(init)
        eval = evaluation(FLAGS, sess)
        ckpt = tf.train.get_checkpoint_state(FLAGS.model_dir)
        if ckpt:
            print('Loading Model...')
            saver.restore(sess, ckpt.model_checkpoint_path)
        total_step = 1
        print('\t'.join(
            map(str, [
                "epoch", "epsilon", "total_step", "rewardPerEpoch", "profits",
                "lossPerBatch", "elapsed_time"
            ])))
        for epoch in range(FLAGS.epoch_num):
            avg_loss_per_batch, total_reward, total_step, profits = run_epch(
                FLAGS, sess, total_step)
            # total_rewards.append(total_reward)
            # total_losses.append(total_loss)

            if (epoch + 1) % FLAGS.show_log_freq == 0:
                # log_reward = sum(total_rewards[((epoch+1)-FLAGS.show_log_freq):])/FLAGS.show_log_freq
                # log_loss = sum(total_losses[((epoch+1)-FLAGS.show_log_freq):])/FLAGS.show_log_freq
                elapsed_time = time.time() - start
                #print('\t'.join(map(str, [epoch+1, FLAGS.act.epsilon, total_step, log_reward, log_loss, elapsed_time])))
                print('\t'.join(
                    map(str, [
                        epoch + 1, FLAGS.act.epsilon, total_step, total_reward,
                        profits, avg_loss_per_batch, elapsed_time
                    ])))
                start = time.time()

                saver.save(
                    sess,
                    FLAGS.model_dir + '\model-' + str(epoch + 1) + '.ckpt')
                eval.eval()
Пример #2
0
def evaluate():
    if ARGS.replay == 'PER':
        filename = ARGS.results_path + "/" + 'weights_' + str(ARGS.replay) + \
                   '_' + ARGS.pmethod + '_' + ARGS.env + '_.pt'
    else:
        filename = ARGS.results_path + "/" + 'weights_' + str(
            ARGS.replay) + '_' + ARGS.env + '_.pt'

    env, (input_size, output_size) = get_env(ARGS.env)
    # set env seed
    env.seed(ARGS.seed_value)

    network = {
        'CartPole-v1':
        CartNetwork(input_size, output_size, ARGS.num_hidden).to(device),
        'MountainCar-v0':
        MountainNetwork(input_size, output_size, ARGS.num_hidden).to(device),
        'LunarLander-v2':
        LanderNetwork(input_size, output_size, ARGS.num_hidden).to(device)
    }

    model = network[ARGS.env]
    model.eval()
    if os.path.isfile(filename):
        print(f"Loading weights from {filename}")
        # weights = torch.load(filename)
        weights = torch.load(filename,
                             map_location=lambda storage, loc: storage)
        model.load_state_dict(weights['policy'])
    else:
        print("Please train the model or provide the saved 'weights.pt' file")
    episode_durations = []
    for i in range(20):
        state = env.reset()
        done = False
        steps = 0
        while not done:
            steps += 1
            with torch.no_grad():
                state = torch.tensor(state, dtype=torch.float).to(device)
                action = get_action(state, model).item()
                state, reward, done, _ = env.step(action)
                env.render()
        episode_durations.append(steps)
        print(i)
    env.close()

    plt.plot(episode_durations)
    plt.title('Episode durations')
    plt.show()
Пример #3
0
def main(args):
    __log.setLevel(level=getattr(logging, args.log_level))
    os.makedirs(args.save_dir, exist_ok=True)

    env = get_env(args.env, app_path=args.app_path)
    actor = get_actor(args.actor,
                      input_shapes=[[None, env.state_size]],
                      output_shapes=[[None, env.action_size]],
                      load_path=args.ckpt_actor)

    if args.train:
        Algorithm = get_algorithm(args.algorithm)
        algorithm = Algorithm(env=env, actor=actor)
        algorithm.run(**vars(args))

    if args.eval:
        evaluate(env, actor, n_eval_episode=args.n_eval_episode)

    if args.play:
        show_agent_play(env, actor)
Пример #4
0
def main():
    # update this disctionary as per the implementation of methods
    memory = {
        'NaiveReplayMemory': NaiveReplayMemory,
        'CombinedReplayMemory': CombinedReplayMemory,
        'PER': PrioritizedReplayMemory
    }

    if ARGS.adaptive_buffer:
        # Introduces the buffer manager for the adaptive buffer size.
        manage_memory = BufferSizeManager(initial_capacity=ARGS.buffer,
                                          size_change=ARGS.buffer_step_size)

    # environment
    env, (input_size, output_size) = get_env(ARGS.env)
    env.seed(ARGS.seed_value)

    network = {
        'CartPole-v1':
        CartNetwork(input_size, output_size, ARGS.num_hidden).to(device),
        'MountainCar-v0':
        MountainNetwork(input_size, output_size, ARGS.num_hidden).to(device),
        'LunarLander-v2':
        LanderNetwork(input_size, output_size, ARGS.num_hidden).to(device)
    }

    # create new file to store durations
    i = 0
    fd_name = ARGS.results_path + "/" + str(ARGS.buffer) + "_" + str(
        ARGS.replay) + "_" + str(
            ARGS.pmethod) + '_' + ARGS.env + "_durations0.txt"
    exists = os.path.isfile(fd_name)
    while exists:
        i += 1
        fd_name = ARGS.results_path + "/" + str(ARGS.buffer) + "_" + str(
            ARGS.replay) + "_" + str(
                ARGS.pmethod) + '_' + ARGS.env + "_durations%d.txt" % i
        exists = os.path.isfile(fd_name)
    fd = open(fd_name, "w+")

    # create new file to store rewards
    i = 0
    fr_name = ARGS.results_path + "/" + str(ARGS.buffer) + "_" + str(
        ARGS.replay) + "_" + str(
            ARGS.pmethod) + '_' + ARGS.env + "_rewards0.txt"
    exists = os.path.isfile(fr_name)
    while exists:
        i += 1
        fr_name = ARGS.results_path + "/" + str(ARGS.buffer) + "_" + str(
            ARGS.replay) + "_" + str(
                ARGS.pmethod) + '_' + ARGS.env + "_rewards%d.txt" % i
        exists = os.path.isfile(fr_name)
    fr = open(fr_name, "w+")

    # Save experiment hyperparams
    i = 0
    exists = os.path.isfile(ARGS.results_path + "/" + str(ARGS.buffer) + "_" +
                            str(ARGS.replay) + "_" + str(ARGS.pmethod) + '_' +
                            ARGS.env + "_info0.txt")
    while exists:
        i += 1
        exists = os.path.isfile(ARGS.results_path + "/" + str(ARGS.buffer) +
                                "_" + str(ARGS.replay) + "_" +
                                str(ARGS.pmethod) + '_' + ARGS.env +
                                "_info%d.txt" % i)
    fi = open(
        ARGS.results_path + "/" + str(ARGS.buffer) + "_" + str(ARGS.replay) +
        "_" + str(ARGS.pmethod) + '_' + ARGS.env + "_info%d.txt" % i, "w+")
    file_counter = i
    fi.write(str(ARGS))
    fi.close()

    # -----------initialization---------------
    if ARGS.replay == 'PER':
        replay = memory[ARGS.replay](ARGS.buffer, ARGS.pmethod)
        filename = ARGS.results_path + "/" + str(
            ARGS.buffer
        ) + "_" + 'weights_' + str(
            ARGS.replay
        ) + '_' + ARGS.pmethod + '_' + ARGS.env + "_%d.pt" % ARGS.seed_value  # file_counter  # +'_.pt'
    else:
        replay = memory[ARGS.replay](ARGS.buffer)
        filename = ARGS.results_path + "/" + str(
            ARGS.buffer
        ) + "_" + 'weights_' + str(
            ARGS.replay
        ) + '_' + ARGS.env + "_%d.pt" % ARGS.seed_value  # file_counter  # +'_.pt'

    model = network[ARGS.env]  # local network
    model_target = network[ARGS.env]  # target_network

    optimizer = optim.Adam(model.parameters(), ARGS.lr)

    # Count the steps (do not reset at episode start, to compute epsilon)
    global_steps = 0
    episode_durations = []
    rewards_per_episode = []
    buffer_sizes = []

    scores_window = deque(maxlen=100)
    eps = ARGS.EPS
    # -------------------------------------------------------

    for i_episode in tqdm(range(ARGS.num_episodes), ncols=50):
        # Sample a transition
        s = env.reset()
        done = False
        epi_duration = 0
        r_sum = 0
        buffer_sizes.append(len(replay))

        # for debugging purposes:
        if (ARGS.debug_mode):
            print(
                f"buffer size: {len(replay)}, r: {episode_durations[-1] if len(episode_durations) >= 1 else 0}"
            )

        render_env_bool = False
        if (ARGS.render_env > 0) and not (i_episode % ARGS.render_env):
            render_env_bool = True
            env.render()

        max_steps = 1000
        for t in range(max_steps):
            # eps = get_epsilon(global_steps) # Comment this to to not use linear decay

            model.eval()
            a = select_action(model, s, eps)

            model.train()
            s_next, r, done, _ = env.step(a)

            beta = None

            # The TD-error is necessary if replay == PER OR if we are using adaptive buffer and the memory is full
            get_td_error = (ARGS.replay == 'PER') or (ARGS.adaptive_buffer
                                                      and replay.memory_full())

            if get_td_error:
                state = torch.tensor(s,
                                     dtype=torch.float).to(device).unsqueeze(0)
                action = torch.tensor(
                    a, dtype=torch.int64).to(device).unsqueeze(
                        0)  # Need 64 bit to use them as index
                next_state = torch.tensor(
                    s_next, dtype=torch.float).to(device).unsqueeze(0)
                reward = torch.tensor(
                    r, dtype=torch.float).to(device).unsqueeze(0)
                done_ = torch.tensor(done,
                                     dtype=torch.uint8).to(device).unsqueeze(0)
                with torch.no_grad():
                    q_val = compute_q_val(model, state, action)
                    target = compute_target(model_target, reward, next_state,
                                            done_, ARGS.discount_factor)
                td_error = F.smooth_l1_loss(q_val, target)

                if ARGS.adaptive_buffer and replay.memory_full():
                    new_buffer_size = manage_memory.update_memory_size(
                        td_error.item())
                    replay.resize_memory(new_buffer_size)

            if ARGS.replay == 'PER':
                replay.push(abs(td_error), (s, a, r, s_next, done))
                beta = get_beta(i_episode, ARGS.num_episodes, ARGS.beta0)
            else:
                replay.push((s, a, r, s_next, done))

            loss = train(model,
                         model_target,
                         replay,
                         optimizer,
                         ARGS.batch_size,
                         ARGS.discount_factor,
                         ARGS.TAU,
                         global_steps,
                         beta=beta)

            s = s_next
            epi_duration += 1
            global_steps += 1

            if done:
                break

            r_sum += r
            # visualize
            if render_env_bool:
                env.render()

        eps = max(0.01, ARGS.eps_decay * eps)
        rewards_per_episode.append(r_sum)
        episode_durations.append(epi_duration)
        scores_window.append(r_sum)

        # store episode data in files
        fr.write("%d\n" % r_sum)
        fr.close()
        fr = open(fr_name, "a")

        fd.write("%d\n" % epi_duration)
        fd.close()
        fd = open(fd_name, "a")

        if i_episode % 100 == 0:
            print('\rEpisode {}\tAverage Score: {:.2f}'.format(
                i_episode, np.mean(scores_window)))
        # if np.mean(scores_window)>=200.0:
        #     print('\nEnvironment solved in {:d} episodes!\tAverage Score: {:.2f}'.format(i_episode-100, np.mean(scores_window)))
        # break

        # if epi_duration >= 500: # this value is environment dependent
        #     print("Failed to complete in trial {}".format(i_episode))

        # else:
        # print("Completed in {} trials".format(i_episode))
        # break

    # close files
    fd.close()
    fr.close()
    env.close()

    # TODO: save all stats in numpy (pickle)
    b_name = ARGS.results_path + "/" + str(ARGS.buffer) + "_" + str(
        ARGS.replay) + "_" + str(
            ARGS.pmethod) + '_' + ARGS.env + "_buffers_sizes_" + str(
                ARGS.seed_value)
    np.save(b_name, buffer_sizes)

    print(f"max episode duration {max(episode_durations)}")
    print(f"Saving weights to {filename}")
    torch.save(
        {
            # You can add more here if you need, e.g. critic
            'policy':
            model.state_dict()  # Always save weights rather than objects
        },
        filename)

    plt.plot(smooth(episode_durations, 10))
    plt.title('Episode durations per episode')
    # plt.show()
    plt.savefig(ARGS.images_path + "/" + str(ARGS.buffer) + "_" +
                str(ARGS.replay) + '_' + ARGS.pmethod + '_' + ARGS.env +
                '_Episode' + "%d.png" % ARGS.seed_value)  # file_counter)

    plt.plot(smooth(rewards_per_episode, 10))
    plt.title("Rewards per episode")
    # plt.show()
    plt.savefig(ARGS.images_path + "/" + str(ARGS.buffer) + "_" +
                str(ARGS.replay) + '_' + ARGS.pmethod + '_' + ARGS.env +
                '_Rewards' + "%d.png" % ARGS.seed_value)  # file_counter)
    return episode_durations
Пример #5
0
def main():

    #update this disctionary as per the implementation of methods
    memory = {
        'NaiveReplayMemory': NaiveReplayMemory,
        'CombinedReplayMemory': CombinedReplayMemory,
        'PER': PrioritizedReplayMemory
    }

    # environment
    env, (input_size, output_size) = get_env(ARGS.env)
    # env.seed(seed_value)

    network = {
        'CartPole-v1':
        CartNetwork(input_size, output_size, ARGS.num_hidden).to(device),
        'MountainCar-v0':
        MountainNetwork(input_size, output_size, ARGS.num_hidden).to(device),
        'LunarLander-v2':
        LanderNetwork(input_size, output_size, ARGS.num_hidden).to(device)
    }

    # create new file to store durations
    i = 0
    fd_name = "results/" + str(ARGS.buffer) + "_" + str(
        ARGS.replay) + "_" + str(
            ARGS.pmethod) + '_' + ARGS.env + "_durations0.txt"
    exists = os.path.isfile(fd_name)
    while exists:
        i += 1
        fd_name = "results/" + str(ARGS.buffer) + "_" + str(
            ARGS.replay) + "_" + str(
                ARGS.pmethod) + '_' + ARGS.env + "_durations%d.txt" % i
        exists = os.path.isfile(fd_name)
    fd = open(fd_name, "w+")

    # create new file to store rewards
    i = 0
    fr_name = "results/" + str(ARGS.buffer) + "_" + str(
        ARGS.replay) + "_" + str(
            ARGS.pmethod) + '_' + ARGS.env + "_rewards0.txt"
    exists = os.path.isfile(fr_name)
    while exists:
        i += 1
        fr_name = "results/" + str(ARGS.buffer) + "_" + str(
            ARGS.replay) + "_" + str(
                ARGS.pmethod) + '_' + ARGS.env + "_rewards%d.txt" % i
        exists = os.path.isfile(fr_name)
    fr = open(fr_name, "w+")

    # Save experiment hyperparams
    i = 0
    exists = os.path.isfile("results/" + str(ARGS.buffer) + "_" +
                            str(ARGS.replay) + "_" + str(ARGS.pmethod) + '_' +
                            ARGS.env + "_info0.txt")
    while exists:
        i += 1
        exists = os.path.isfile("results/" + str(ARGS.buffer) + "_" +
                                str(ARGS.replay) + "_" + str(ARGS.pmethod) +
                                '_' + ARGS.env + "_info%d.txt" % i)
    fi = open(
        "results/" + str(ARGS.buffer) + "_" + str(ARGS.replay) + "_" +
        str(ARGS.pmethod) + '_' + ARGS.env + "_info%d.txt" % i, "w+")
    file_counter = i
    fi.write(str(ARGS))
    fi.close()

    #-----------initialization---------------
    if ARGS.replay == 'PER':
        replay = memory[ARGS.replay](ARGS.buffer, ARGS.pmethod)
        filename = "results/" + str(ARGS.buffer) + "_" + 'weights_' + str(
            ARGS.replay
        ) + '_' + ARGS.pmethod + '_' + ARGS.env + "_%d.pt" % file_counter  # +'_.pt'
    else:
        replay = memory[ARGS.replay](ARGS.buffer)
        filename = "results/" + str(ARGS.buffer) + "_" + 'weights_' + str(
            ARGS.replay) + '_' + ARGS.env + "_%d.pt" % file_counter  #+'_.pt'

    model = network[ARGS.env]  # local network
    model_target = network[ARGS.env]  # target_network

    optimizer = optim.Adam(model.parameters(), ARGS.lr)

    global_steps = 0  # Count the steps (do not reset at episode start, to compute epsilon)
    episode_durations = []  #
    rewards_per_episode = []

    scores_window = deque(maxlen=100)
    eps = ARGS.EPS
    #-------------------------------------------------------

    for i_episode in tqdm(range(ARGS.num_episodes)):
        # YOUR CODE HERE
        # Sample a transition
        s = env.reset()
        done = False
        epi_duration = 0
        r_sum = 0
        for t in range(1000):
            # eps = get_epsilon(global_steps) # Comment this to to not use linear decay

            model.eval()
            a = select_action(model, s, eps)

            model.train()
            s_next, r, done, _ = env.step(a)

            beta = None
            if ARGS.replay == 'PER':
                state = torch.tensor(s,
                                     dtype=torch.float).to(device).unsqueeze(0)
                action = torch.tensor(
                    a, dtype=torch.int64).to(device).unsqueeze(
                        0)  # Need 64 bit to use them as index
                next_state = torch.tensor(
                    s_next, dtype=torch.float).to(device).unsqueeze(0)
                reward = torch.tensor(
                    r, dtype=torch.float).to(device).unsqueeze(0)
                done_ = torch.tensor(done,
                                     dtype=torch.uint8).to(device).unsqueeze(0)
                with torch.no_grad():
                    q_val = compute_q_val(model, state, action)
                    target = compute_target(model_target, reward, next_state,
                                            done_, ARGS.discount_factor)
                td_error = F.smooth_l1_loss(q_val, target)
                replay.push(td_error, (s, a, r, s_next, done))
                beta = get_beta(i_episode, ARGS.num_episodes, ARGS.beta0)
            else:
                replay.push((s, a, r, s_next, done))

            loss = train(model,
                         model_target,
                         replay,
                         optimizer,
                         ARGS.batch_size,
                         ARGS.discount_factor,
                         ARGS.TAU,
                         global_steps,
                         beta=beta)

            s = s_next
            epi_duration += 1
            global_steps += 1

            if done:
                break

            r_sum += r
            #visualize
            # env.render()

        eps = max(0.01, ARGS.eps_decay * eps)
        rewards_per_episode.append(r_sum)
        episode_durations.append(epi_duration)
        scores_window.append(r_sum)

        # store episode data in files
        fr.write("%d\n" % r_sum)
        fr.close()
        fr = open(fr_name, "a")

        fd.write("%d\n" % epi_duration)
        fd.close()
        fd = open(fd_name, "a")

        if i_episode % 100 == 0:
            print('\rEpisode {}\tAverage Score: {:.2f}'.format(
                i_episode, np.mean(scores_window)))
        # if np.mean(scores_window)>=200.0:
        #     print('\nEnvironment solved in {:d} episodes!\tAverage Score: {:.2f}'.format(i_episode-100, np.mean(scores_window)))
        # break

        # if epi_duration >= 500: # this value is environment dependent
        #     print("Failed to complete in trial {}".format(i_episode))

        # else:
        # print("Completed in {} trials".format(i_episode))
        # break

    # close files
    fd.close()
    fr.close()

    env.close()

    print(f"Saving weights to {filename}")
    torch.save(
        {
            # You can add more here if you need, e.g. critic
            'policy':
            model.state_dict()  # Always save weights rather than objects
        },
        filename)

    plt.plot(smooth(episode_durations, 10))
    plt.title('Episode durations per episode')
    #plt.show()
    plt.savefig("images/" + str(ARGS.buffer) + "_" + str(ARGS.replay) + '_' +
                ARGS.pmethod + '_' + ARGS.env + '_Episode' +
                "%d.png" % file_counter)

    plt.plot(smooth(rewards_per_episode, 10))
    plt.title("Rewards per episode")
    #plt.show()
    plt.savefig("images/" + str(ARGS.buffer) + "_" + str(ARGS.replay) + '_' +
                ARGS.pmethod + '_' + ARGS.env + '_Rewards' +
                "%d.png" % file_counter)
    return episode_durations
Пример #6
0
    def sample(self, n):
        return self.container.get_batch(n)

    def update(self, idx, error):
        self.container.update(idx, error)

    def __len__(self):
        return self.container.get_len()


#sanity check
if __name__=="__main__":

    capacity = 10
    memory = PrioritizedReplayMemory(capacity)#CombinedReplayMemory(capacity)#NaiveReplayMemory(capacity)

    env, _ = get_env("Acrobot-v1")

    # Sample a transition
    s = env.reset()
    a = env.action_space.sample()
    s_next, r, done, _ = env.step(a)

    # Push a transition
    err = 0.5
    memory.push(err,(s, a, r, s_next, done))

    # Sample a batch size of 1
    print(memory.sample(1))
Пример #7
0
from collections import OrderedDict
import json
import sys
from environment import get_env

Lock = get_env("Lock")


class EOF(Exception):
    pass


class InvalidInput(Exception):
    pass


# TODO error handling


class Connection(object):
    def __init__(self, read, write):
        self._read = read
        self._write = write
        self._write_lock = Lock()

    def serve(self, once=False):
        state = 0
        while not self._read.closed:
            if state == 0:
                headers = {}
                state = 1
Пример #8
0
 def __init__(self, params, sess):
     self.env = get_env(params, tf.estimator.ModeKeys.EVAL)
     self.params = params
     self.sess = sess
Пример #9
0
    def forward(self, state):
        
        x = F.relu(self.fc1(state))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

#TODO: implement a network based on obeservation. Also use CNNs




if __name__=="__main__":
    # Let's instantiate and test if it works
    num_hidden = 128
    torch.manual_seed(1234)
    env, _size = get_env("Acrobot-v1")
    input_size, output_size = _size
    # Sample a transition
    s = env.reset()
    a = env.action_space.sample()
    s_next, r, done, _ = env.step(a)

    model = QNetwork(input_size, output_size, num_hidden)

    torch.manual_seed(1234)
    test_model = nn.Sequential(
        nn.Linear(input_size, num_hidden), 
        nn.ReLU(), 
        nn.Linear(num_hidden, output_size)
    )
Пример #10
0
from collections import deque
import json
import os
import time
from connection import Resolver, EOF
from environment import get_env

Lock, sleep, spawn, subprocess, Popen = get_env(
    "Lock", "sleep", "spawn", "subprocess", "Popen"
)


class CannotRunError(Exception):
    def __init__(self, message, originalException=None):
        self.originalException = originalException
        super(CannotRunError, self).__init__(message)


# The requests implementation is more complex here if you don't want to send
# a different request per plugin based on when you last got results from it.
class PluginBridge(object):
    def __init__(self, name, metadata_path):
        self._name = name
        self._metadata_dir = os.path.dirname(os.path.abspath(metadata_path))
        with open(metadata_path, "rb") as f:
            self._metadata = json.load(f)
        self._proc = None
        self._resolver = False
        self._requests = deque()
        self._requests_lock = Lock()
        self._request_start_time = None
Пример #11
0
lams = np.zeros(shape = (len(demand_points_list), time_horizon))                            # arrival rates for hospitalisation for each demand point for different 
for i in range(len(demand_points_list)):
  lams[i] = np.random.normal(hospitalisation_mean[i], hospitalisation_sd[i],  time_horizon)
burn_rate_mean = [6, 6, 6, 6, 6]                                                            # burn rate mean (burn rate is calculated per day per patient)
burn_rate_sd = [2, 2, 2, 2, 2]                                                              # burn rate sd             
death_rate = 0.04                                                                           # Proportion of people hospitalised that die
mean_LOS_death = 7                                                                          # Avergae LOS (Length of Stay) for patients that die (in days)
mean_LOS_discharge = 10                                                                     # Average LOS for patients that get discarhegd (in days)
avg_interarrival_time_supply = 7                                                            # On an average weekly interarrival between supplies 
supply_amount = 10000 
vehicles = 100



# Rendering the environement
env = get_env(demand_points_list, vehicles, time_horizon)
env.set(initial_inv = initial_inv , lams = lams, hospitalisation_trend = hospitalisation_trend, 
                    burn_rate_mean = burn_rate_mean, burn_rate_sd = burn_rate_sd, death_rate = death_rate , mean_LOS_death = mean_LOS_death, 
                    mean_LOS_discharge = mean_LOS_discharge, average_interarrival_time = avg_interarrival_time_supply, supply_amount = supply_amount)



# Simulating states/state variables if no action is taken 
env.actionlesss_simulation()

## Plotting the results of the simulation

# Plotting hospitalisations, deaths, discharges and active cases for each demand point by day
plot.plot_cases(time_horizon, env, demand_points_list)

# Plotting PPE inventory for each demand point by day