Esempio n. 1
0
    def run_test_with_noise(self, num_test=10):
        state = self.env.reset()  #_for_test()
        state = Variable(torch.Tensor(state).unsqueeze(0))
        model_old = ActorCriticNet(self.num_inputs, self.num_outputs,
                                   self.hidden_layer)
        model_old.load_state_dict(self.model.state_dict())
        ave_test_reward = 0

        total_rewards = []

        for i in range(num_test):
            total_reward = 0
            while True:
                state = self.shared_obs_stats.normalize(state)
                mu, log_std, v = self.model(state)
                eps = torch.randn(mu.size())
                action = (mu + 0.1 * Variable(eps))
                action = action.data.squeeze().numpy()
                state, reward, done, _ = self.env.step(action)
                total_reward += reward

                if done:
                    state = self.env.reset()  #_for_test()
                    state = Variable(torch.Tensor(state).unsqueeze(0))
                    ave_test_reward += total_reward / num_test
                    total_rewards.append(total_reward)
                    break
                state = Variable(torch.Tensor(state).unsqueeze(0))
        #print("avg test reward is", ave_test_reward)

        reward_mean = statistics.mean(total_rewards)
        reward_std = statistics.stdev(total_rewards)
        self.noisy_test_mean.append(reward_mean)
        self.noisy_test_std.append(reward_std)
        self.noisy_test_list.append((reward_mean, reward_std))
