コード例 #1
0
ファイル: scheduler.py プロジェクト: lab821/CSSim
 def __init__(self):
     self.agent = DQN()
     self.last_state = np.zeros(6 * NUM_A + 7 * NUM_F, dtype=np.int)
     self.last_action = 0
     self.last_througout = 0
     self.last_reward = 0
     self.key = []
コード例 #2
0
def init_agents(sess, info_state_size, num_actions, hidden_layers_sizes,
                **kwargs):
    agents = [
        DQN(sess, 0, info_state_size, num_actions, hidden_layers_sizes,
            **kwargs),
        agent.RandomAgent(1)
    ]
    sess.run(tf.global_variables_initializer())

    return agents
コード例 #3
0
ファイル: offline_metalearner.py プロジェクト: vitchyr/BOReL
 def initialize_policy(self):
     if self.args.policy == 'dqn':
         q_network = FlattenMlp(input_size=self.args.augmented_obs_dim,
                                output_size=self.args.act_space.n,
                                hidden_sizes=self.args.dqn_layers).to(
                                    ptu.device)
         self.agent = DQN(
             q_network,
             # optimiser_vae=self.optimizer_vae,
             lr=self.args.policy_lr,
             gamma=self.args.gamma,
             tau=self.args.soft_target_tau,
         ).to(ptu.device)
     else:
         # assert self.args.act_space.__class__.__name__ == "Box", (
         #     "Can't train SAC with discrete action space!")
         q1_network = FlattenMlp(
             input_size=self.args.augmented_obs_dim + self.args.action_dim,
             output_size=1,
             hidden_sizes=self.args.dqn_layers).to(ptu.device)
         q2_network = FlattenMlp(
             input_size=self.args.augmented_obs_dim + self.args.action_dim,
             output_size=1,
             hidden_sizes=self.args.dqn_layers).to(ptu.device)
         policy = TanhGaussianPolicy(
             obs_dim=self.args.augmented_obs_dim,
             action_dim=self.args.action_dim,
             hidden_sizes=self.args.policy_layers).to(ptu.device)
         self.agent = SAC(
             policy,
             q1_network,
             q2_network,
             actor_lr=self.args.actor_lr,
             critic_lr=self.args.critic_lr,
             gamma=self.args.gamma,
             tau=self.args.soft_target_tau,
             use_cql=self.args.use_cql if 'use_cql' in self.args else False,
             alpha_cql=self.args.alpha_cql
             if 'alpha_cql' in self.args else None,
             entropy_alpha=self.args.entropy_alpha,
             automatic_entropy_tuning=self.args.automatic_entropy_tuning,
             alpha_lr=self.args.alpha_lr,
             clip_grad_value=self.args.clip_grad_value,
         ).to(ptu.device)
コード例 #4
0
ファイル: run.py プロジェクト: ZikangXiong/ToyRLAlgorithms
def dqn_train(env_name, device="cpu", seed=0):
    from algorithms.dqn import DQN

    hyper = copy.deepcopy(HYPER_PARAM["dqn"][env_name])
    pixel_input = hyper.pop("pixel_input")
    buffer_size = hyper.pop("buffer_size")

    env = gym.make(env_name)

    dqn = DQN(env, pixel_input, buffer_size, device, seed=seed)
    dqn.learn(**hyper)

    dqn.save(f"{ROOT}/pretrain/{env_name}/dqn.pth")
