コード例 #1
0
 def evaluate(self, state, action):
     action_mean = torch.tanh(self.net(state))
     dist = Normal(action_mean, self.action_std)
     action_logprobs = torch.sum(dist.log_prob(action), dim=1)
     self.clamp(self.action_std, 0.01)
     dist_entropy = torch.sum(dist.entropy(), dim=1)
     return action_logprobs, dist_entropy
コード例 #2
0
 def get_action(self, x, action=None):
     mean, logstd = self.forward(x)
     std = torch.exp(logstd)
     probs = Normal(mean, std)
     if action is None:
         action = probs.sample()
     return action, probs.log_prob(action).sum(1), probs.entropy().sum(1)
コード例 #3
0
    def learn(self, s, a, td):
        s = torch.from_numpy(s[np.newaxis, :]).float()
        td_no_grad = td.detach()
        mu, sigma = torch.squeeze(self.mu(self.l1(s))), torch.squeeze(
            self.sigma(self.l1(s)))
        normal_dist = Normal(mu * 2, sigma + 0.1)
        # action = torch.clamp(normal_dist.sample(1), self.action_bound[0], self.action_bound[1])
        log_prob = normal_dist.log_prob(torch.from_numpy(a))
        self.exp_v = log_prob * td_no_grad
        self.exp_v += 0.01 * normal_dist.entropy()
        self.exp_v = -self.exp_v
        optimizer = optim.Adam([{
            'params': self.l1.parameters()
        }, {
            'params': self.sigma.parameters()
        }, {
            'params': self.mu.parameters()
        }],
                               lr=self.lr)

        # optimize the model
        optimizer.zero_grad()
        self.exp_v.backward()
        optimizer.step()
        return -self.exp_v
コード例 #4
0
 def get_action(self, x, action=None):
     action_mean = self.actor_mean(x)
     action_logstd = self.actor_logstd.expand_as(action_mean)
     action_std = torch.exp(action_logstd)
     probs = Normal(action_mean, action_std)
     if action is None:
         action = probs.sample()
     return action, probs.log_prob(action).sum(1), probs.entropy().sum(1)
コード例 #5
0
ファイル: utils.py プロジェクト: TianhongDai/esil-hindsight
def evaluate_actions(pi, actions, dist_type):
    if dist_type == 'gauss':
        mean, std = pi
        normal_dist = Normal(mean, std)
        log_prob = normal_dist.log_prob(actions).sum(dim=1, keepdim=True)
        entropy = normal_dist.entropy().mean()
    else:
        raise NotImplementedError
    return log_prob, entropy
コード例 #6
0
ファイル: ppo_continuous_gae.py プロジェクト: perfmjs/cleanrl
 def get_action(self, x):
     mean, std = self.forward(x)
     probs = Normal(mean, std)
     action = probs.sample()
     clipped_action = torch.clamp(
         action, torch.min(torch.Tensor(env.action_space.low)),
         torch.min(torch.Tensor(env.action_space.high)))
     return clipped_action, -probs.log_prob(action).sum(1), probs.entropy(
     ).sum(1)
コード例 #7
0
 def get_action(self, x, action=None):
     mean, logstd = self.forward(x)
     std = torch.exp(logstd)
     probs = Normal(mean, std)
     if action is None:
         action = probs.sample()
     else:
         if not isinstance(action, torch.Tensor):
             action = preprocess_obs_fn(action)
     return action, probs.log_prob(action).sum(1), probs.entropy().sum(1)
コード例 #8
0
    def forward(self, x, a=None):
        policy = Normal(self.mu(x), self.log_std.exp())
        pi = policy.sample()
        logp_pi = policy.log_prob(pi).mean(dim=1)
        if a is not None:
            logp = policy.log_prob(a).mean(dim=1)
        else:
            logp = None
        entropy = policy.entropy().mean(dim=1)

        return pi, logp, logp_pi, entropy