Esempio n. 2
0
def test_model(model_file: str):
    net = ActorCriticNet(4, 2)
    net.load_state_dict(torch.load(model_file))
    net.eval()

    env = gym.make("CartPole-v1")
    env = gym.wrappers.Monitor(env,
                               f"./cart",
                               video_callable=lambda episode_id: True,
                               force=True)

    observation = env.reset()

    R = 0
    while True:
        env.render()
        cleaned_observation = torch.tensor(observation).unsqueeze(dim=0)
        action_logits = net.forward_actor(cleaned_observation)
        action = Categorical(logits=action_logits).sample()
        observation, r, done, _ = env.step(action.item())
        R += r
        if done:
            break

    env.close()

    print(R)
 def validation(self):
     batch_states, batch_actions, batch_next_states, batch_rewards, batch_q_values = self.validation_trajectory.sample(300)
     model_old = ActorCriticNet(self.num_inputs, self.num_outputs, self.hidden_layer)
     model_old.load_state_dict(self.model.state_dict())
     batch_states = Variable(torch.Tensor(batch_states))
     batch_q_values = Variable(torch.Tensor(batch_q_values))
     batch_actions = Variable(torch.Tensor(batch_actions))
     mu_old, log_std_old, v_pred_old = model_old(batch_states)
     loss = torch.mean((batch_actions-mu_old)**2)
     if loss.data < self.current_best_validation:
         self.current_best_validation = loss.data
     print("validation error", self.current_best_validation)
    def update_actor(self, batch_size, num_epoch, supervised=False):
        model_old = ActorCriticNet(self.num_inputs, self.num_outputs, self.hidden_layer)
        model_old.load_state_dict(self.model.state_dict())
        model_old.set_noise(self.model.noise)
        self.model.train()
        optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
        for k in range(num_epoch):
            batch_states, batch_actions, batch_next_states, batch_rewards, batch_q_values = self.memory.sample(batch_size)

            batch_states = Variable(torch.Tensor(batch_states))
            batch_q_values = Variable(torch.Tensor(batch_q_values))
            batch_actions = Variable(torch.Tensor(batch_actions))
            mu_old, log_std_old, v_pred_old = model_old(batch_states)
            #mu_old_next, log_std_old_next, v_pred_old_next = model_old(batch_next_states)
            mu, log_std, v_pred = self.model(batch_states)
            batch_advantages = batch_q_values - v_pred_old
            probs_old = normal(batch_actions, mu_old, log_std_old)
            probs = normal(batch_actions, mu, log_std)
            ratio = (probs - (probs_old)).exp()
            ratio = ratio.unsqueeze(1)
            #print(model_old.noise)
            #print(ratio)
            batch_advantages = batch_q_values - v_pred_old
            surr1 = ratio * batch_advantages
            surr2 = ratio.clamp(1-self.params.clip, 1+self.params.clip) * batch_advantages
            loss_clip = -torch.mean(torch.min(surr1, surr2))

            #expert loss
            if supervised is True:
                if k % 1000 == 999:
                    batch_expert_states, batch_expert_actions, _, _, _ = self.expert_trajectory.sample(len(self.expert_trajectory.memory))
                else:
                    batch_expert_states, batch_expert_actions, _, _, _ = self.expert_trajectory.sample(batch_size)
                batch_expert_states = Variable(torch.Tensor(batch_expert_states))
                batch_expert_actions = Variable(torch.Tensor(batch_expert_actions))
                mu_expert, _, _ = self.model(batch_expert_states)
                mu_expert_old, _, _ = model_old(batch_expert_states)
                loss_expert1 = torch.mean((batch_expert_actions-mu_expert)**2)
                clip_expert_action = torch.max(torch.min(mu_expert, mu_expert_old + 0.1), mu_expert_old-0.1)
                loss_expert2 = torch.mean((clip_expert_action-batch_expert_actions)**2)
                loss_expert = loss_expert1#torch.min(loss_expert1, loss_expert2)
            else:
                loss_expert = 0

            total_loss = self.policy_weight * loss_clip + self.weight*loss_expert
            print(k, loss_expert)
            optimizer.zero_grad()
            total_loss.backward(retain_graph=True)
            #print(torch.nn.utils.clip_grad_norm(self.model.parameters(),1))
            optimizer.step()
        if self.lr > 1e-4:
            self.lr *= 0.99
 def update_critic(self, batch_size, num_epoch):
     self.model.train()
     optimizer = optim.Adam(self.model.parameters(), lr=self.lr*10)
     model_old = ActorCriticNet(self.num_inputs, self.num_outputs, self.hidden_layer)
     model_old.load_state_dict(self.model.state_dict())
     for k in range(num_epoch):
         batch_states, batch_actions, batch_next_states, batch_rewards, batch_q_values = self.memory.sample(batch_size)
         batch_states = Variable(torch.Tensor(batch_states))
         batch_q_values = Variable(torch.Tensor(batch_q_values))
         batch_next_states = Variable(torch.Tensor(batch_next_states))
         _, _, v_pred_next = model_old(batch_next_states)
         _, _, v_pred = self.model(batch_states)
         loss_value = (v_pred - batch_q_values)**2
         #loss_value = (v_pred_next * self.params.gamma + batch_rewards - v_pred)**2
         loss_value = 0.5*torch.mean(loss_value)
         optimizer.zero_grad()
         loss_value.backward(retain_graph=True)
         optimizer.step()
    def normalize_data(self, num_iter=50000, file='shared_obs_stats.pkl'):
        state = self.env.reset_for_normalization()
        state = Variable(torch.Tensor(state).unsqueeze(0))
        model_old = ActorCriticNet(self.num_inputs, self.num_outputs,self.hidden_layer)
        model_old.load_state_dict(self.model.state_dict())
        for i in range(num_iter):
            self.shared_obs_stats.observes(state)
            state = self.shared_obs_stats.normalize(state)
            mu, log_std, v = model_old(state)
            eps = torch.randn(mu.size())
            action = (mu + log_std.exp()*Variable(eps))
            env_action = action.data.squeeze().numpy()
            state, reward, done, _ = self.env.step(env_action)

            if done:
                state = self.env.reset()

            state = Variable(torch.Tensor(state).unsqueeze(0))

        with open(file, 'wb') as output:
            pickle.dump(self.shared_obs_stats, output, pickle.HIGHEST_PROTOCOL)
    def run_test(self, num_test=1):
        state = self.env.reset_for_test()
        state = Variable(torch.Tensor(state).unsqueeze(0))
        model_old = ActorCriticNet(self.num_inputs, self.num_outputs,self.hidden_layer)
        model_old.load_state_dict(self.model.state_dict())
        ave_test_reward = 0

        total_rewards = []
        '''self.fig2.clear()
        circle1 = plt.Circle((0, 0), 0.5, edgecolor='r', facecolor='none')
        circle2 = plt.Circle((0, 0), 0.01, edgecolor='r', facecolor='none')
        plt.axis('equal')'''
        
        for i in range(num_test):
            total_reward = 0
            while True:
                state = self.shared_obs_stats.normalize(state)
                mu, log_std, v = self.model(state)
                action = mu.data.squeeze().numpy()
                state, reward, done, _ = self.env.step(action)
                total_reward += reward
                #print(state)
                #print("done", done, "state", state)

                if done:
                    state = self.env.reset_for_test()
                    #print(self.env.position)
                    #print(self.env.time)
                    state = Variable(torch.Tensor(state).unsqueeze(0))
                    ave_test_reward += total_reward / num_test
                    total_rewards.append(total_reward)
                    break
                state = Variable(torch.Tensor(state).unsqueeze(0))
        #print("avg test reward is", ave_test_reward)

        reward_mean = statistics.mean(total_rewards)
        reward_std = statistics.stdev(total_rewards)
        self.test_mean.append(reward_mean)
        self.test_std.append(reward_std)
        self.test_list.append((reward_mean, reward_std))