コード例 #5
0
ファイル: train_demo.py プロジェクト: lab821/SDN-DRL
def main():
    # initialize OpenAI Gym env and dqn agent
    #   env = gym.make(ENV_NAME)
    env = AutoEnv()
    agent = DQN(env)

    for episode in range(EPISODE):
        # initialize task
        state = env.reset()
        # Train
        for step in range(STEP):
            action = agent.egreedy_action(state)  # e-greedy action for train
            print(action)
            print(toMeterList(action, env.action_space.content))
            meter_l = toMeterList(action, env.action_space.content)
            n_action = genAction(state, meter_l)
            next_state, reward, done, _ = env.step(n_action)
            print("State:")
            print_state(state)
            print("Next State: ")
            print_state(next_state)
            print("Env Reward: ", reward)
            # Define reward for agent
            # reward = -1 if done else 0.1
            reward = rewardReg(reward)
            print("Agent Reward: ", reward)
            agent.perceive(state, action, reward, next_state, done)
            state = next_state
            if done:
                break
        # Test every 100 episodes
        if episode % 100 == 0:
            total_reward = 0
            for i in range(TEST):
                state = env.reset()
                for j in range(STEP):
                    #   env.render()
                    action = agent.action(state)  # direct action for test
                    meter_l = toMeterList(action, env.action_space.content)
                    print(action, meter_l)
                    n_action = genAction(state, meter_l)
                    state, reward, done, _ = env.step(n_action)
                    print("Reward: ", reward)
                    total_reward += reward
                    if done:
                        break
            ave_reward = total_reward / TEST
            print('episode: ', episode, 'Evaluation Average Reward:',
                  ave_reward)
コード例 #6
0
ファイル: run.py プロジェクト: ZikangXiong/ToyRLAlgorithms
def dqn_eval(env_name):
    from algorithms.dqn import DQN

    env = gym.make(env_name)
    dqn = DQN.load(f"{ROOT}/pretrain/{env_name}/dqn.pth", env=env)
    env = dqn.env

    for _ in range(5):
        obs = env.reset()
        env.render()
        ep_reward = 0
        while True:
            action = dqn.predict(obs)
            obs, reward, done, info = env.step(action)
            env.render()
            ep_reward += reward

            if done:
                print(f"reward: {ep_reward}")
                break
コード例 #7
0
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 200
TARGET_UPDATE = 10

# Get screen size so that we can initialize layers correctly based on shape
# returned from AI gym. Typical dimensions at this point are close to 3x40x90
# which is the result of a clamped and down-scaled render buffer in get_screen()
init_screen = get_screen(env, device)
_, _, screen_height, screen_width = init_screen.shape

# Get number of actions from gym action space
n_actions = env.action_space.n

# Init the policy and the target net
policy_net = DQN(screen_height, screen_width, n_actions).to(device)
target_net = DQN(screen_height, screen_width, n_actions).to(device)
if resume_model:
    policy_net.load_state_dict(torch.load(model_path))

target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

# Init the optimizer
optimizer = optim.RMSprop(policy_net.parameters())
memory = ReplayMemory(10000, Transition)

steps_done = 0

episode_durations = []
コード例 #8
0
def main(unused_argv):
    begin = time.time()
    env = Go()
    info_state_size = env.state_size
    num_actions = env.action_size

    hidden_layers_sizes = [int(l) for l in FLAGS.hidden_layers_sizes]
    kwargs = {
        "replay_buffer_capacity": FLAGS.replay_buffer_capacity,
        "epsilon_decay_duration": int(0.6 * FLAGS.num_train_episodes),
        "epsilon_start": 0.8,
        "epsilon_end": 0.001,
        "learning_rate": 1e-3,
        "learn_every": FLAGS.learn_every,
        "batch_size": 128,
        "max_global_gradient_norm": 10,
    }
    import agent.agent as agent
    ret = [0]
    max_len = 2000

    with tf.Session() as sess:
        # agents = [DQN(sess, _idx, info_state_size,
        #                   num_actions, hidden_layers_sizes, **kwargs) for _idx in range(2)]  # for self play
        agents = [
            agent.RandomAgent(1),
            DQN(sess, 1, info_state_size, num_actions, hidden_layers_sizes,
                **kwargs)
        ]
        sess.run(tf.global_variables_initializer())
        # train the agent
        for ep in range(FLAGS.num_train_episodes):
            if (ep + 1) % FLAGS.save_every == 0:
                if not os.path.exists("saved_model/random_vs_dqn"):
                    os.mkdir('saved_model/random_vs_dqn')
                agents[1].save(checkpoint_root='saved_model/random_vs_dqn',
                               checkpoint_name='random_vs_dqn_{}'.format(ep +
                                                                         1))
                print('saved %d' % (ep + 1))
            time_step = env.reset()  # a go.Position object
            while not time_step.last():
                player_id = time_step.observations["current_player"]
                agent_output = agents[player_id].step(time_step)
                action_list = agent_output.action
                # print(action_list)
                time_step = env.step(action_list)
            for agent in agents:
                agent.step(time_step)
            if len(ret) < max_len:
                ret.append(time_step.rewards[0])
            else:
                ret[ep % max_len] = time_step.rewards[0]

        # evaluated the trained agent
        agents[1].restore("saved_model/random_vs_dqn/random_vs_dqn_10000")
        ret = []
        for ep in range(FLAGS.num_eval):
            time_step = env.reset()
            while not time_step.last():
                player_id = time_step.observations["current_player"]
                if player_id == 0:
                    agent_output = agents[player_id].step(time_step)
                else:
                    agent_output = agents[player_id].step(
                        time_step,
                        is_evaluation=True,
                        add_transition_record=False)
                action_list = agent_output.action
                time_step = env.step(action_list)

            # Episode is over, step all agents with final info state.
            # for agent in agents:
            agents[0].step(time_step)
            agents[1].step(time_step,
                           is_evaluation=True,
                           add_transition_record=False)
            ret.append(time_step.rewards[0])
        print(np.mean(ret))
        # print(ret)

    print('Time elapsed:', time.time() - begin)
