class SAC_Agent:
    def __init__(self, env, batch_size=256, gamma=0.99, tau=0.005, actor_lr=3e-4, critic_lr=3e-4, alpha_lr=3e-4):
        #Environment
        self.env = env
        state_dim = env.observation_space.shape[0]
        action_dim = env.action_space.shape[0]        
        
        #Hyperparameters
        self.batch_size = batch_size
        self.gamma = gamma
        self.tau = tau

        #Entropy
        self.alpha = 1
        self.target_entropy = -np.prod(env.action_space.shape).item()  # heuristic value
        self.log_alpha = torch.zeros(1, requires_grad=True, device="cuda")
        self.alpha_optimizer = optim.Adam([self.log_alpha], lr=alpha_lr)

        #Networks
        self.Q1 = SoftQNetwork(state_dim, action_dim).cuda()
        self.Q1_target = SoftQNetwork(state_dim, action_dim).cuda()
        self.Q1_target.load_state_dict(self.Q1.state_dict())
        self.Q1_optimizer = optim.Adam(self.Q1.parameters(), lr=critic_lr)

        self.Q2 = SoftQNetwork(state_dim, action_dim).cuda()
        self.Q2_target = SoftQNetwork(state_dim, action_dim).cuda()
        self.Q2_target.load_state_dict(self.Q2.state_dict())
        self.Q2_optimizer = optim.Adam(self.Q2.parameters(), lr=critic_lr)

        self.actor = PolicyNetwork(state_dim, action_dim).cuda()
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=actor_lr)
        
        self.loss_function = torch.nn.MSELoss()
        self.replay_buffer = ReplayBuffer()

    def act(self, state, deterministic=True):
        state = torch.tensor(state, dtype=torch.float, device="cuda")
        mean, log_std = self.actor(state)  
        if(deterministic):
            action = torch.tanh(mean)
        else:
            std = log_std.exp()
            normal = Normal(mean, std)
            z = normal.sample()
            action = torch.tanh(z)
        action = action.detach().cpu().numpy()
        return action

    def update(self, state, action, next_state, reward, done):

        self.replay_buffer.add_transition(state, action, next_state, reward, done)

        # Sample next batch and perform batch update: 
        batch_states, batch_actions, batch_next_states, batch_rewards, batch_dones = \
            self.replay_buffer.next_batch(self.batch_size)
        
        #Map to tensor
        batch_states = torch.tensor(batch_states, dtype=torch.float, device="cuda") #B,S_D
        batch_actions = torch.tensor(batch_actions, dtype=torch.float, device="cuda") #B,A_D
        batch_next_states = torch.tensor(batch_next_states, dtype=torch.float, device="cuda", requires_grad=False) #B,S_D
        batch_rewards = torch.tensor(batch_rewards, dtype=torch.float, device="cuda", requires_grad=False).unsqueeze(-1) #B,1
        batch_dones = torch.tensor(batch_dones, dtype=torch.uint8, device="cuda", requires_grad=False).unsqueeze(-1) #B,1

        #Policy evaluation
        with torch.no_grad():
            policy_actions, log_pi = self.actor.sample(batch_next_states)
            Q1_next_target = self.Q1_target(batch_next_states, policy_actions)
            Q2_next_target = self.Q2_target(batch_next_states, policy_actions)
            Q_next_target = torch.min(Q1_next_target, Q2_next_target)
            td_target = batch_rewards + (1 - batch_dones) * self.gamma * (Q_next_target - self.alpha * log_pi)

        Q1_value = self.Q1(batch_states, batch_actions)
        self.Q1_optimizer.zero_grad()
        loss = self.loss_function(Q1_value, td_target)
        loss.backward()
        #torch.nn.utils.clip_grad_norm_(self.Q1.parameters(), 1)
        self.Q1_optimizer.step()

        Q2_value = self.Q2(batch_states, batch_actions)
        self.Q2_optimizer.zero_grad()
        loss = self.loss_function(Q2_value, td_target)
        loss.backward()
        #torch.nn.utils.clip_grad_norm_(self.Q2.parameters(), 1)
        self.Q2_optimizer.step()

        #Policy improvement
        policy_actions, log_pi = self.actor.sample(batch_states)
        Q1_value = self.Q1(batch_states, policy_actions)
        Q2_value = self.Q2(batch_states, policy_actions)
        Q_value = torch.min(Q1_value, Q2_value)
        
        self.actor_optimizer.zero_grad()
        loss = (self.alpha * log_pi - Q_value).mean()
        loss.backward()
        #torch.nn.utils.clip_grad_norm_(self.actor.parameters(), 1)
        self.actor_optimizer.step()

        #Update entropy parameter 
        alpha_loss = (self.log_alpha * (-log_pi - self.target_entropy).detach()).mean()
        self.alpha_optimizer.zero_grad()
        alpha_loss.backward()
        self.alpha_optimizer.step()
        self.alpha = self.log_alpha.exp()
        
        #Update target networks
        soft_update(self.Q1_target, self.Q1, self.tau)
        soft_update(self.Q2_target, self.Q2, self.tau)

    def save(self, file_name):
        torch.save({'actor_dict': self.actor.state_dict(),
                    'Q1_dict' : self.Q1.state_dict(),
                    'Q2_dict' : self.Q2.state_dict(),
                }, file_name)

    def load(self, file_name):
        if os.path.isfile(file_name):
            print("=> loading checkpoint... ")
            checkpoint = torch.load(file_name)
            self.actor.load_state_dict(checkpoint['actor_dict'])
            self.Q1.load_state_dict(checkpoint['Q1_dict'])
            self.Q2.load_state_dict(checkpoint['Q2_dict'])
            print("done !")
        else:
            print("no checkpoint found...")