Esempio n. 8
0
env = cassieRLEnvMirror()

u = pd_in_t()
u.leftLeg.motorPd.torque[3] = 0  # Feedforward torque
u.leftLeg.motorPd.pTarget[3] = -2
u.leftLeg.motorPd.pGain[3] = 1000
u.leftLeg.motorPd.dTarget[3] = -2
u.leftLeg.motorPd.dGain[3] = 100
u.rightLeg.motorPd = u.leftLeg.motorPd

num_inputs = env.observation_space.shape[0]
num_outputs = env.action_space.shape[0]

model = ActorCriticNet(num_inputs, num_outputs, [256, 256])
#model.load_state_dict(torch.load("torch_model/StablePelvisForwardBackward256X256Jan25.pt"))
model.load_state_dict(torch.load("torch_model/corl_demo.pt"))
with open('torch_model/cassie3dMirror2kHz_shared_obs_stats.pkl',
          'rb') as input:
    shared_obs_stats = pickle.load(input)

state_list = []
env.visualize = True


def run_test():
    t.sleep(1)
    state = env.reset()
    total_reward = 0
    done = False
    total_10_reward = 0
    current_scale = 0
Esempio n. 9
0
    def update_actor(self, batch_size, num_epoch, supervised=False):
        model_old = ActorCriticNet(self.num_inputs, self.num_outputs,
                                   self.hidden_layer)
        model_old.load_state_dict(self.model.state_dict())
        model_old.set_noise(self.noise.value)
        self.model.train()
        optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
        for k in range(num_epoch):
            batch_states, batch_actions, batch_next_states, batch_rewards, batch_q_values = self.memory.sample(
                batch_size)

            #mirror
            batch_mirror_states = np.copy(batch_states)
            #batch_mirror_actions = np.copy(batch_actions)

            batch_states = Variable(torch.Tensor(batch_states))
            batch_q_values = Variable(torch.Tensor(batch_q_values))
            batch_actions = Variable(torch.Tensor(batch_actions))
            mu_old, log_std_old, v_pred_old = model_old(batch_states)
            #mu_old_next, log_std_old_next, v_pred_old_next = model_old(batch_next_states)
            mu, log_std, v_pred = self.model(batch_states)
            batch_advantages = batch_q_values - v_pred_old
            probs_old = normal(batch_actions, mu_old, log_std_old)
            probs = normal(batch_actions, mu, log_std)
            ratio = (probs - (probs_old)).exp()
            ratio = ratio.unsqueeze(1)
            #print(model_old.noise)
            #print(ratio)
            batch_advantages = batch_q_values - v_pred_old
            surr1 = ratio * batch_advantages
            surr2 = ratio.clamp(1 - self.params.clip,
                                1 + self.params.clip) * batch_advantages
            loss_clip = -torch.mean(torch.min(surr1, surr2))

            #expert loss
            if supervised is True:
                if k % 1000 == 999:
                    batch_expert_states, batch_expert_actions, _, _, _ = self.expert_trajectory.sample(
                        len(self.expert_trajectory.memory))
                else:
                    batch_expert_states, batch_expert_actions, _, _, _ = self.expert_trajectory.sample(
                        batch_size)
                batch_expert_states = Variable(
                    torch.Tensor(batch_expert_states))
                batch_expert_actions = Variable(
                    torch.Tensor(batch_expert_actions))
                mu_expert, _, _ = self.model(batch_expert_states)
                mu_expert_old, _, _ = model_old(batch_expert_states)
                loss_expert1 = torch.mean(
                    (batch_expert_actions - mu_expert)**2)
                clip_expert_action = torch.max(
                    torch.min(mu_expert, mu_expert_old + 0.1),
                    mu_expert_old - 0.1)
                loss_expert2 = torch.mean(
                    (clip_expert_action - batch_expert_actions)**2)
                loss_expert = loss_expert1  #torch.min(loss_expert1, loss_expert2)
            else:
                loss_expert = 0

            #mirror loss
            (
                negation_obs_indices,
                right_obs_indices,
                left_obs_indices,
                negation_action_indices,
                right_action_indices,
                left_action_indices,
            ) = self.env.get_mirror_indices()

            batch_mirror_states[:, negation_obs_indices] *= -1
            rl = np.concatenate((right_obs_indices, left_obs_indices))
            lr = np.concatenate((left_obs_indices, right_obs_indices))
            batch_mirror_states[:, rl] = batch_mirror_states[:, lr]

            #with torch.no_grad():
            batch_mirror_actions, _, _ = self.model(batch_states)
            batch_mirror_actions_clone = batch_mirror_actions.clone()
            batch_mirror_actions_clone[:,
                                       negation_action_indices] = batch_mirror_actions[:,
                                                                                       negation_action_indices] * -1
            rl = np.concatenate((right_action_indices, left_action_indices))
            lr = np.concatenate((left_action_indices, right_action_indices))
            batch_mirror_actions_clone[:, rl] = batch_mirror_actions[:, lr]
            #batch_mirror_actions_v2[:,]
            #print(vars(batch_mirror_actions))

            batch_mirror_states = Variable(torch.Tensor(batch_mirror_states))
            #batch_mirror_actions = Variable(torch.Tensor(batch_mirror_actions))
            mirror_mu, _, _ = self.model(batch_mirror_states)
            mirror_loss = torch.mean(
                (mirror_mu - batch_mirror_actions_clone)**2)

            total_loss = 1.0 * loss_clip + self.weight * loss_expert + mirror_loss
            #print(k, loss_expert)
            #print(k)
            '''self.validation()
            if k % 1000 == 999:
                #self.run_test(num_test=2)
                #self.run_test_with_noise(num_test=2)
                #self.plot_statistics()
                self.save_model("expert_model/SupervisedModel16X16Jan11.pt")
                if (self.current_best_validation - self.best_validation)  > -1e-5:
                    break
                if self.best_validation > self.current_best_validation:
                    self.best_validation = self.current_best_validation
                self.current_best_validation = 1.0
            print(k, loss_expert)'''
            #print(loss_clip)
            optimizer.zero_grad()
            total_loss.backward(retain_graph=True)
            #print(torch.nn.utils.clip_grad_norm(self.model.parameters(),1))
            optimizer.step()
        if self.lr > 1e-4:
            self.lr *= 0.99
        if self.weight > 10:
            self.weight *= 0.99
        if self.weight < 10:
            self.weight = 10.0