コード例 #9
0
ファイル: agent.py プロジェクト: code4meplz/reacher_env
    def choose_action(self, state):
        state = T.Tensor(state)

        self.actor.eval()

        mus, sigmas = self.actor(state)

        normal = Normal(mus, sigmas)
        self.entropy = normal.entropy()
        self.actions = normal.sample()
        self.log_prob = normal.log_prob(self.actions)

        return self.actions.numpy()
コード例 #10
0
ファイル: gaussian.py プロジェクト: BY571/pytorchrl
    def forward(self, x, deterministic=False):
        """
        Predict distribution parameters from x (obs features) and return
        predicted values (sampled and clipped), sampled log
        probability and distribution entropy.

        Parameters
        ----------
        x : torch.tensor
            Feature maps extracted from environment observations.
        deterministic : bool
            Whether to randomly sample from predicted distribution or take the mode.

        Returns
        -------
        pred: torch.tensor
            Predicted value.
        clipped_pred: torch.tensor
            Predicted value (clipped to be within [-1, 1] range).
        logp : torch.tensor
            Log probability of `pred` according to the predicted distribution.
        entropy_dist : torch.tensor
            Entropy of the predicted distribution.
        """
        # Predict distribution parameters
        mean = self.mean(x)
        logstd = self.log_std(x) if self.predict_log_std else torch.zeros(
            mean.size()).to(x.device) + self.log_std

        # logstd = torch.clamp(logstd, LOG_STD_MIN, LOG_STD_MAX)

        # Create distribution and sample
        dist = Normal(mean, logstd.exp())

        if deterministic:
            pred = mean
        else:
            pred = dist.sample()

        # Apply clipping to avoid being outside output space
        clipped_pred = torch.clamp(pred, -1, 1)

        # Action log probability
        logp = dist.log_prob(pred).sum(-1, keepdim=True)

        # Distribution entropy
        entropy_dist = dist.entropy().sum(-1).mean()

        return pred, clipped_pred, logp, entropy_dist
コード例 #11
0
ファイル: utils.py プロジェクト: DAIM-2020/DAIM
def evaluate_actions(pi, actions, dist_type, env_type):
    if env_type == 'atari':
        cate_dist = Categorical(pi)
        log_prob = cate_dist.log_prob(actions).unsqueeze(-1)
        entropy = cate_dist.entropy().mean()
    else:
        if dist_type == 'gauss':
            mean, std = pi
            normal_dist = Normal(mean, std)
            log_prob = normal_dist.log_prob(actions).sum(dim=1, keepdim=True)
            entropy = normal_dist.entropy().mean()
        elif dist_type == 'beta':
            alpha, beta = pi
            beta_dist = Beta(alpha, beta)
            log_prob = beta_dist.log_prob(actions).sum(dim=1, keepdim=True)
            entropy = beta_dist.entropy().mean()
    return log_prob, entropy
コード例 #12
0
ファイル: gaussian.py プロジェクト: BY571/pytorchrl
    def evaluate_pred(self, x, pred):
        """
        Return log prob of `pred` under the distribution generated from
        x (obs features). Also return entropy of the generated distribution.

        Parameters
        ----------
        x : torch.tensor
            obs feature map obtained from a policy_net.
        pred : torch.tensor
            Prediction to evaluate.

        Returns
        -------
        logp : torch.tensor
            Log probability of `pred` according to the predicted distribution.
        entropy_dist : torch.tensor
            Entropy of the predicted distribution.
        """

        # Predict distribution parameters

        mean = self.mean(x)
        logstd = self.log_std(x) if self.predict_log_std else torch.zeros(
            mean.size()).to(x.device) + self.log_std

        # logstd = torch.clamp(logstd, LOG_STD_MIN, LOG_STD_MAX)

        # Create distribution
        dist = Normal(mean, logstd.exp())

        # Evaluate log prob under dist
        logp = dist.log_prob(pred).sum(-1, keepdim=True)

        # Distribution entropy
        entropy_dist = dist.entropy().sum(-1).mean()

        return logp, entropy_dist