Exemple #2
0
def calc_po_best_response_PER(poacher, target_poacher, po_copy_op, po_good_copy_op, patrollers, pa_s, pa_type,
       iteration, sess, env, args, final_utility, starting_e, train_episode_num = None):
    '''
    Given a list of patrollers, and their types (DQN, PARAM, RS)
    Train a DQN poacher as the approximating best response
    Args:
        poacher: DQN poacher
        target_poacher: target DQN poacher
        po_copy_op: tensorflow copy opertaions, copy the weights from DQN to the target DQN
        po_good_copy_op: tensorflow copy operations,  save the trained ever-best poacher DQN
        patrollers: a list of patrollers
        pa_s: the patroller mixed startegy among the list of patrollers
        pa_type: a list specifying the type of each patroller, {'DQN', 'PARAM', 'RS'}
        iteration: the current DO iterations
        sess: tensorflow sess
        env: the game environment
        args: some args 
        final_utility: record the best response utility 
        starting_e: the starting of the training epoch
    Return:
        Nothing explictly returned  due to multithreading.
        The best response utility is returned in $final_utility$
        The best response DQN is copied through the $po_good_copy_op$
    '''

    #print('FIND_poacher_best_response iteration: ' + str(iteration))
    if train_episode_num is None:
        train_episode_num = args.po_episode_num

    decrease_time = 1.0 / args.epsilon_decrease
    epsilon_decrease_every = train_episode_num // decrease_time

    if not args.PER:
        replay_buffer = ReplayBuffer(args, args.po_replay_buffer_size)
    else:
        replay_buffer = PERMemory(args)
    pa_strategy = pa_s
    best_utility = -10000.0
    test_utility = []

    if starting_e == 0:
        log = open(args.save_path + 'po_log_train_iter_' + str(iteration) + '.dat', 'w')
        test_log = open(args.save_path + 'po_log_test_iter_' + str(iteration) +  '.dat', 'w')
    else:
        log = open(args.save_path + 'po_log_train_iter_' + str(iteration) + '.dat', 'a')
        test_log = open(args.save_path + 'po_log_test_iter_' + str(iteration) +  '.dat', 'a')

    epsilon = 1.0
    learning_rate = args.po_initial_lr
    global_step = 0
    action_id = {
        ('still', 1): 0,
        ('up', 0): 1,
        ('down', 0): 2,
        ('left', 0): 3,
        ('right', 0): 4
    }

    sess.run(po_copy_op)

    for e in range(starting_e, starting_e + train_episode_num):
        if e > 0 and e % epsilon_decrease_every == 0:
            epsilon = max(0.1, epsilon - args.epsilon_decrease)
        if e % args.mix_every_episode == 0 or e == starting_e:
            pa_chosen_strat = np.argmax(np.random.multinomial(1, pa_strategy))
            patroller = patrollers[pa_chosen_strat]
            type = pa_type[pa_chosen_strat]
        # if args.gui == 1 and e > 0 and e % args.gui_every_episode == 0:
        #     test_gui(poacher, patroller, sess, args, pah = heurestic_flag, poh = False)

        ### reset the environment
        poacher.reset_snare_num()
        pa_state, po_state = env.reset_game()
        episode_reward = 0.0
        pa_action = 'still'

        for t in range(args.max_time):
            global_step += 1
            transition = []

            ### transition adds current state
            transition.append(po_state)

            ### poacher chooses an action, if it has not been caught/returned home
            if not env.catch_flag and not env.home_flag: 
                po_state = np.array([po_state])
                snare_flag, po_action = poacher.infer_action(sess=sess, states=po_state, policy="epsilon_greedy", epsilon=epsilon, po_loc=env.po_loc, animal_density=env.animal_density)
            else:
                snare_flag = True
                po_action = 'still'
            
            transition.append(action_id[(po_action, snare_flag)])

            ### patroller chooses an action
            ### Note that heuristic and DQN agent has different APIs
            if type == 'DQN':
                pa_state = np.array([pa_state])  # Make it 2-D, i.e., [batch_size(1), state_size]
                pa_action = patroller.infer_action(sess=sess, states=pa_state, policy="greedy", pa_loc=env.pa_loc, animal_density=env.animal_density)
            elif type == 'PARAM':
                pa_loc = env.pa_loc
                pa_action = patroller.infer_action(pa_loc, env.get_local_po_trace(pa_loc), 1.5, -2.0, 8.0)
            elif type == 'RS':
                pa_loc = env.pa_loc
                footprints = []
                actions = ['up', 'down', 'left', 'right']
                for i in range(4,8):
                    if env.po_trace[pa_loc[0], pa_loc[1]][i] == 1:
                        footprints.append(actions[i - 4])
                pa_action = patroller.infer_action(pa_loc, pa_action, footprints)


            pa_state, _, po_state, po_reward, end_game = \
              env.step(pa_action, po_action, snare_flag)

           
            ### transition adds reward, and the new state
            transition.append(po_reward)
            transition.append(po_state)
           
            episode_reward += po_reward
            
            ### Add transition to replay buffer
            replay_buffer.add_transition(transition)

            ### Start training
            ### Sample a minibatch
            if replay_buffer.size >= args.batch_size:

                if not args.PER:
                    train_state, train_action, train_reward, train_new_state = \
                        replay_buffer.sample_batch(args.batch_size)
                else:
                    train_state, train_action, train_reward,train_new_state, \
                      idx_batch, weight_batch = replay_buffer.sample_batch(args.batch_size)

                ### Double DQN get target
                max_index = poacher.get_max_q_index(sess=sess, states=train_new_state)
                max_q = target_poacher.get_q_by_index(sess=sess, states=train_new_state, index=max_index)

                q_target = train_reward + args.reward_gamma * max_q

                if args.PER:
                    q_pred = sess.run(poacher.output, {poacher.input_state: train_state})
                    q_pred = q_pred[np.arange(args.batch_size), train_action]
                    TD_error_batch = np.abs(q_target - q_pred)
                    replay_buffer.update(idx_batch, TD_error_batch)

                if not args.PER:
                    weight = np.ones(args.batch_size) 
                else:
                    weight = weight_batch 

                ### Update parameter
                feed = {
                    poacher.input_state: train_state,
                    poacher.actions: train_action,
                    poacher.q_target: q_target,
                    poacher.learning_rate: learning_rate,
                    poacher.loss_weight: weight
                }
                sess.run(poacher.train_op, feed_dict=feed)

            ### Update target network
            if global_step > 0 and global_step % args.target_update_every == 0:
                sess.run(po_copy_op)

            ### game ends: 1) the patroller catches the poacher and removes all the snares; 
            ###            2) the maximum time step is achieved
            if end_game or (t == args.max_time - 1):
                info = str(e) + "\tepisode\t%s\tlength\t%s\ttotal_reward\t%s\taverage_reward\t%s" % \
                       (e, t + 1, episode_reward, 1. * episode_reward / (t + 1))
                if  e % args.print_every == 0:
                    log.write(info + '\n')
                    print('po ' + info)
                    #log.flush()
                break

        ### save model
        if  e > 0 and e % args.save_every_episode == 0 or e == train_episode_num - 1:
            save_name = args.save_path + 'iteration_' + str(iteration) +  '_epoch_'+ str(e) +  "_po_model.ckpt"
            poacher.save(sess=sess, filename=save_name)
            #print('Save model to ' + save_name)

        ### test 
        if e == train_episode_num - 1 or ( e > 0 and e % args.test_every_episode  == 0):
            po_utility = 0.0
            test_total_reward = np.zeros(len(pa_strategy))

            ### test against each patroller strategy in the current strategy set
            for pa_strat in range(len(pa_strategy)):
                if pa_strategy[pa_strat] > 1e-10:
                    _, test_total_reward[pa_strat], _ = test_(patrollers[pa_strat], poacher, \
                        env, sess,args, iteration, e, poacher_type = 'DQN', patroller_type = pa_type[pa_strat])
                    po_utility += pa_strategy[pa_strat] * test_total_reward[pa_strat]

            test_utility.append(po_utility)

            if po_utility > best_utility and (e > min(50000, train_episode_num / 2) or args.row_num == 3):
                best_utility = po_utility
                sess.run(po_good_copy_op)
                final_utility[1] = po_utility
            
            info = [str(po_utility)] + [str(x)  for x in test_total_reward]
            info = 'test   '  + str(e) + '   ' +  '\t'.join(info) + '\n'
            #print('reward is: ', info)
            print('po ' + info)
            test_log.write(info)
            test_log.flush()

    test_log.close()
    log.close()