Esempio n. 10
0
    def collect_expert_samples(self,
                               num_samples,
                               filename,
                               noise=-2.0,
                               speed=0,
                               y_speed=0,
                               validation=False):
        expert_env = cassieRLEnvMirrorWithTransition()
        start_state = expert_env.reset_by_speed(speed, y_speed)
        samples = 0
        done = False
        states = []
        next_states = []
        actions = []
        rewards = []
        values = []
        q_values = []
        self.model.set_noise(self.noise.value)
        model_expert = ActorCriticNet(85, 10, [256, 256])

        model_expert.load_state_dict(torch.load(filename))
        model_expert.set_noise(self.noise.value)

        with open('torch_model/cassie3dMirror2kHz_shared_obs_stats.pkl',
                  'rb') as input:
            expert_shared_obs_stats = pickle.load(input)

        residual_model = ActorCriticNet(85, 10, [256, 256])
        residual_model.load_state_dict(
            torch.load("torch_model/StablePelvisNov14_v2.pt"))

        state = start_state
        virtual_state = np.concatenate([np.copy(state[0:46]), np.zeros(39)])
        state = Variable(torch.Tensor(state).unsqueeze(0))
        virtual_state = Variable(torch.Tensor(virtual_state).unsqueeze(0))
        total_reward = 0
        total_sample = 0
        #q_value = Variable(torch.zeros(1, 1))
        if validation:
            max_sample = 300
        else:
            max_sample = 3000
        while total_sample < max_sample:
            model_expert.set_noise(self.noise.value)
            score = 0
            while samples < num_samples and not done:
                state = expert_shared_obs_stats.normalize(state)
                virtual_state = expert_shared_obs_stats.normalize(
                    virtual_state)

                states.append(state.data.numpy())
                mu, log_std, v = model_expert(state)
                mu_residual, _, _ = residual_model(state)
                #print(log_std.exp())
                action = (mu + mu_residual * 0)
                pos_index = [7, 8, 9, 14, 20, 21, 22, 23, 28, 34]
                vel_index = [6, 7, 8, 12, 18, 19, 20, 21, 25, 31]
                ref_pos, ref_vel = expert_env.get_kin_next_state()
                saved_action = action.data.numpy() + ref_pos[pos_index]

                actions.append(action.data.numpy())
                #actions.append(saved_action)
                values.append(v.data.numpy())
                eps = torch.randn(mu.size())
                if validation:
                    weight = 0.1
                else:
                    weight = 0.1
                mu = (action + np.exp(-2) * Variable(eps))
                env_action = mu.data.squeeze().numpy()

                state, reward, done, _ = expert_env.step(env_action)
                reward = 1
                rewards.append(Variable(reward * torch.ones(1)).data.numpy())
                #q_value = self.gamma * q_value + Variable(reward * torch.ones(1))
                virtual_state = np.concatenate(
                    [np.copy(state[0:46]), np.zeros(39)])
                virtual_state = Variable(
                    torch.Tensor(virtual_state).unsqueeze(0))
                state = Variable(torch.Tensor(state).unsqueeze(0))

                next_state = expert_shared_obs_stats.normalize(state)
                next_states.append(next_state.data.numpy())

                samples += 1
                #total_sample += 1
                score += reward
            print("expert score", score)

            state = expert_shared_obs_stats.normalize(state)
            #print(state)
            _, _, v = model_expert(state)
            if done:
                R = torch.zeros(1, 1)
            else:
                R = v.data
                R = torch.ones(1, 1) * 100
            R = Variable(R)
            for i in reversed(range(len(rewards))):
                R = self.params.gamma * R + Variable(
                    torch.from_numpy(rewards[i]))
                q_values.insert(0, R.data.numpy())

            if not validation and score >= 299:
                self.expert_trajectory.push(
                    [states, actions, next_states, rewards, q_values])
                total_sample += 300
            elif score >= 299:
                self.validation_trajectory.push(
                    [states, actions, next_states, rewards, q_values])
            expert_env.reset_by_speed(speed, y_speed)
            start_state = expert_env.reset_by_speed(speed, y_speed)
            state = start_state
            state = Variable(torch.Tensor(state).unsqueeze(0))
            total_reward = 0
            samples = 0
            done = False
            states = []
            next_states = []
            actions = []
            rewards = []
            values = []
            q_values = []
