def surrogate_loss(actor, advants, states, old_policy, actions, batch_index): mu, std = actor(states) new_policy = log_prob_density(actions, mu, std) old_policy = old_policy[batch_index] ratio = torch.exp(new_policy - old_policy) surrogate_loss = ratio * advants return surrogate_loss, ratio
def train_actor_critic(actor, critic, memory, actor_optim, critic_optim, args): memory = np.array(memory) states = np.vstack(memory[:, 0]) actions = list(memory[:, 1]) rewards = list(memory[:, 2]) masks = list(memory[:, 3]) old_values = critic(torch.Tensor(states)) returns, advants = get_gae(rewards, masks, old_values, args) mu, std = actor(torch.Tensor(states)) old_policy = log_prob_density(torch.Tensor(actions), mu, std) criterion = torch.nn.MSELoss() n = len(states) arr = np.arange(n) for _ in range(args.actor_critic_update_num): np.random.shuffle(arr) for i in range(n // args.batch_size): batch_index = arr[args.batch_size * i : args.batch_size * (i + 1)] batch_index = torch.LongTensor(batch_index) inputs = torch.Tensor(states)[batch_index] actions_samples = torch.Tensor(actions)[batch_index] returns_samples = returns.unsqueeze(1)[batch_index] advants_samples = advants.unsqueeze(1)[batch_index] oldvalue_samples = old_values[batch_index].detach() values = critic(inputs) clipped_values = oldvalue_samples + \ torch.clamp(values - oldvalue_samples, -args.clip_param, args.clip_param) critic_loss1 = criterion(clipped_values, returns_samples) critic_loss2 = criterion(values, returns_samples) critic_loss = torch.max(critic_loss1, critic_loss2).mean() loss, ratio, entropy = surrogate_loss(actor, advants_samples, inputs, old_policy.detach(), actions_samples, batch_index) clipped_ratio = torch.clamp(ratio, 1.0 - args.clip_param, 1.0 + args.clip_param) clipped_loss = clipped_ratio * advants_samples actor_loss = -torch.min(loss, clipped_loss).mean() loss = actor_loss + 0.5 * critic_loss - 0.001 * entropy critic_optim.zero_grad() loss.backward(retain_graph=True) critic_optim.step() actor_optim.zero_grad() loss.backward() actor_optim.step()
def train_model(actor, critic, memory, actor_optim, critic_optim, args): memory = np.array(memory) states = np.vstack(memory[:, 0]) actions = list(memory[:, 1]) rewards = list(memory[:, 2]) masks = list(memory[:, 3]) old_values = critic(torch.Tensor(states)) returns, advants = get_gae(rewards, masks, old_values, args) mu, std = actor(torch.Tensor(states)) old_policy = log_prob_density(torch.Tensor(actions), mu, std) criterion = torch.nn.MSELoss() n = len(states) arr = np.arange(n) for _ in range(args.model_update_num): np.random.shuffle(arr) for i in range(n // args.batch_size): batch_index = arr[args.batch_size * i : args.batch_size * (i + 1)] batch_index = torch.LongTensor(batch_index) inputs = torch.Tensor(states)[batch_index] actions_samples = torch.Tensor(actions)[batch_index] returns_samples = returns.unsqueeze(1)[batch_index] advants_samples = advants.unsqueeze(1)[batch_index] oldvalue_samples = old_values[batch_index].detach() values = critic(inputs) clipped_values = oldvalue_samples + \ torch.clamp(values - oldvalue_samples, -args.clip_param, args.clip_param) critic_loss1 = criterion(clipped_values, returns_samples) critic_loss2 = criterion(values, returns_samples) critic_loss = torch.max(critic_loss1, critic_loss2).mean() loss, ratio = surrogate_loss(actor, advants_samples, inputs, old_policy.detach(), actions_samples, batch_index) clipped_ratio = torch.clamp(ratio, 1.0 - args.clip_param, 1.0 + args.clip_param) clipped_loss = clipped_ratio * advants_samples actor_loss = -torch.min(loss, clipped_loss).mean() loss = actor_loss + 0.5 * critic_loss critic_optim.zero_grad() loss.backward(retain_graph=True) critic_optim.step() actor_optim.zero_grad() loss.backward() actor_optim.step()