Exemple #3
0
def calc_pa_best_response_PER(patroller, target_patroller, pa_copy_op, pa_good_copy_op, poachers, po_strategy, po_type,
        iteration, sess, env, args, final_utility, starting_e, train_episode_num = None, po_locations = None):
    
    '''
    po_locations: if is purely global mode, then po_locations is None
        else, it is the local + global retrain mode. each entry of po_locations specify the local mode of that poacher.
    Other things are basically the same as the function 'calc_po_best_response_PER'
    '''

    po_location = None

    #print('FIND_patroller_best_response iteration: ' + str(iteration))
    if train_episode_num is None:
        train_episode_num = args.pa_episode_num

    decrease_time = 1.0 / args.epsilon_decrease
    epsilon_decrease_every = train_episode_num // decrease_time

    if not args.PER:
        replay_buffer = ReplayBuffer(args, args.pa_replay_buffer_size)
    else:
        replay_buffer = PERMemory(args)
    best_utility = -10000.0
    test_utility = []

    if starting_e == 0:
        log = open(args.save_path + 'pa_log_train_iter_' + str(iteration) + '.dat', 'w')
        test_log = open(args.save_path + 'pa_log_test_iter_' + str(iteration) +  '.dat', 'w')
    else:
        log = open(args.save_path + 'pa_log_train_iter_' + str(iteration) + '.dat', 'a')
        test_log = open(args.save_path + 'pa_log_test_iter_' + str(iteration) +  '.dat', 'a')

    epsilon = 1.0
    learning_rate = args.po_initial_lr
    global_step = 0
    action_id = {
        'still': 0,
        'up': 1,
        'down': 2,
        'left': 3,
        'right': 4
    }

    sess.run(pa_copy_op)

    for e in range(starting_e, starting_e + train_episode_num):
        if e > 0 and e % epsilon_decrease_every == 0:
            epsilon = max(0.1, epsilon - args.epsilon_decrease)
        if e % args.mix_every_episode == 0 or e == starting_e:
            po_chosen_strat = np.argmax(np.random.multinomial(1, po_strategy))
            poacher = poachers[po_chosen_strat]
            type = po_type[po_chosen_strat]
            if po_locations is not None: # loacl + global mode, needs to change the poacher mode
                po_location = po_locations[po_chosen_strat]


        ### reset the environment
        poacher.reset_snare_num()
        pa_state, po_state = env.reset_game(po_location)
        episode_reward = 0.0
        pa_action = 'still'

        for t in range(args.max_time):
            global_step += 1

            ### transition records the (s,a,r,s) tuples
            transition = []

            ### poacher chooses an action
            ### doing so is because heuristic and DQN agent has different infer_action API
            if type == 'DQN':
                if not env.catch_flag and not env.home_flag: # if poacher is not caught, it can still do actions
                    po_state = np.array([po_state])
                    snare_flag, po_action = poacher.infer_action(sess=sess, states=po_state, policy="greedy", po_loc=env.po_loc, animal_density=env.animal_density)
                else: ### however, if it is caught, just make it stay still and does nothing
                    snare_flag = 0
                    po_action = 'still'
            elif type == 'PARAM':
                po_loc = env.po_loc
                if not env.catch_flag and not env.home_flag: 
                    snare_flag, po_action = poacher.infer_action(loc=po_loc,
                                                                local_trace=env.get_local_pa_trace(po_loc),
                                                                local_snare=env.get_local_snare(po_loc),
                                                                initial_loc=env.po_initial_loc)
                else:
                    snare_flag = 0
                    po_action = 'still'

            ### transition appends the current state
            transition.append(pa_state)

            ### patroller chooses an action
            pa_state = np.array([pa_state])
            pa_action = patroller.infer_action(sess=sess, states=pa_state, policy="epsilon_greedy", epsilon=epsilon, pa_loc=env.pa_loc, animal_density=env.animal_density)

            ### transition adds action
            transition.append(action_id[pa_action])

            ### the game moves on a step.
            pa_state, pa_reward, po_state, _, end_game = \
              env.step(pa_action, po_action, snare_flag)

            ### transition adds reward and the next state 
            episode_reward += pa_reward
            transition.append(pa_reward)
            transition.append(pa_state)

            ### Add transition to replay buffer
            replay_buffer.add_transition(transition)

            ### Start training
            ### Sample a minibatch, if the replay buffer has been full
            if replay_buffer.size >= args.batch_size:
                if not args.PER:
                    train_state, train_action, train_reward, train_new_state = \
                        replay_buffer.sample_batch(args.batch_size)
                else:
                    train_state, train_action, train_reward,train_new_state, \
                      idx_batch, weight_batch = replay_buffer.sample_batch(args.batch_size)

                ### Double DQN get target
                max_index = patroller.get_max_q_index(sess=sess, states=train_new_state)
                max_q = target_patroller.get_q_by_index(sess=sess, states=train_new_state, index=max_index)

                q_target = train_reward + args.reward_gamma * max_q

                if args.PER:
                    q_pred = sess.run(patroller.output, {patroller.input_state: train_state})
                    q_pred = q_pred[np.arange(args.batch_size), train_action]
                    TD_error_batch = np.abs(q_target - q_pred)
                    replay_buffer.update(idx_batch, TD_error_batch)

                if not args.PER:
                    weight = np.ones(args.batch_size)
                else:
                    weight = weight_batch 

                ### Update parameter
                feed = {
                    patroller.input_state: train_state,
                    patroller.actions: train_action,
                    patroller.q_target: q_target,
                    patroller.learning_rate: learning_rate,
                    patroller.weight_loss: weight
                }
                sess.run(patroller.train_op, feed_dict=feed)

            ### Update target network
            if global_step % args.target_update_every == 0:
                sess.run(pa_copy_op)

            ### game ends: 1) the patroller catches the poacher and removes all the snares; 
            ###            2) the maximum time step is achieved
            if end_game or (t == args.max_time - 1):
                info = str(e) + "\tepisode\t%s\tlength\t%s\ttotal_reward\t%s\taverage_reward\t%s" % \
                       (e, t + 1, episode_reward, 1. * episode_reward / (t + 1))
                if  e % args.print_every == 0:
                    log.write(info + '\n')
                    print('pa ' + info)
                    # log.flush()
                break


        ### save the models, and test if they are good 
        if  e > 0 and e % args.save_every_episode == 0 or e == train_episode_num - 1:
            save_name = args.save_path + 'iteration_' + str(iteration) + '_epoch_' + str(e) +  "_pa_model.ckpt"
            patroller.save(sess=sess, filename=save_name)

        ### test the agent
        if e == train_episode_num - 1 or ( e > 0 and e % args.test_every_episode  == 0):
            ### test against each strategy the poacher is using now, compute the expected utility
            pa_utility = 0.0
            test_total_reward = np.zeros(len(po_strategy))
            for po_strat in range(len(po_strategy)):
                if po_strategy[po_strat] > 1e-10:
                    if po_locations is None: ### indicates the purely global mode
                        tmp_po_location = None
                    else: ### indicates the local + global retrain mode, needs to set poacher mode
                        tmp_po_location = po_locations[po_strat]
                    test_total_reward[po_strat], _, _ = test_(patroller, poachers[po_strat], \
                            env, sess,args, iteration, e, patroller_type='DQN', poacher_type=po_type[po_strat],
                                po_location=tmp_po_location)
                    ### update the expected utility
                    pa_utility += po_strategy[po_strat] * test_total_reward[po_strat]

            test_utility.append(pa_utility)

            if pa_utility > best_utility and (e > min(50000, train_episode_num / 2) or args.row_num == 3):
                best_utility = pa_utility
                sess.run(pa_good_copy_op)
                final_utility[0] = pa_utility

            info = [str(pa_utility)] + [str(x)  for x in test_total_reward]
            info = 'test  ' + str(e) + '   ' +  '\t'.join(info) + '\n'
            #print('reward is: ', info)
            print('pa ' + info)
            test_log.write(info)
            test_log.flush()

    test_log.close()
    log.close()