Esempio n. 11
0
    def collect_samples(self,
                        num_samples,
                        start_state=None,
                        noise=-2.0,
                        env_index=0,
                        random_seed=1):

        random.seed(random_seed)
        torch.manual_seed(random_seed + 1)
        np.random.seed(random_seed + 2)

        if start_state == None:
            start_state = self.env.reset()
        samples = 0
        done = False
        states = []
        next_states = []
        actions = []
        rewards = []
        values = []
        q_values = []
        real_rewards = []
        self.model.set_noise(self.noise.value)
        #print("soemthing 1")
        model_old = ActorCriticNet(self.num_inputs, self.num_outputs,
                                   self.hidden_layer)
        model_old.load_state_dict(self.model.state_dict())
        #print("something 2")
        model_old.set_noise(self.noise.value)

        state = start_state
        state = Variable(torch.Tensor(state).unsqueeze(0))
        total_reward = 0
        #q_value = Variable(torch.zeros(1, 1))
        while True:
            self.model.set_noise(self.noise.value)
            model_old.set_noise(self.noise.value)
            signal_init = self.traffic_light.get()
            score = 0
            while samples < num_samples and not done:

                state = self.shared_obs_stats.normalize(state)

                states.append(state.data.numpy())
                mu, log_std, v = model_old(state)
                eps = torch.randn(mu.size())
                #print(log_std.exp())
                #print(log_std.exp())
                action = (mu + log_std.exp() * Variable(eps))
                actions.append(action.data.numpy())

                values.append(v.data.numpy())

                env_action = action.data.squeeze().numpy()
                state, reward, done, _ = self.env.step(env_action)
                score += reward
                rewards.append(Variable(reward * torch.ones(1)).data.numpy())

                # rewards.append(Variable(reward * torch.ones(1)).data.numpy())
                real_rewards.append(
                    Variable(reward * torch.ones(1)).data.numpy())

                state = Variable(torch.Tensor(state).unsqueeze(0))

                next_state = self.shared_obs_stats.normalize(state)
                next_states.append(next_state.data.numpy())

                samples += 1

            state = self.shared_obs_stats.normalize(state)

            _, _, v = model_old(state)
            if done:
                R = torch.zeros(1, 1)
            else:
                R = v.data
            R = Variable(R)
            for i in reversed(range(len(real_rewards))):
                R = self.params.gamma * R + Variable(
                    torch.from_numpy(real_rewards[i]))
                q_values.insert(0, R.data.numpy())

            self.queue.put([states, actions, next_states, rewards, q_values])
            self.counter.increment()
            self.env.reset()
            while self.traffic_light.get() == signal_init:
                pass
            start_state = self.env.reset()
            state = start_state
            state = Variable(torch.Tensor(state).unsqueeze(0))
            total_reward = 0
            samples = 0
            done = False
            states = []
            next_states = []
            actions = []
            rewards = []
            values = []
            q_values = []
            real_rewards = []
            model_old = ActorCriticNet(self.num_inputs, self.num_outputs,
                                       self.hidden_layer)
            model_old.load_state_dict(self.model.state_dict())
            model_old.set_noise(self.noise.value)