コード例 #13
0
ファイル: agent.py プロジェクト: Ashish017/CASNET
class Agent(nn.Module):
    def __init__(self, envs):
        super(Agent, self).__init__()

        self.envs = envs
        self.fc = nn.Sequential(
            nn.Linear((settings.max_links * p.link_dims) + 2,
                      p.fc1_dims),  #2 is the goal dimensions
            nn.Tanh(),
            nn.Linear(p.fc1_dims, p.fc2_dims),
            nn.Tanh(),
        )
        self.action_layer = nn.Linear(p.fc2_dims, settings.max_links)
        self.value_layer = nn.Linear(p.fc2_dims, 1)
        self.logstd = nn.Parameter(torch.zeros(settings.max_links))

    def init_weights(self):
        for m in self.fc:
            if hasattr(m, 'weight') or hasattr(m, 'bias'):
                for name, param in m.named_parameters():
                    if name == "weight":
                        nn.init.orthogonal_(
                            param, gain=nn.init.calculate_gain('tanh'))
                    if name == "bias":
                        nn.init.constant_(param, 0.0)

        for name, param in self.value_layer.named_parameters():
            if name == "weight":
                nn.init.orthogonal_(param, gain=0.01)
            if name == "bias":
                nn.init.constant_(param, 0.0)

        for name, param in self.action_layer.named_parameters():
            if name == "weight":
                nn.init.orthogonal_(param, gain=0.01)
            if name == "bias":
                nn.init.constant_(param, 0.0)

        nn.init.constant_(self.logstd, 0.0)

    def forward(self, obs, goals):
        batch_size = obs.shape[0]
        fc_input = torch.cat((obs, goals), 1)
        fc_out = self.fc(fc_input)

        mean_actions = self.action_layer(fc_out)
        logstd = self.logstd.view(1, self.logstd.shape[0])
        logstd = self.logstd.repeat(batch_size, 1)
        std = torch.exp(logstd)

        self.pd = Normal(mean_actions, std)
        self.v = self.value_layer(fc_out).view(-1)

    def step(self, obs, goals):
        with torch.no_grad():
            self.forward(obs, goals)
            act = self.pd.sample()
            neglogp = torch.sum(-self.pd.log_prob(act), dim=1)
        return act.cpu(), self.v.cpu(), neglogp.cpu()

    def statistics(self, obs, goals, actions):
        obs = obs.view(obs.shape[0] * obs.shape[1], obs.shape[2])
        goals = goals.view(goals.shape[0] * goals.shape[1], goals.shape[2])
        actions = actions.view(actions.shape[0] * actions.shape[1],
                               actions.shape[2])

        self.forward(obs, goals)

        neglogps = torch.sum(-self.pd.log_prob(actions), dim=1)
        entropies = torch.sum(self.pd.entropy(), dim=1)
        values = self.v

        neglogps = neglogps.view(p.batch_size,
                                 int(neglogps.shape[0] / p.batch_size))
        entropies = entropies.view(p.batch_size,
                                   int(entropies.shape[0] / p.batch_size))
        values = values.view(p.batch_size, int(values.shape[0] / p.batch_size))

        return neglogps, entropies, values
コード例 #14
0
 def entropy(self, x):
     mean, log_std, std = self.forward(x)
     m = Normal(mean, std)
     return m.entropy().mean()