class DQN:
    def __init__(self, state_dim, action_dim, gamma,
                 conf={'lr':0.001, 'bs':64, 'loss':nn.MSELoss, 'hidden_dim':64,
                       'activation':'relu',
                       'mem_size':50000, 'epsilon':1., 'eps_scheduler':'exp',
                       'n_episodes':1000, 'n_cycles':1, 'subtract':0.,
                      }):
        if conf['activation'] == 'relu':

            activation = torch.relu
        elif conf['activation'] == 'tanh':
            activation = torch.tanh

        self._q = Q(state_dim, action_dim,
                    non_linearity=activation, hidden_dim=conf['hidden_dim'],
                    dropout_rate=conf['dropout_rate']).to(device)
        self._q_target = Q(state_dim, action_dim,
                    non_linearity=activation, hidden_dim=conf['hidden_dim'],
                    dropout_rate=0.0).to(device)

        self._gamma = gamma
        ############################
        # exploration exploitation tradeoff
        self.epsilon = conf['epsilon']
        self.n_episodes = conf['n_episodes']
        self.n_cycles = conf['n_cycles']
        self.eps_scheduler = conf['eps_scheduler']
        ############################
        # Network
        self.bs = conf['bs']
        self._loss_function = nn.MSELoss()  # conf['loss']
        self._q_optimizer = optim.Adam(self._q.parameters(), lr=conf['lr'])
        self._action_dim = action_dim
        self._replay_buffer = ReplayBuffer(conf['mem_size'])
        self.scheduler = StepLR(self._q_optimizer, step_size=1, gamma=0.99)
        ############################
        # actions
        self.action = acton_discrete(action_dim)

    def get_action(self, x, epsilon):
        u = np.argmax(self._q(tt(x)).cpu().detach().numpy())
        r = np.random.uniform()
        if r < epsilon:
            return np.random.randint(self._action_dim)
        return u

    def train(self, episodes, time_steps, env, conf):
        # Statistics for each episode
        # start_time = time.time()
        stats = EpisodeStats(episode_lengths=np.zeros(episodes),
                             episode_rewards=np.zeros(episodes),
                             episode_loss=np.zeros(episodes),
                             episode_epsilon=np.zeros(episodes))
        ############################
        # opt: each 20e continuos ploting (only works without bohb)
        # cont_plot = continous_plot()
        ############################
        # Loop over episodes
        eps = self.epsilon
        for e in range(episodes):
            # reduce epsilon by decay rate
            if self.eps_scheduler == 'cos':
                eps = cos_ann_w_restarts(e, self.n_episodes,
                             self.n_cycles, self.epsilon)
            elif self.eps_scheduler == 'exp':
                eps = exponential_decay_w_restarts(e, self.n_episodes,
                                                   self.n_cycles,
                                                   conf['epsilon'], 0.03, conf['decay_rate'])
            stats.episode_epsilon[e] = eps
            ############################
            # opt: each 20e continuos ploting (only works without bohb)
            #if e % 20 == 0:
            #    cont_plot.plot_stats(stats)
            ############################
            stats.episode_lengths[e] = 0
            s = env.reset()
            for t in range(time_steps):
                ############################
                # opt: render env every 5 episodes
                #if e % 5 == 0:
                #    env.render()
                ############################
                # act and get results
                a = self.get_action(s, eps)
                ns, r, d, _ = env.step(self.action.act(a))
                ns[2] = angle_normalize(ns[2])
                stats.episode_rewards[e] += r
                self._replay_buffer.add_transition(s, a, ns, r, d)
                batch_states, batch_actions, batch_next_states, batch_rewards, batch_terminal_flags = self._replay_buffer.random_next_batch(self.bs)  # NOQA
                # get actions of Target network
                target = (batch_rewards
                          + (1 - batch_terminal_flags)
                             * self._gamma
                             * self._q_target(batch_next_states)[
                                    torch.arange(conf['bs']).long(),
                                    torch.argmax(self._q(batch_next_states),
                                    dim=1)])
                # get actions of value network
                current_prediction = self._q(batch_states)[
                                        torch.arange(self.bs).long(),
                                        batch_actions.long()]
                ############################
                # Update acting network
                loss = self._loss_function(current_prediction, target.detach())
                stats.episode_loss[e] += loss.cpu().detach()
                self._q_optimizer.zero_grad()
                loss.backward()
                self._q_optimizer.step()
                # Update target network
                soft_update(self._q_target, self._q, 0.01)
                ############################
                # stop episode if carte leaves boundaries
                if d:
                    stats.episode_lengths[e] = t
                    break
                s = ns
            ############################
            # if episode didn't failed, time is maximal time
            if stats.episode_lengths[e] == 0:
                stats.episode_lengths[e] = time_steps
            self.scheduler.step()
        return stats