Esempio n. 12
0
class TrainerProcess:
    def __init__(self, global_net, global_opt):
        self.proc_net = ActorCriticNet(4, 2, training=True)
        self.proc_net.load_state_dict(global_net.state_dict())
        self.proc_net.train()

        self.global_net = global_net
        self.optimizer = global_opt
        self.env = gym.make("CartPole-v1")

        print(f"Starting process...")
        sys.stdout.flush()

    def play_episode(self):
        episode_actions = torch.empty(size=(0, ), dtype=torch.long)
        episode_logits = torch.empty(size=(0, self.env.action_space.n),
                                     dtype=torch.long)
        episode_observs = torch.empty(size=(0,
                                            *self.env.observation_space.shape),
                                      dtype=torch.long)
        episode_rewards = np.empty(shape=(0, ), dtype=np.float)

        observation = self.env.reset()

        t = 0
        done = False
        while not done:
            # Prepare observation
            cleaned_observation = torch.tensor(observation).unsqueeze(dim=0)
            episode_observs = torch.cat((episode_observs, cleaned_observation),
                                        dim=0)

            # Get action from policy net
            action_logits = self.proc_net.forward_actor(cleaned_observation)
            action = Categorical(logits=action_logits).sample()

            # Save observation and the action from the net
            episode_logits = torch.cat((episode_logits, action_logits), dim=0)
            episode_actions = torch.cat((episode_actions, action), dim=0)

            # Get new observation and reward from action
            observation, r, done, _ = self.env.step(action.item())

            # Save reward from net_action
            episode_rewards = np.concatenate(
                (episode_rewards, np.asarray([r])), axis=0)

            t += 1

        discounted_R = self.get_discounted_rewards(episode_rewards, GAMMA)
        discounted_R -= episode_rewards.mean()

        mask = F.one_hot(episode_actions, num_classes=self.env.action_space.n)
        episode_log_probs = torch.sum(mask.float() *
                                      F.log_softmax(episode_logits, dim=1),
                                      dim=1)

        values = self.proc_net.forward_critic(episode_observs)
        action_advantage = (discounted_R.float() - values).detach()
        episode_weighted_log_probs = episode_log_probs * action_advantage
        sum_weighted_log_probs = torch.sum(
            episode_weighted_log_probs).unsqueeze(dim=0)
        sum_action_advantages = torch.sum(action_advantage).unsqueeze(dim=0)

        return (
            sum_weighted_log_probs,
            sum_action_advantages,
            episode_logits,
            np.sum(episode_rewards),
            t,
        )

    def get_discounted_rewards(self, rewards: np.array,
                               GAMMA: float) -> torch.Tensor:
        """
        Calculates the sequence of discounted rewards-to-go.
        Args:
            rewards: the sequence of observed rewards
            GAMMA: the discount factor
        Returns:
            discounted_rewards: the sequence of the rewards-to-go

        AXEL: Directly from
        https://towardsdatascience.com/breaking-down-richard-suttons-policy-gradient-9768602cb63b
        """
        discounted_rewards = np.empty_like(rewards, dtype=np.float)
        for i in range(rewards.shape[0]):
            GAMMAs = np.full(shape=(rewards[i:].shape[0]), fill_value=GAMMA)
            discounted_GAMMAs = np.power(GAMMAs,
                                         np.arange(rewards[i:].shape[0]))
            discounted_reward = np.sum(rewards[i:] * discounted_GAMMAs)
            discounted_rewards[i] = discounted_reward
        return torch.from_numpy(discounted_rewards)

    def calculate_policy_loss(self, epoch_logits: torch.Tensor,
                              weighted_log_probs: torch.Tensor):
        policy_loss = -torch.mean(weighted_log_probs)
        p = F.softmax(epoch_logits, dim=1)
        log_p = F.log_softmax(epoch_logits, dim=0)
        entropy = -1 * torch.mean(torch.sum(p * log_p, dim=-1), dim=0)
        entropy_bonus = -1 * BETA * entropy
        return policy_loss + entropy_bonus, entropy

    def share_grads(self):
        for gp, lp in zip(self.global_net.parameters(),
                          self.proc_net.parameters()):
            if gp.grad is not None:
                return
            gp._grad = lp.grad

    def train(self):
        epoch, episode = 0, 0
        total_rewards = []
        epoch_action_advantage = torch.empty(size=(0, ))
        epoch_logits = torch.empty(size=(0, self.env.action_space.n))
        epoch_weighted_log_probs = torch.empty(size=(0, ), dtype=torch.float)

        while True:
            (
                episode_weighted_log_probs,
                action_advantage_sum,
                episode_logits,
                total_episode_reward,
                t,
            ) = self.play_episode()

            episode += 1
            total_rewards.append(total_episode_reward)
            epoch_weighted_log_probs = torch.cat(
                (epoch_weighted_log_probs, episode_weighted_log_probs), dim=0)
            epoch_action_advantage = torch.cat(
                (epoch_action_advantage, action_advantage_sum), dim=0)

            if episode > BATCH_SIZE:

                episode = 0
                epoch += 1

                policy_loss, entropy = self.calculate_policy_loss(
                    epoch_logits=epoch_logits,
                    weighted_log_probs=epoch_weighted_log_probs,
                )
                value_loss = torch.square(epoch_action_advantage).mean()
                total_loss = policy_loss + VALUE_LOSS_CONSTANT * value_loss

                self.optimizer.zero_grad()
                self.share_grads()
                total_loss.backward()
                self.optimizer.step()

                self.proc_net.load_state_dict(self.global_net.state_dict())

                print(
                    f"{os.getpid()} Epoch: {epoch}, Avg Return per Epoch: {np.mean(total_rewards):.3f}"
                )
                sys.stdout.flush()

                # reset the epoch arrays, used for entropy calculation
                epoch_logits = torch.empty(size=(0, self.env.action_space.n))
                epoch_weighted_log_probs = torch.empty(size=(0, ),
                                                       dtype=torch.float)

                # check if solved
                if np.mean(total_rewards) > 200:
                    print("\nSolved!")
                    break

        self.env.close()