コード例 #9
0
ファイル: run_ft.py プロジェクト: AndreaSoprani/T2VT-RL
n_actions = mdps[0].action_space.n

layers = [l1]
if l2 > 0:
    layers.append(l2)

if not dqn:
    # Create BellmanOperator
    operator = MellowBellmanOperator(kappa, tau, xi, mdps[0].gamma, state_dim,
                                     action_dim)
    # Create Q Function
    Q = MLPQFunction(state_dim, n_actions, layers=layers)
else:
    Q, operator = DQN(state_dim,
                      action_dim,
                      n_actions,
                      mdps[0].gamma,
                      layers=layers)


def run(mdp, seed=None, idx=0):
    Q._w = ws[idx]
    return learn(mdp,
                 Q,
                 operator,
                 max_iter=max_iter,
                 buffer_size=buffer_size,
                 batch_size=batch_size,
                 alpha=alpha,
                 train_freq=train_freq,
                 eval_freq=eval_freq,
コード例 #10
0
if l2 > 0:
    layers.append(l2)

if not dqn:
    # Create BellmanOperator
    operator = MellowBellmanOperator(kappa, tau, xi, temp_mdp.gamma, state_dim,
                                     action_dim)
    # Create Q Function
    Q = MLPQFunction(state_dim,
                     n_actions,
                     layers=layers,
                     activation=activation)
else:
    Q, operator = DQN(state_dim,
                      action_dim,
                      n_actions,
                      temp_mdp.gamma,
                      layers=layers)


def run(data, seed=None):
    return learn(Q,
                 operator,
                 data,
                 demand,
                 min_env_flow,
                 max_iter=max_iter,
                 buffer_size=buffer_size,
                 batch_size=batch_size,
                 alpha=alpha,
                 train_freq=train_freq,
コード例 #11
0
def main(unused_argv):
    begin = time.time()
    env = Go()
    info_state_size = env.state_size
    num_actions = env.action_size

    hidden_layers_sizes = [int(l) for l in FLAGS.hidden_layers_sizes]
    kwargs = {
        "replay_buffer_capacity": FLAGS.replay_buffer_capacity,
        "epsilon_decay_duration": int(0.6*FLAGS.num_train_episodes),
        "epsilon_start": 0.8,
        "epsilon_end": 0.001,
        "learning_rate": 1e-3,
        "learn_every": FLAGS.learn_every,
        "batch_size": 128,
        "max_global_gradient_norm": 10,
    }
    import agent.agent as agent
    ret = [0]
    max_len = 2000

    with tf.Session() as sess:
        agents = [DQN(sess, 0, info_state_size,
            num_actions, hidden_layers_sizes, **kwargs), agent.Random_Rollout_MCTS_Agent(max_simulations=50)]
        sess.run(tf.global_variables_initializer())

        # train the agent
        for ep in range(FLAGS.num_train_episodes):
            if (ep + 1) % FLAGS.eval_every == 0:
                losses = agents[0].loss
                logging.info("Episodes: {}: Losses: {}, Rewards: {}".format(ep + 1, losses, np.mean(ret)))
                with open('log/log_{}_{}'.format(os.environ.get('BOARD_SIZE'), begin), 'a+') as log_file:
                    log_file.writelines("{}, {}\n".format(ep+1, np.mean(ret)))
            if (ep + 1) % FLAGS.save_every == 0:
                if not os.path.exists("saved_model"):
                    os.mkdir('saved_model')
                agents[0].save(checkpoint_root='saved_model', checkpoint_name='{}'.format(ep+1))
            time_step = env.reset()  # a go.Position object
            while not time_step.last():
                player_id = time_step.observations["current_player"]
                if player_id == 0:
                    agent_output = agents[player_id].step(time_step)
                else:
                    agent_output = agents[player_id].step(time_step, env)
                action_list = agent_output.action
                time_step = env.step(action_list)
            # for agent in agents:
            agents[0].step(time_step)
            agents[1].step(time_step, env)
            if len(ret) < max_len:
                ret.append(time_step.rewards[0])
            else:
                ret[ep % max_len] = time_step.rewards[0]

        # evaluated the trained agent
        agents[0].restore("saved_model/10000")
        ret = []
        for ep in range(FLAGS.num_eval):
            time_step = env.reset()
            while not time_step.last():
                player_id = time_step.observations["current_player"]
                if player_id == 0:
                    agent_output = agents[player_id].step(time_step, is_evaluation=True, add_transition_record=False)
                else:
                    agent_output = agents[player_id].step(time_step, env)
                action_list = agent_output.action
                time_step = env.step(action_list)

            # Episode is over, step all agents with final info state.
            # for agent in agents:
            agents[0].step(time_step, is_evaluation=True, add_transition_record=False)
            agents[1].step(time_step, env)
            ret.append(time_step.rewards[0])
        print(np.mean(ret))

    print('Time elapsed:', time.time()-begin)
コード例 #12
0
ファイル: test_dqn.py プロジェクト: foliag/RL_example
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Apr 17 18:55:03 2019

@author: clytie
"""

if __name__ == "__main__":
    import numpy as np
    import time
    from tqdm import tqdm
    from env.dist_env import BreakoutEnv
    from algorithms.dqn import DQN

    DQNetwork = DQN(4, (84, 84, 4),
                    epsilon_schedule=lambda x: 0,
                    save_path="./dqn_log")
    env = BreakoutEnv(4999, num_envs=1, mode="test")
    env_ids, states, _, _ = env.start()
    for _ in tqdm(range(10000)):
        time.sleep(0.1)
        actions = DQNetwork.get_action(np.asarray(states))
        env_ids, states, _, _ = env.step(env_ids, actions)
    env.close()
コード例 #13
0
ファイル: learner.py プロジェクト: matants/OMRL_MER
    def initialize_policy(self):

        if self.args.policy == 'dqn':
            assert self.args.act_space.__class__.__name__ == "Discrete", (
                "Can't train DQN with continuous action space!")
            q_network = FlattenMlp(input_size=self.args.obs_dim,
                                   output_size=self.args.act_space.n,
                                   hidden_sizes=self.args.dqn_layers)
            self.agent = DQN(
                q_network,
                # optimiser_vae=self.optimizer_vae,
                lr=self.args.policy_lr,
                gamma=self.args.gamma,
                eps_init=self.args.dqn_epsilon_init,
                eps_final=self.args.dqn_epsilon_final,
                exploration_iters=self.args.dqn_exploration_iters,
                tau=self.args.soft_target_tau,
            ).to(ptu.device)
        # elif self.args.policy == 'ddqn':
        #     assert self.args.act_space.__class__.__name__ == "Discrete", (
        #         "Can't train DDQN with continuous action space!")
        #     q_network = FlattenMlp(input_size=self.args.obs_dim,
        #                            output_size=self.args.act_space.n,
        #                            hidden_sizes=self.args.dqn_layers)
        #     self.agent = DoubleDQN(
        #         q_network,
        #         # optimiser_vae=self.optimizer_vae,
        #         lr=self.args.policy_lr,
        #         eps_optim=self.args.dqn_eps,
        #         alpha_optim=self.args.dqn_alpha,
        #         gamma=self.args.gamma,
        #         eps_init=self.args.dqn_epsilon_init,
        #         eps_final=self.args.dqn_epsilon_final,
        #         exploration_iters=self.args.dqn_exploration_iters,
        #         tau=self.args.soft_target_tau,
        #     ).to(ptu.device)
        elif self.args.policy == 'sac':
            assert self.args.act_space.__class__.__name__ == "Box", (
                "Can't train SAC with discrete action space!")
            q1_network = FlattenMlp(input_size=self.args.obs_dim +
                                    self.args.action_dim,
                                    output_size=1,
                                    hidden_sizes=self.args.dqn_layers)
            q2_network = FlattenMlp(input_size=self.args.obs_dim +
                                    self.args.action_dim,
                                    output_size=1,
                                    hidden_sizes=self.args.dqn_layers)
            policy = TanhGaussianPolicy(obs_dim=self.args.obs_dim,
                                        action_dim=self.args.action_dim,
                                        hidden_sizes=self.args.policy_layers)
            self.agent = SAC(
                policy,
                q1_network,
                q2_network,
                actor_lr=self.args.actor_lr,
                critic_lr=self.args.critic_lr,
                gamma=self.args.gamma,
                tau=self.args.soft_target_tau,
                entropy_alpha=self.args.entropy_alpha,
                automatic_entropy_tuning=self.args.automatic_entropy_tuning,
                alpha_lr=self.args.alpha_lr).to(ptu.device)
        else:
            raise NotImplementedError
コード例 #14
0
ファイル: scheduler.py プロジェクト: lab821/CSSim
class DQNscheduler():
    def __init__(self):
        self.agent = DQN()
        self.last_state = np.zeros(6 * NUM_A + 7 * NUM_F, dtype=np.int)
        self.last_action = 0
        self.last_througout = 0
        self.last_reward = 0
        self.key = []

    def train(self, actq, cptq):
        '''
        Generating control strategy and training model based on current flow information    
        input:
            actq: the infomation of active flows
            cptq: the infomation fo completed flows     
        '''
        #state
        state = self.stateparser(actq, cptq)

        #get action
        action = self.agent.egreedy_action(state)  # e-greedy action for train

        #reward
        current_throughout = self.throughout(cptq)
        if self.last_througout == 0:
            if current_throughout == 0:
                reward = 0
            else:
                reward = 1
        else:
            reward = current_throughout / self.last_througout
            if reward > 1:
                # (0, 1) U (1, +)
                reward = reward / 10
            else:
                # (-1, 0)
                reward = reward - 1

        done = False

        if reward != 0:
            #train
            self.agent.perceive(self.last_state, self.last_action,
                                self.last_reward, state, done)

        #record state action and throughout
        self.last_state = state
        self.last_action = action
        self.last_reward = reward
        self.last_througout = current_throughout

        #analyzing the meaning of actions
        ret = self.actionparser(action)
        infostr = self.getinfo(state, action, reward)

        return ret, infostr

    def throughout(self, cptq):
        '''
        Computing the bandwidth of the completed flows
        Input:
            cptq: the infomation of completed flows
        '''
        res = 0.0
        for index, row in cptq.iterrows():
            res += row['size'] / row['duration']
        return res

    def stateparser(self, actq, cptq):
        '''
        Converting the active and completed flows information to a 1*136 state space
        Intput:
            actq: the infomation of active flows
            cptq: the infomation fo completed flows
        '''
        temp = actq.sort_values(by='sentsize')
        active_num = NUM_A
        finished_num = NUM_F
        state = np.zeros(active_num * 6 + finished_num * 7, dtype=np.int)
        i = 0
        self.key = []
        self.qindex_list = []
        for index, row in temp.iterrows():
            if i > active_num:
                break
            else:
                state[6 * i] = row['src']
                state[6 * i + 1] = row['dst']
                state[6 * i + 2] = row['protocol']
                state[6 * i + 3] = row['sp']
                state[6 * i + 4] = row['dp']
                state[6 * i + 5] = row['priority']
                self.key.append(index)
                self.qindex_list.append(row['qindex'])
            i += 1
        i = active_num
        for index, row in cptq.iterrows():
            state[6 * active_num + 7 * (i - active_num)] = row['src']
            state[6 * active_num + 7 * (i - active_num) + 1] = row['dst']
            state[6 * active_num + 7 * (i - active_num) + 2] = row['protocol']
            state[6 * active_num + 7 * (i - active_num) + 3] = row['sp']
            state[6 * active_num + 7 * (i - active_num) + 4] = row['dp']
            state[6 * active_num + 7 * (i - active_num) + 5] = row['duration']
            state[6 * active_num + 7 * (i - active_num) + 6] = row['size']
            i += 1
        return state

    def actionparser(self, action):
        '''
        Converting 11-bit integer to control information
        Input:
            action: 11-bit integer as action
        '''
        bstr = ('{:0%sb}' % (NUM_A)).format(action)
        res = {}
        for i in range(len(self.key)):
            res[self.key[i]] = int(bstr[-1 - i])
        return res

    def getinfo(self, state, action, reward):
        '''
        Generating the log info of this time training
        Input:
            state: state space
            action: action space
            reward: current reward
        '''
        infostr = ''
        line = '%50s\n' % (50 * '*')
        rewardstr = 'Evaluation Reward: %f\n' % reward
        policy = 'State and action:\n'
        bstr = ('{:0%sb}' % (NUM_A)).format(action)
        for i in range(NUM_A):
            if i >= len(self.key):
                break
            else:
                policy += 'Queue index:%d, Five tuple={%d,%d,%d,%d,%d}, priority: %d, action:%s\n' % (
                    self.qindex_list[i], state[6 * i], state[6 * i + 1],
                    state[6 * i + 2], state[6 * i + 3], state[6 * i + 4],
                    state[6 * i + 5], bstr[-1 - i])
        infostr = line + rewardstr + policy + line
        return infostr
コード例 #15
0
    from algorithms.dqn import ReplayBuffer, DQN

    logging.basicConfig(level=logging.INFO,
                        format='%(asctime)s|%(levelname)s|%(message)s')

    memory = ReplayBuffer(max_size=500000)
    env = BreakoutEnv(49999, num_envs=20)
    env_ids, states, rewards, dones = env.start()
    print("pre-train: ")
    for _ in tqdm(range(5000)):
        env_ids, states, rewards, dones = env.step(
            env_ids, np.random.randint(env.action_space, size=env.num_srd))
    trajs = env.get_episodes()

    memory.add(trajs)
    DQNetwork = DQN(env.action_space, env.state_space, save_path="./dqn_log")

    print("start train: ")
    for step in range(10000000):
        for _ in range(20):
            actions = DQNetwork.get_action(np.asarray(states))
            env_ids, states, rewards, dones = env.step(env_ids, actions)
        if step % 10 == 0:
            logging.info(
                f'>>>>{env.mean_reward}, nth_step{step}, buffer{len(memory)}')
        trajs = env.get_episodes()
        memory.add(trajs)
        for _ in range(10):
            batch_samples = memory.sample(32)
            DQNetwork.update(batch_samples, sw_dir="dqn")
コード例 #16
0
def main(params):
    np.random.seed(params['seed'])
    torch.manual_seed(params['seed'])
    # declare environment
    is_goal = True
    if params['environment'] == 'acrobot_simple':
        env = SimpleAcrobotEnv(stochastic=False, max_steps=400, mean_goal=-1.5)
        s, goal = env.reset()
    elif params['environment'] == 'windy_grid_world':
        env = GridworldEnv()
        s, goal = env.perform_reset()
    else:
        env = gym.make(params['environment'])
        s = env.reset()
        goal = s
        is_goal = False

    state_shape = s.shape[0] + goal.shape[0]

    # select type of experience replay using the parameters
    if params['buffer'] == ReplayBuffer:
        buffer = ReplayBuffer(params['buffer_size'])
        loss_function = params['loss_function']()
    elif params['buffer'] == PrioritizedReplayBuffer:
        buffer = PrioritizedReplayBuffer(params['buffer_size'], params['PER_alpha'], params['PER_beta'])
        loss_function = params['loss_function'](reduction='none')
    elif params['buffer'] == HindsightReplayBuffer:
        buffer = HindsightReplayBuffer(params['buffer_size'])
        loss_function = params['loss_function']()
    elif params['buffer'] == PrioritizedHindsightReplayBuffer:
        buffer = PrioritizedHindsightReplayBuffer(params['buffer_size'], params['PER_alpha'], params['PER_beta'])
        loss_function = params['loss_function'](reduction='none')
    else:
        raise ValueError('Buffer type not found.')

    # select learning algorithm using the parameters
    if params['algorithm'] == DQN:
        algorithm = DQN(state_shape,
                        env.action_space.n,
                        loss_function=loss_function,
                        optimizer=params['optimizer'],
                        lr=params['lr'],
                        gamma=params['gamma'],
                        epsilon_delta=1 / (params['epsilon_delta_end'] * params['train_steps']),
                        epsilon_min=params['epsilon_min'])
    elif params['algorithm'] == algo_DQN:
        algorithm = algo_DQN()
    else:
        raise ValueError('Algorithm type not found.')

    losses = []
    returns = []
    train_steps = 0
    episodes_length = []
    episodes_length_test = []

    print('Starting to train:', type(buffer))
    test_lengths = test(algorithm, env)
    episodes_length_test.append(test_lengths)

    while train_steps < params['train_steps']:
        if isinstance(env, GridworldEnv):
            obs_t, goal = env.perform_reset()
        elif is_goal:
            obs_t, goal = env.reset()
        else:
            obs_t = env.reset()
            goal = np.zeros_like(obs_t)

        t = 0
        episode_loss = []
        episode_rewards = []
        episode_transitions = []
        while train_steps < params['train_steps']:
            # env.render()
            action = algorithm.predict(np.hstack((obs_t, goal)))
            t += 1
            if isinstance(env, GridworldEnv):
                obs_tp1, reward, done, _ = env.perform_step(action)
                transition = (obs_t, goal, action, reward, obs_tp1, done)
            elif is_goal:
                obs_tp1, reward, done, _, gr = env.step(action)
                transition = (obs_t, goal, action, reward, obs_tp1, gr, done)
            else:
                obs_tp1, reward, done, _ = env.step(action)
                transition = (obs_t, goal, action, reward, obs_tp1, done)
            episode_transitions.append(transition)
            episode_rewards.append(reward)
            if len(buffer) >= params['batch_size']:
                loss = update(algorithm, buffer, params, train_steps)
                train_steps += 1
                episode_loss.append(loss)
                if train_steps % params['test_every'] == 0:
                    test_lengths = test(algorithm, env)
                    episodes_length_test.append(test_lengths)
            # termination condition
            if done:
                episodes_length.append(t)
                break

            obs_t = obs_tp1

        special_goal = isinstance(env, CustomAcrobotEnv) or isinstance(env, SimpleAcrobotEnv)
        add_transitions_to_buffer(episode_transitions, buffer, special_goal=special_goal)
        losses.append(np.mean(episode_loss))
        returns.append(np.sum(episode_rewards))

    env.close()
    return episodes_length_test, returns, losses