コード例 #15
0
ファイル: a2c_continous.py プロジェクト: etendue/move37
        L_critic = F.smooth_l1_loss(predicted_states_v, target_states_v)

        optimizer_critic.zero_grad()
        L_critic.backward()
        optimizer_critic.step()

        # actor loss
        mu, var = actor(states_t)
        m = Normal(mu, var)
        log_probs_t = m.log_prob(actions_t)
        advantages_t = (target_states_v - predicted_states_v).detach()
        J_actor = (advantages_t.unsqueeze(-1) * log_probs_t).mean()

        # entropy
        entropy = m.entropy().mean()
        L_actor = -J_actor - entropy * BETA

        optimizer_actor.zero_grad()
        L_actor.backward()
        optimizer_actor.step()

        # removing invalid experience
        while not t_queue.empty():
            t_queue.get()

        # start worker again
        working_event.set()

        # smooth update target
        for target_param, new_param in zip(critic_target.parameters(),
コード例 #16
0
        global_step += 1
        obs[step] = next_obs.copy()

        # ALGO LOGIC: put action logic here
        logits, std = pg.forward(obs[step:step + 1])
        values[step] = vf.forward(obs[step:step + 1])

        # ALGO LOGIC: `env.action_space` specific logic
        probs = Normal(logits, std)
        action = probs.sample()
        clipped_action = torch.clamp(
            action, torch.min(torch.Tensor(env.action_space.low)),
            torch.min(torch.Tensor(env.action_space.high)))
        actions[step], neglogprobs[step], entropys[
            step] = clipped_action.tolist(
            )[0], -probs.log_prob(action).sum(), probs.entropy().sum()

        # TRY NOT TO MODIFY: execute the game and log data.
        next_obs, rewards[step], dones[step], _ = env.step(actions[step])
        next_obs = np.array(next_obs)
        if dones[step]:
            break

    # ALGO LOGIC: training.
    # calculate the discounted rewards, or namely, returns
    returns = np.zeros_like(rewards)
    for t in reversed(range(rewards.shape[0] - 1)):
        returns[t] = rewards[t] + args.gamma * returns[t + 1] * (1 - dones[t])
    # advantages are returns - baseline, value estimates in our case
    advantages = returns - values.detach().cpu().numpy()
コード例 #17
0
class Agent(nn.Module):

	def __init__(self, env):
		super(Agent, self).__init__()

		self.fc = nn.Sequential(
									nn.Linear(env.observation_space.shape[0], 64),
									nn.Tanh(),
									nn.Linear(64, 64),
									nn.Tanh(),
								)
		self.action_layer = nn.Linear(64, env.action_space.shape[0])
		self.value_layer = nn.Linear(64, 1)
		self.logstd = nn.Parameter(torch.zeros(env.action_space.shape[0]))
		
		self.init_weights()

	def init_weights(self):
		for m in self.fc:
			if hasattr(m, 'weight') or hasattr(m, 'bias'):
				for name, param in m.named_parameters():
					if name == "weight":
						nn.init.orthogonal_(param, gain=nn.init.calculate_gain('tanh'))
					if name == "bias":
						nn.init.constant_(param, 0.0)

		for name, param in self.value_layer.named_parameters():
			if name == "weight":
				nn.init.orthogonal_(param, gain=0.01)
			if name == "bias":
				nn.init.constant_(param, 0.0)

		for name, param in self.action_layer.named_parameters():
			if name == "weight":
				nn.init.orthogonal_(param, gain=0.01)
			if name == "bias":
				nn.init.constant_(param, 0.0)

		nn.init.constant_(self.logstd, 0.0)

	def forward(self, obs):
		fc_out = self.fc(obs)
		self.mean_actions = self.action_layer(fc_out)
		log_std = self.logstd
		std = torch.exp(log_std)
		self.pd = Normal(self.mean_actions, std)
		self.v = self.value_layer(fc_out)

	def step(self, obs):
		with torch.no_grad():
			self.forward(obs)
			act = self.pd.sample()
			neglogp = torch.sum(-self.pd.log_prob(act)).view(1)
		return act, self.v, neglogp

	def statistics(self, obs, actions):
		self.forward(obs)
		neglogps = torch.sum(-self.pd.log_prob(actions), dim=1)
		entropies = torch.sum(self.pd.entropy(), dim=1)
		values = self.v
		neglogps = neglogps.view(obs.shape[0],1)
		entropies = torch.mean(entropies).view(1)

		return neglogps, entropies, values
コード例 #18
0
class MlpAgent(nn.Module):
    def __init__(self, ob_space, ac_space):  #pylint: disable=W0613

        super(MlpAgent, self).__init__()
        self.initial_state = None
        self.output_shape = ac_space
        self.input_shape = ob_space

        self.pi_h = nn.Sequential(
            nn.Linear(ob_space, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
        )
        self.mean = nn.Linear(64, ac_space)
        self.logstd = nn.Parameter(torch.zeros(ac_space))

        self.vf_h = nn.Sequential(
            nn.Linear(ob_space, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
        )
        self.vf = nn.Linear(64, 1)

        self.init_()
        self.pd = None
        self.v = None

    def init_(self):
        for m in self.pi_h:
            if hasattr(m, 'weight') or hasattr(m, 'bias'):
                for name, param in m.named_parameters():
                    if name == 'weight':
                        nn.init.orthogonal_(
                            param, gain=nn.init.calculate_gain('relu'))
                    if name == 'bias':
                        nn.init.constant_(param, 0.0)

        for name, param in self.mean.named_parameters():
            if name == 'weight':
                nn.init.orthogonal_(param, gain=0.01)
            if name == 'bias':
                nn.init.constant_(param, 0.0)

        nn.init.constant_(self.logstd, 0.0)

        for m in self.vf_h:
            if hasattr(m, 'weight') or hasattr(m, 'bias'):
                for name, param in m.named_parameters():
                    if name == 'weight':
                        nn.init.orthogonal_(
                            param, gain=nn.init.calculate_gain('relu'))
                    if name == 'bias':
                        nn.init.constant_(param, 0.0)

        for name, param in self.vf.named_parameters():
            if name == 'weight':
                nn.init.orthogonal_(param, gain=1.0)
            if name == 'bias':
                nn.init.constant_(param, 0.0)

    def forward(self, obs):
        mean = self.mean(self.pi_h(obs))
        logstd = self.logstd
        std = torch.exp(logstd)
        self.pd = Normal(mean, std)

        vf = self.vf(self.vf_h(obs))
        self.v = torch.sum(
            vf, dim=1
        )  # vf has a shape of (nenv, 1), sum operation is just to reduct the last dimension of '1', for compatibility between nenv = 1 and nenv > 1.

    def statistics(self, obs, A):
        self.forward(obs)
        return torch.sum(-self.pd.log_prob(A),
                         dim=1), torch.sum(self.pd.entropy(), dim=1), self.v

    def step(self, obs):
        if isinstance(obs, np.ndarray):
            obs = torch.tensor(obs).float()
        with torch.no_grad():
            self.forward(obs)
            act = self.pd.sample()
            neglogp = torch.sum(-self.pd.log_prob(act), dim=1)
        return act.numpy(), self.v.numpy(), self.initial_state, neglogp.numpy()
コード例 #19
0
        # ALGO LOGIC: put action logic here
        logits, std = pg.forward(obs[step:step + 1])
        values[step] = vf.forward(obs[step:step + 1])

        # ALGO LOGIC: `env.action_space` specific logic
        if isinstance(env.action_space, Box):
            probs = Normal(logits, std)
            action = probs.sample()

            # action squashing. The reparamaterization trick
            action = torch.tanh(action)
            action *= env.action_space.high[0]
            #print(pg.fc1.weight.grad)

            # clipped_action = torch.clamp(action, torch.min(torch.Tensor(env.action_space.low)), torch.min(torch.Tensor(env.action_space.high)))
            actions[step], entropys[step] = action.tolist()[0], probs.entropy(
            ).sum()

        # TRY NOT TO MODIFY: execute the game and log data.
        next_obs, rewards[step], dones[step], _ = env.step(action.tolist()[0])
        next_obs = np.array(next_obs)
        er.add(obs[step], actions[step], rewards[step], next_obs, dones[step])
        if dones[step]:
            break

        # ALGO LOGIC: training.
        if len(er._storage) > 2000:
            s_obs, s_actions, s_rewards, s_next_obses, s_dones = er.sample(
                args.batch_size)
            # soft value loss
            # TODO: Importantly, we do not use actions from the replay buffer here:
            # these actions are sampled fresh from the current version of the policy.
コード例 #20
0
class Agent(nn.Module):
    def __init__(self, envs):
        super(Agent, self).__init__()

        self.envs = envs
        self.seq_lens = envs.envs_seq_lens
        self.robot_encoder = nn.GRU(p.link_dims,
                                    p.encoded_robot_dims,
                                    batch_first=True)
        self.fc = nn.Sequential(
            nn.Linear(p.encoded_robot_dims + 2,
                      p.fc1_dims),  #2 is the goal dimensions
            nn.Tanh(),
            nn.Linear(p.fc1_dims, p.fc2_dims),
            nn.Tanh(),
        )
        self.action_decoder = nn.GRU(p.encoded_robot_dims + p.fc2_dims,
                                     1,
                                     batch_first=True)
        self.value_layer = nn.Linear(p.fc2_dims, 1)
        self.logstd = nn.Parameter(torch.zeros(settings.max_links))

    def init_weights(self):
        for m in self.fc:
            if hasattr(m, 'weight') or hasattr(m, 'bias'):
                for name, param in m.named_parameters():
                    if name == "weight":
                        nn.init.orthogonal_(
                            param, gain=nn.init.calculate_gain('tanh'))
                    if name == "bias":
                        nn.init.constant_(param, 0.0)

        for name, param in self.value_layer.named_parameters():
            if name == "weight":
                nn.init.orthogonal_(param, gain=0.01)
            if name == "bias":
                nn.init.constant_(param, 0.0)

        for name, param in self.robot_encoder.named_parameters():
            if "bias" in name:
                nn.init.constant_(param, 0.0)
            elif "weight" in name:
                nn.init.xavier_normal_(param)

        for name, param in self.action_decoder.named_parameters():
            if "bias" in name:
                nn.init.constant_(param, 0.0)
            elif "weight" in name:
                nn.init.xavier_normal_(param)

        nn.init.constant_(self.logstd, 0.0)

    def init_hidden_robot_encoder(self, batch_size):
        return torch.zeros(1, batch_size, p.encoded_robot_dims).cuda()

    def init_hidden_action_encoder(self, batch_size):
        return torch.zeros(1, batch_size, 1).cuda()

    def forward(self, obs, goals):
        goals = goals.cuda()
        obs = obs.view(obs.shape[0], settings.max_links,
                       int(obs.shape[1] / settings.max_links)).cuda()
        batch_size = obs.shape[0]
        seq_lens = self.seq_lens * int(batch_size / len(self.envs.envs))
        robot_encoder_hidden = self.init_hidden_robot_encoder(batch_size)
        obs = pack_padded_sequence(obs,
                                   seq_lens,
                                   batch_first=True,
                                   enforce_sorted=False)
        encoded_robots_states, encoded_robots = self.robot_encoder(
            obs, robot_encoder_hidden)
        encoded_robots_states, _ = pad_packed_sequence(encoded_robots_states,
                                                       batch_first=True)
        fc_input = torch.cat((goals,
                              encoded_robots.view(encoded_robots.shape[1],
                                                  encoded_robots.shape[2])),
                             dim=1)
        fc_out = self.fc(fc_input)
        fc_out.shaped = fc_out.view(fc_out.shape[0], 1, fc_out.shape[1])
        decoder_input = torch.cat(
            (encoded_robots_states,
             fc_out.shaped.repeat(1, encoded_robots_states.shape[1], 1)),
            dim=2)
        mean_actions, _ = self.action_decoder(decoder_input)
        mean_actions = mean_actions.view(mean_actions.shape[0],
                                         mean_actions.shape[1])

        logstd = self.logstd[:encoded_robots_states.shape[1]]
        logstd = logstd.view(1, logstd.shape[0])
        logstd = logstd.repeat(batch_size, 1)
        std = torch.exp(logstd)

        self.pd = Normal(mean_actions, std)
        self.v = self.value_layer(fc_out).view(-1)

    def step(self, obs, goals):
        with torch.no_grad():
            self.forward(obs, goals)
            act = self.pd.sample()
            neglogp = torch.sum(-self.pd.log_prob(act), dim=1)
        return act.cpu(), self.v.cpu(), neglogp.cpu()

    def statistics(self, obs, goals, actions):
        obs = obs.view(obs.shape[0] * obs.shape[1], obs.shape[2])
        goals = goals.view(goals.shape[0] * goals.shape[1], goals.shape[2])
        actions = actions.view(actions.shape[0] * actions.shape[1],
                               actions.shape[2])

        self.forward(obs, goals)

        neglogps = torch.sum(-self.pd.log_prob(actions), dim=1)
        entropies = torch.sum(self.pd.entropy(), dim=1)
        values = self.v

        neglogps = neglogps.view(p.batch_size,
                                 int(neglogps.shape[0] / p.batch_size))
        entropies = entropies.view(p.batch_size,
                                   int(entropies.shape[0] / p.batch_size))
        values = values.view(p.batch_size, int(values.shape[0] / p.batch_size))

        return neglogps, entropies, values
コード例 #21
0
ファイル: a2c_gae.py プロジェクト: aditimavalankar/rl_basic
def main():
    args = parser.parse_args()
    num_training_steps = args.train_steps
    lr = args.learning_rate
    gamma = args.discount_factor
    n_test_episodes = args.n_test_episodes
    checkpoint_file = args.resume
    test_only = args.test_only
    env_name = args.environment
    seed = args.seed
    batch_size = args.batch_size
    horizon = args.horizon
    lam = args.gae
    visualize = args.visualize
    entropy_coeff = args.entropy_coeff
    use_lr_decay = args.use_lr_decay

    env = gym.make(env_name)
    set_global_seed(seed)
    env.seed(seed)

    input_shape = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]

    net = Network(input_shape, action_dim).to(device)
    total_steps = 0
    total_episodes = 0

    optimizer = Adam(net.parameters(), lr=lr)
    adv_rms = RunningMeanStd(dim=1)
    return_rms = RunningMeanStd(dim=1)
    state_rms = RunningMeanStd(dim=input_shape)

    if checkpoint_file:
        (total_steps, total_episodes, net, optimizer, state_info, adv_info,
         return_info) = load_checkpoint(checkpoint_file, net, optimizer,
                                        'state', 'adv', 'return')
        state_mean, state_var, state_min, state_max = state_info
        adv_mean, adv_var, adv_min, adv_max = adv_info
        return_mean, return_var, return_min, return_max = return_info
        state_rms.set_state(state_mean, state_var, state_min, state_max,
                            total_steps)
        adv_rms.set_state(adv_mean, adv_var, adv_min, adv_max, total_steps)
        return_rms.set_state(return_mean, return_var, return_min, return_max,
                             total_steps)

    checkpoint_dir = os.path.join(env_name, 'a2c_checkpoints_lr2e-3-b32-decay')
    if not os.path.isdir(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    if test_only:
        avg_reward = test(env, action_dim, net, state_rms, n_test_episodes,
                          visualize)
        print('Average episode reward:', avg_reward)
        return

    # Summary writer for tensorboardX
    writer = {}
    writer['writer'] = SummaryWriter()

    s = env.reset()

    reward_buf = []
    ep_reward = 0
    ep_len = 0
    niter = 0
    done = False

    mean_indices = torch.LongTensor([2 * x for x in range(action_dim)])
    logstd_indices = torch.LongTensor([2 * x + 1 for x in range(action_dim)])
    mean_indices = mean_indices.to(device)
    logstd_indices = logstd_indices.to(device)

    prev_best = 0

    total_epochs = int(num_training_steps / batch_size) + 1

    while total_steps < num_training_steps:
        values = []
        rewards = []
        dones = []
        logps = []
        entropies = []
        niter += 1
        for _ in range(batch_size):
            s = state_rms.normalize(s, mode=MEAN_STD)
            out, v = net(prepare_input(s))
            mean = torch.index_select(out, 0, mean_indices)
            logstd = torch.index_select(out, 0, logstd_indices)
            action_dist = Normal(mean, torch.exp(logstd))
            a = action_dist.sample()
            s, r, done, _ = env.step(a.cpu().numpy())
            logp = action_dist.log_prob(a)
            entropy = action_dist.entropy()
            ep_reward += r
            ep_len += 1
            total_steps += 1

            if done:
                writer['iter'] = total_steps + 1
                writer['writer'].add_scalar('data/ep_reward', ep_reward,
                                            total_steps)
                writer['writer'].add_scalar('data/ep_len', ep_len, total_steps)
                reward_buf.append(ep_reward)
                ep_reward = 0
                ep_len = 0
                total_episodes += 1
                if len(reward_buf) > 100:
                    reward_buf = reward_buf[-100:]
                done = False
                s = env.reset()

            values.append(v)
            rewards.append(r)
            dones.append(done)
            logps.append(logp)
            entropies.append(entropy.sum())

        policy_loss, value_loss = batch_actor_critic(logps, rewards, values,
                                                     dones, gamma, lam,
                                                     horizon, adv_rms,
                                                     return_rms)
        optimizer.zero_grad()
        policy_entropy = torch.stack(entropies).mean()
        loss = policy_loss + 0.5 * value_loss - entropy_coeff * policy_entropy
        loss.backward()
        optimizer.step()

        if use_lr_decay:
            for param_group in optimizer.param_groups:
                lr = param_group['lr']
                param_group['lr'] = (
                    lr - lr *
                    (total_steps / num_training_steps) / total_epochs)

        writer['iter'] = total_steps
        writer['writer'].add_scalar('data/last_100_ret',
                                    np.array(reward_buf).mean(), total_steps)
        writer['writer'].add_scalar('data/policy_loss', policy_loss,
                                    total_steps)
        writer['writer'].add_scalar('data/value_loss', value_loss, total_steps)
        writer['writer'].add_scalar('data/loss', loss, total_steps)

        print(total_episodes, 'episodes,', total_steps, 'steps,',
              np.array(reward_buf).mean(), 'reward')

        save_checkpoint(
            {
                'total_steps':
                total_steps,
                'total_episodes':
                total_episodes,
                'state_dict':
                net.state_dict(),
                'optimizer':
                optimizer.state_dict(),
                'state':
                [state_rms.mean, state_rms.var, state_rms.min, state_rms.max],
                'adv': [adv_rms.mean, adv_rms.var, adv_rms.min, adv_rms.max],
                'return': [
                    return_rms.mean, return_rms.var, return_rms.min,
                    return_rms.max
                ]
            },
            filename=os.path.join(checkpoint_dir,
                                  str(niter) + '.pth.tar'))

        if np.array(reward_buf).mean() > prev_best:
            save_checkpoint(
                {
                    'total_steps': total_steps,
                    'total_episodes': total_episodes,
                    'state_dict': net.state_dict(),
                    'optimizer': optimizer.state_dict(),
                },
                filename=os.path.join(checkpoint_dir, 'best.pth.tar'))