def __init__(self, env, batch_size=256, gamma=0.99, tau=0.005, actor_lr=3e-4, critic_lr=3e-4, alpha_lr=3e-4): #Environment self.env = env state_dim = env.observation_space.shape[0] action_dim = env.action_space.shape[0] #Hyperparameters self.batch_size = batch_size self.gamma = gamma self.tau = tau #Entropy self.alpha = 1 self.target_entropy = -np.prod(env.action_space.shape).item() # heuristic value self.log_alpha = torch.zeros(1, requires_grad=True, device="cuda") self.alpha_optimizer = optim.Adam([self.log_alpha], lr=alpha_lr) #Networks self.Q1 = SoftQNetwork(state_dim, action_dim).cuda() self.Q1_target = SoftQNetwork(state_dim, action_dim).cuda() self.Q1_target.load_state_dict(self.Q1.state_dict()) self.Q1_optimizer = optim.Adam(self.Q1.parameters(), lr=critic_lr) self.Q2 = SoftQNetwork(state_dim, action_dim).cuda() self.Q2_target = SoftQNetwork(state_dim, action_dim).cuda() self.Q2_target.load_state_dict(self.Q2.state_dict()) self.Q2_optimizer = optim.Adam(self.Q2.parameters(), lr=critic_lr) self.actor = PolicyNetwork(state_dim, action_dim).cuda() self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=actor_lr) self.loss_function = torch.nn.MSELoss() self.replay_buffer = ReplayBuffer()
def __init__(self, env_id: str, config: Config, pid: int = None, epsilon: float = 0., summary_writer: tf.summary.SummaryWriter = None): self.env_id = env_id self.config = config self.pid = pid self.epsilon = epsilon self.summary_writer = summary_writer self.action_space = gym.make(self.env_id).action_space.n self.preprocess_func = util.get_preprocess_func(env_name=self.env_id) self.buffer = EpisodeBuffer(seqlen=self.config.sequence_length) self.world_model = WorldModel(config) self.wm_optimizer = tf.keras.optimizers.Adam(lr=self.config.lr_world, epsilon=1e-4) self.policy = PolicyNetwork(action_space=self.action_space) self.policy_optimizer = tf.keras.optimizers.Adam( lr=self.config.lr_actor, epsilon=1e-5) self.value = ValueNetwork(action_space=self.action_space) self.target_value = ValueNetwork(action_space=self.action_space) self.value_optimizer = tf.keras.optimizers.Adam( lr=self.config.lr_critic, epsilon=1e-5) self.setup()
def __init__(self, observation_space, action_space, args): """ Constructor :param observation_space: observation space of the environment :param action_space: action space of the environment :param args: command line args to set hyperparameters """ # set hyperparameters self.device = torch.device( "cuda" if torch.cuda.is_available() else "cpu") self.gamma = args.gamma self.state_dim = observation_space.shape[0] self.action_dim = action_space.shape[0] self.hidden_dim = args.hidden_units self.tau = args.tau self.lr = args.lr self.target_update_interval = args.target_update_interval # build and initialize networks self.q_net_1 = QNetwork(self.state_dim, self.action_dim, self.hidden_dim).to(self.device) self.q_net_2 = QNetwork(self.state_dim, self.action_dim, self.hidden_dim).to(self.device) self.target_q_net_1 = QNetwork(self.state_dim, self.action_dim, self.hidden_dim).to(self.device) self.target_q_net_2 = QNetwork(self.state_dim, self.action_dim, self.hidden_dim).to(self.device) hard_update(self.q_net_1, self.target_q_net_1) hard_update(self.q_net_2, self.target_q_net_2) self.policy_net = PolicyNetwork(self.state_dim, self.action_dim, self.hidden_dim, self.device).to(self.device) # build criterions and optimizers self.q1_criterion = nn.MSELoss() self.q2_criterion = nn.MSELoss() self.q1_optim = optim.Adam(self.q_net_1.parameters(), lr=self.lr) self.q2_optim = optim.Adam(self.q_net_2.parameters(), lr=self.lr) self.policy_optim = optim.Adam(self.policy_net.parameters(), lr=self.lr) # for optimizing alpha (see Harnojaa et al. section 5) if args.initial_alpha is not None: self.alpha = torch.tensor(args.initial_alpha, requires_grad=True, device=self.device, dtype=torch.float) else: self.alpha = torch.rand(1, requires_grad=True, device=self.device, dtype=torch.float) if args.entropy_target is not None: self.entropy_target = torch.tensor(args.target_alpha, device=self.device, dtype=torch.float) else: self.entropy_target = -1. * torch.tensor( action_space.shape, device=self.device, dtype=torch.float) self.alpha_optim = optim.Adam([self.alpha], lr=self.lr)
class SAC: """ A class used to represent a SAC agent Attributes ---------- device : cuda or cpu the device on which all the computation occurs gamma : float[0,1] discount factor state_dim : int dimension of the environment observation space action_dim : int dimension of the environment action space hidden_dim : int dimension of the hidden layers of the networks tau : float[0,1] coefficient of soft update of target networks lr : float learning rate of the optimizers target_update_interval : int number of updates in between soft updates of target networks q_net_1 : QNetwork soft Q value network 1 q_net_2 : QNetwork soft Q value network 2 target_q_net_1 : QNetwork target Q value network 1 target_q_net_2 : QNetwork target Q value network 2 policy_net : PolicyNetwork policy network q1_criterion : torch optimization criterion for q_net_1 q2_criterion : torch optimization criterion for q_net_2 q1_optim : torch optimizer for q_net_1 q2_optim : torch optimizer for q_net_2 policy_optim : torch optimizer for policy_net alpha : torch float scalar entropy temperature (controls policy stochasticity) entropy_target : torch float scalar entropy target for the environment (see Haarnoja et al. Section 5) Methods ------- update(replay_buffer, batch_size, updates) : q1_loss, q2_loss, policy_loss, alpha_loss Performs a gradient step of the algorithm, optimizing Q networks and policy network and optimizing alpha choose_action(state) : action Returns the appropriate action in given state according to current policy save_networks_parameters(params_dir) Saves the relevant parameters (q1_net's, q2_net's, policy_net's, alpha) from the networks load_networks_parameters(params_dir) Loads the relevant parameters (q1_net's, q2_net's, policy_net's, alpha) into the networks """ def __init__(self, observation_space, action_space, args): """ Constructor :param observation_space: observation space of the environment :param action_space: action space of the environment :param args: command line args to set hyperparameters """ # set hyperparameters self.device = torch.device( "cuda" if torch.cuda.is_available() else "cpu") self.gamma = args.gamma self.state_dim = observation_space.shape[0] self.action_dim = action_space.shape[0] self.hidden_dim = args.hidden_units self.tau = args.tau self.lr = args.lr self.target_update_interval = args.target_update_interval # build and initialize networks self.q_net_1 = QNetwork(self.state_dim, self.action_dim, self.hidden_dim).to(self.device) self.q_net_2 = QNetwork(self.state_dim, self.action_dim, self.hidden_dim).to(self.device) self.target_q_net_1 = QNetwork(self.state_dim, self.action_dim, self.hidden_dim).to(self.device) self.target_q_net_2 = QNetwork(self.state_dim, self.action_dim, self.hidden_dim).to(self.device) hard_update(self.q_net_1, self.target_q_net_1) hard_update(self.q_net_2, self.target_q_net_2) self.policy_net = PolicyNetwork(self.state_dim, self.action_dim, self.hidden_dim, self.device).to(self.device) # build criterions and optimizers self.q1_criterion = nn.MSELoss() self.q2_criterion = nn.MSELoss() self.q1_optim = optim.Adam(self.q_net_1.parameters(), lr=self.lr) self.q2_optim = optim.Adam(self.q_net_2.parameters(), lr=self.lr) self.policy_optim = optim.Adam(self.policy_net.parameters(), lr=self.lr) # for optimizing alpha (see Harnojaa et al. section 5) if args.initial_alpha is not None: self.alpha = torch.tensor(args.initial_alpha, requires_grad=True, device=self.device, dtype=torch.float) else: self.alpha = torch.rand(1, requires_grad=True, device=self.device, dtype=torch.float) if args.entropy_target is not None: self.entropy_target = torch.tensor(args.target_alpha, device=self.device, dtype=torch.float) else: self.entropy_target = -1. * torch.tensor( action_space.shape, device=self.device, dtype=torch.float) self.alpha_optim = optim.Adam([self.alpha], lr=self.lr) def update(self, replay_buffer, batch_size, updates): """ Performs a gradient step of the algorithm, optimizing Q networks and policy network and optimizing alpha :param replay_buffer: replay buffer to sample batches of transitions from :param batch_size: size of the batches :param updates: number of updates so far :return: losses of the four optimizers (q1_optim, q2_optim, policy_optim, alpha_optim) :rtype: tuple of torch scalar floats """ # sample a transition batch from replay buffer and cast it to tensor of the correct shape state_batch, action_batch, reward_batch, next_state_batch, done_batch = replay_buffer.sample( batch_size) state_batch = torch.from_numpy(state_batch).to(self.device, dtype=torch.float) next_state_batch = torch.from_numpy(next_state_batch).to( self.device, dtype=torch.float) action_batch = torch.from_numpy(action_batch).to(self.device, dtype=torch.float) reward_batch = torch.from_numpy(reward_batch).unsqueeze(1).to( self.device, dtype=torch.float) done_batch = torch.from_numpy(np.float32(done_batch)).unsqueeze(1).to( self.device, dtype=torch.float) # sample actions from the policy to be used for expectations updates sampled_action, log_prob, epsilon, mean, log_std = self.policy_net.sample( state_batch) ### evaluation step target_next_value = torch.min( self.target_q_net_1(next_state_batch, sampled_action), self.target_q_net_2(next_state_batch, sampled_action)) - self.alpha * log_prob current_q_value_1 = self.q_net_1(state_batch, action_batch) current_q_value_2 = self.q_net_2(state_batch, action_batch) expected_next_value = reward_batch + ( 1 - done_batch) * self.gamma * target_next_value q1_loss = self.q1_criterion(current_q_value_1, expected_next_value.detach()) q2_loss = self.q2_criterion(current_q_value_2, expected_next_value.detach()) # optimize q1 and q1 nets self.q1_optim.zero_grad() q1_loss.backward() self.q1_optim.step() self.q2_optim.zero_grad() q2_loss.backward() self.q2_optim.step() ### improvement step sampled_q_value = torch.min(self.q_net_1(state_batch, sampled_action), self.q_net_2(state_batch, sampled_action)) policy_loss = (self.alpha * log_prob - sampled_q_value).mean() # optimize policy net self.policy_optim.zero_grad() policy_loss.backward() self.policy_optim.step() # optimize alpha alpha_loss = (self.alpha * (-log_prob - self.entropy_target).detach()).mean() self.alpha_optim.zero_grad() alpha_loss.backward() self.alpha_optim.step() # update Q target value if updates % self.target_update_interval == 0: soft_update(self.q_net_1, self.target_q_net_1, self.tau) soft_update(self.q_net_2, self.target_q_net_2, self.tau) return q1_loss.item(), q2_loss.item(), policy_loss.item( ), alpha_loss.item() def choose_action(self, state): """ Returns the appropriate action in given state according to current policy :param state: state :return: action :rtype numpy float array """ action = self.policy_net.get_action(state) # move to cpu, remove from gradient graph, cast to numpy return action.cpu().detach().numpy() def save_networks_parameters(self, params_dir=None): """ Saves the relevant parameters (q1_net's, q2_net's, policy_net's, alpha) from the networks :param params_dir: directory where to save parameters to (optional) :return: None """ if params_dir is None: params_dir = "SavedAgents/" # create a subfolder with current timestamp prefix = os.path.join( params_dir, datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")) if not os.path.exists(prefix): os.makedirs(prefix) policy_path = os.path.join(prefix, "policy_net_params") q1_path = os.path.join(prefix, "q1_net_params") q2_path = os.path.join(prefix, "q2_net_params") alpha_path = os.path.join(prefix, "alpha_param") print("Saving parameters to {}, {}, {}".format(q1_path, q2_path, policy_path)) torch.save(self.q_net_1.state_dict(), q1_path) torch.save(self.q_net_2.state_dict(), q2_path) torch.save(self.policy_net.state_dict(), policy_path) torch.save(self.alpha, alpha_path) return params_dir def load_networks_parameters(self, params_dir): """ Loads the relevant parameters (q1_net's, q2_net's, policy_net's, alpha) into the networks :param params_dir: directory where to load parameters from :return: None """ if params_dir is not None: print("Loading parameters from {}".format(params_dir)) policy_path = os.path.join(params_dir, "policy_net_params") self.policy_net.load_state_dict(torch.load(policy_path)) q1_path = os.path.join(params_dir, "q1_net_params") q2_path = os.path.join(params_dir, "q2_net_params") self.q_net_1.load_state_dict(torch.load(q1_path)) self.q_net_2.load_state_dict(torch.load(q2_path)) alpha_path = os.path.join(params_dir, "alpha_param") self.alpha = torch.load(alpha_path)
env.action_space.seed(args.seed) env.observation_space.seed(args.seed) inputShape = env.observation_space.shape[0] outputShape = env.action_space.shape[0] assert isinstance(env.action_space, gym.spaces.Box), "only continuous action space is supported" svnMain = StateValueNetwork(sess, inputShape, args.learning_rate_state_value, suffix="-main") #params are psi svnAux = StateValueNetwork(sess, inputShape, args.learning_rate_state_value, suffix="-aux") #params are psi hat pn = PolicyNetwork(sess, inputShape, outputShape, args.learning_rate_policy, args.squash) #params are phi sqf1 = SoftQNetwork(sess, inputShape, outputShape, args.learning_rate_soft_Q_function, args.alpha, suffix="1") #params are teta sqf2 = SoftQNetwork(sess, inputShape, outputShape, args.learning_rate_soft_Q_function, args.alpha, suffix="2") #params are teta init = tf.initialize_all_variables() init2 = tf.initialize_local_variables()
class SAC_Agent: def __init__(self, env, batch_size=256, gamma=0.99, tau=0.005, actor_lr=3e-4, critic_lr=3e-4, alpha_lr=3e-4): #Environment self.env = env state_dim = env.observation_space.shape[0] action_dim = env.action_space.shape[0] #Hyperparameters self.batch_size = batch_size self.gamma = gamma self.tau = tau #Entropy self.alpha = 1 self.target_entropy = -np.prod(env.action_space.shape).item() # heuristic value self.log_alpha = torch.zeros(1, requires_grad=True, device="cuda") self.alpha_optimizer = optim.Adam([self.log_alpha], lr=alpha_lr) #Networks self.Q1 = SoftQNetwork(state_dim, action_dim).cuda() self.Q1_target = SoftQNetwork(state_dim, action_dim).cuda() self.Q1_target.load_state_dict(self.Q1.state_dict()) self.Q1_optimizer = optim.Adam(self.Q1.parameters(), lr=critic_lr) self.Q2 = SoftQNetwork(state_dim, action_dim).cuda() self.Q2_target = SoftQNetwork(state_dim, action_dim).cuda() self.Q2_target.load_state_dict(self.Q2.state_dict()) self.Q2_optimizer = optim.Adam(self.Q2.parameters(), lr=critic_lr) self.actor = PolicyNetwork(state_dim, action_dim).cuda() self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=actor_lr) self.loss_function = torch.nn.MSELoss() self.replay_buffer = ReplayBuffer() def act(self, state, deterministic=True): state = torch.tensor(state, dtype=torch.float, device="cuda") mean, log_std = self.actor(state) if(deterministic): action = torch.tanh(mean) else: std = log_std.exp() normal = Normal(mean, std) z = normal.sample() action = torch.tanh(z) action = action.detach().cpu().numpy() return action def update(self, state, action, next_state, reward, done): self.replay_buffer.add_transition(state, action, next_state, reward, done) # Sample next batch and perform batch update: batch_states, batch_actions, batch_next_states, batch_rewards, batch_dones = \ self.replay_buffer.next_batch(self.batch_size) #Map to tensor batch_states = torch.tensor(batch_states, dtype=torch.float, device="cuda") #B,S_D batch_actions = torch.tensor(batch_actions, dtype=torch.float, device="cuda") #B,A_D batch_next_states = torch.tensor(batch_next_states, dtype=torch.float, device="cuda", requires_grad=False) #B,S_D batch_rewards = torch.tensor(batch_rewards, dtype=torch.float, device="cuda", requires_grad=False).unsqueeze(-1) #B,1 batch_dones = torch.tensor(batch_dones, dtype=torch.uint8, device="cuda", requires_grad=False).unsqueeze(-1) #B,1 #Policy evaluation with torch.no_grad(): policy_actions, log_pi = self.actor.sample(batch_next_states) Q1_next_target = self.Q1_target(batch_next_states, policy_actions) Q2_next_target = self.Q2_target(batch_next_states, policy_actions) Q_next_target = torch.min(Q1_next_target, Q2_next_target) td_target = batch_rewards + (1 - batch_dones) * self.gamma * (Q_next_target - self.alpha * log_pi) Q1_value = self.Q1(batch_states, batch_actions) self.Q1_optimizer.zero_grad() loss = self.loss_function(Q1_value, td_target) loss.backward() #torch.nn.utils.clip_grad_norm_(self.Q1.parameters(), 1) self.Q1_optimizer.step() Q2_value = self.Q2(batch_states, batch_actions) self.Q2_optimizer.zero_grad() loss = self.loss_function(Q2_value, td_target) loss.backward() #torch.nn.utils.clip_grad_norm_(self.Q2.parameters(), 1) self.Q2_optimizer.step() #Policy improvement policy_actions, log_pi = self.actor.sample(batch_states) Q1_value = self.Q1(batch_states, policy_actions) Q2_value = self.Q2(batch_states, policy_actions) Q_value = torch.min(Q1_value, Q2_value) self.actor_optimizer.zero_grad() loss = (self.alpha * log_pi - Q_value).mean() loss.backward() #torch.nn.utils.clip_grad_norm_(self.actor.parameters(), 1) self.actor_optimizer.step() #Update entropy parameter alpha_loss = (self.log_alpha * (-log_pi - self.target_entropy).detach()).mean() self.alpha_optimizer.zero_grad() alpha_loss.backward() self.alpha_optimizer.step() self.alpha = self.log_alpha.exp() #Update target networks soft_update(self.Q1_target, self.Q1, self.tau) soft_update(self.Q2_target, self.Q2, self.tau) def save(self, file_name): torch.save({'actor_dict': self.actor.state_dict(), 'Q1_dict' : self.Q1.state_dict(), 'Q2_dict' : self.Q2.state_dict(), }, file_name) def load(self, file_name): if os.path.isfile(file_name): print("=> loading checkpoint... ") checkpoint = torch.load(file_name) self.actor.load_state_dict(checkpoint['actor_dict']) self.Q1.load_state_dict(checkpoint['Q1_dict']) self.Q2.load_state_dict(checkpoint['Q2_dict']) print("done !") else: print("no checkpoint found...")
class DreamerV2Agent: def __init__(self, env_id: str, config: Config, pid: int = None, epsilon: float = 0., summary_writer: tf.summary.SummaryWriter = None): self.env_id = env_id self.config = config self.pid = pid self.epsilon = epsilon self.summary_writer = summary_writer self.action_space = gym.make(self.env_id).action_space.n self.preprocess_func = util.get_preprocess_func(env_name=self.env_id) self.buffer = EpisodeBuffer(seqlen=self.config.sequence_length) self.world_model = WorldModel(config) self.wm_optimizer = tf.keras.optimizers.Adam(lr=self.config.lr_world, epsilon=1e-4) self.policy = PolicyNetwork(action_space=self.action_space) self.policy_optimizer = tf.keras.optimizers.Adam( lr=self.config.lr_actor, epsilon=1e-5) self.value = ValueNetwork(action_space=self.action_space) self.target_value = ValueNetwork(action_space=self.action_space) self.value_optimizer = tf.keras.optimizers.Adam( lr=self.config.lr_critic, epsilon=1e-5) self.setup() def setup(self): """ Build network weights """ env = gym.make(self.env_id) obs = self.preprocess_func(env.reset()) prev_z, prev_h = self.world_model.get_initial_state(batch_size=1) prev_a = tf.one_hot([0], self.action_space) _outputs = self.world_model(obs, prev_z, prev_h, prev_a) (h, z_prior, z_prior_prob, z_post, z_post_prob, feat, img_out, reward_pred, disc_logit) = _outputs self.policy(feat) self.value(feat) self.target_value(feat) self.target_value.set_weights(self.value.get_weights()) def save(self, savedir=None): savedir = Path(savedir) if savedir is not None else Path( "./checkpoints") self.world_model.save_weights(str(savedir / "worldmodel")) self.policy.save_weights(str(savedir / "policy")) self.value.save_weights(str(savedir / "critic")) def load(self, loaddir=None): loaddir = Path(loaddir) if loaddir is not None else Path("checkpoints") self.world_model.load_weights(str(loaddir / "worldmodel")) self.policy.load_weights(str(loaddir / "policy")) self.value.load_weights(str(loaddir / "critic")) self.target_value.load_weights(str(loaddir / "critic")) def set_weights(self, weights): wm_weights, policy_weights, value_weights = weights self.world_model.set_weights(wm_weights) self.policy.set_weights(policy_weights) self.value.set_weights(value_weights) self.target_value.set_weights(value_weights) def get_weights(self): weights = ( self.world_model.get_weights(), self.policy.get_weights(), self.value.get_weights(), ) return weights def rollout(self, weights=None): if weights: self.set_weights(weights) env = gym.make(self.env_id) obs = self.preprocess_func(env.reset()) episode_steps, episode_rewards = 0, 0 prev_z, prev_h = self.world_model.get_initial_state(batch_size=1) prev_a = tf.convert_to_tensor([[0] * self.action_space], dtype=tf.float32) done = False lives = int(env.ale.lives()) while not done: h = self.world_model.step_h(prev_z, prev_h, prev_a) feat, z = self.world_model.get_feature(obs, h) action = self.policy.sample_action(feat, self.epsilon) action_onehot = tf.one_hot([action], self.action_space) next_frame, reward, done, info = env.step(action) next_obs = self.preprocess_func(next_frame) #: Note: DreamerV2 paper uses tanh clipping _reward = reward if reward <= 1.0 else 1.0 #: Life loss as episode end if info["ale.lives"] != lives: _done = True lives = int(info["ale.lives"]) else: _done = done #: (r_t-1, done_t-1, obs_t, action_t, done_t) self.buffer.add(obs, action_onehot, _reward, next_obs, _done, prev_z, prev_h, prev_a) #: Update states obs = next_obs prev_z, prev_h, prev_a = z, h, action_onehot episode_steps += 1 episode_rewards += reward if episode_steps > 4000: _ = self.buffer.get_episode() return self.pid, [], 0, 0 sequences = self.buffer.get_sequences() return self.pid, sequences, episode_steps, episode_rewards def update_networks(self, minibatchs): for minibatch in minibatchs: z_posts, hs, info = self.update_worldmodel(minibatch) trajectory_in_dream = self.rollout_in_dream(z_posts, hs) info_ac = self.update_actor_critic(trajectory_in_dream) info.update(info_ac) return self.get_weights(), info def update_worldmodel(self, minibatch): """ Inputs: minibatch = { "obs": (L, B, 64, 64, 1) "action": (L, B, action_space) "reward": (L, B) "done": (L, B) "prev_z": (1, B, latent_dim * n_atoms) "prev_h": (1, B, 600) "prev_a": (1, B, action_space) } Note: 1. re-compute post and prior z by unrolling sequences from initial states, obs, prev_z, prev_h and prev_action 2. Conmpute KL loss (post_z, prior_z) 3. Reconstrunction loss, reward, discount loss """ (observations, actions, rewards, next_observations, dones, prev_z, prev_h, prev_a) = minibatch.values() discounts = (1. - dones) * self.config.gamma_discount prev_z, prev_h, prev_a = prev_z[0], prev_h[0], prev_a[0] last_obs = next_observations[-1][None, ...] observations = tf.concat([observations, last_obs], axis=0) #: dummy action to avoid IndexError at last iteration last_action = tf.zeros((1, ) + actions.shape[1:]) actions = tf.concat([actions, last_action], axis=0) L = self.config.sequence_length with tf.GradientTape() as tape: hs, z_prior_probs, z_posts, z_post_probs = [], [], [], [] img_outs, r_means, disc_logits = [], [], [] for t in tf.range(L + 1): _outputs = self.world_model(observations[t], prev_z, prev_h, prev_a) (h, z_prior, z_prior_prob, z_post, z_post_prob, feat, img_out, reward_mu, disc_logit) = _outputs hs.append(h) z_prior_probs.append(z_prior_prob) z_posts.append(z_post) z_post_probs.append(z_post_prob) img_outs.append(img_out) r_means.append(reward_mu) disc_logits.append(disc_logit) prev_z, prev_h, prev_a = z_post, h, actions[t] #: Reshape outputs #: [(B, ...), (B, ...), ...] -> (L+1, B, ...) -> (L, B, ...) hs = tf.stack(hs, axis=0)[:-1] z_prior_probs = tf.stack(z_prior_probs, axis=0)[:-1] z_posts = tf.stack(z_posts, axis=0)[:-1] z_post_probs = tf.stack(z_post_probs, axis=0)[:-1] img_outs = tf.stack(img_outs, axis=0)[:-1] r_means = tf.stack(r_means, axis=0)[1:] disc_logits = tf.stack(disc_logits, axis=0)[1:] #: Compute loss terms kl_loss = self._compute_kl_loss(z_prior_probs, z_post_probs) img_log_loss = self._compute_img_log_loss(observations[:-1], img_outs) reward_log_loss = self._compute_log_loss(rewards, r_means, mode="reward") discount_log_loss = self._compute_log_loss(discounts, disc_logits, mode="discount") loss = -img_log_loss - reward_log_loss - discount_log_loss + self.config.kl_scale * kl_loss loss *= 1. / L grads = tape.gradient(loss, self.world_model.trainable_variables) grads, norm = tf.clip_by_global_norm(grads, 100.) self.wm_optimizer.apply_gradients( zip(grads, self.world_model.trainable_variables)) info = { "wm_loss": L * loss, "img_log_loss": -img_log_loss, "reward_log_loss": -reward_log_loss, "discount_log_loss": -discount_log_loss, "kl_loss": kl_loss } return z_posts, hs, info @tf.function def _compute_kl_loss(self, post_probs, prior_probs): """ Compute KL divergence between two OnehotCategorical Distributions Notes: KL[ Q(z_post) || P(z_prior) ] Q(z_prior) := Q(z | h, o) P(z_prior) := P(z | h) Scratch Impl.: qlogq = post_probs * tf.math.log(post_probs) qlogp = post_probs * tf.math.log(prior_probs) kl_div = tf.reduce_sum(qlogq - qlogp, [1, 2]) Inputs: prior_probs (L, B, latent_dim, n_atoms) post_probs (L, B, latent_dim, n_atoms) """ #: Add small value to prevent inf kl post_probs += 1e-5 prior_probs += 1e-5 #: KL Balancing: See 2.2 BEHAVIOR LEARNING Algorithm 2 kl_div1 = tfd.kl_divergence( tfd.Independent( tfd.OneHotCategorical(probs=tf.stop_gradient(post_probs)), reinterpreted_batch_ndims=1), tfd.Independent(tfd.OneHotCategorical(probs=prior_probs), reinterpreted_batch_ndims=1)) kl_div2 = tfd.kl_divergence( tfd.Independent(tfd.OneHotCategorical(probs=post_probs), reinterpreted_batch_ndims=1), tfd.Independent( tfd.OneHotCategorical(probs=tf.stop_gradient(prior_probs)), reinterpreted_batch_ndims=1)) alpha = self.config.kl_alpha kl_loss = alpha * kl_div1 + (1. - alpha) * kl_div2 #: Batch mean kl_loss = tf.reduce_mean(kl_loss) return kl_loss @tf.function def _compute_img_log_loss(self, img_in, img_out): """ Inputs: img_in: (L, B, 64, 64, 1) img_out: (L, B, 64, 64, 1) """ L, B, H, W, C = img_in.shape img_in = tf.reshape(img_in, (L * B, H * W * C)) img_out = tf.reshape(img_out, (L * B, H * W * C)) dist = tfd.Independent(tfd.Normal(loc=img_out, scale=1.)) #dist = tfd.Independent(tfd.Bernoulli(logits=img_out)) log_prob = dist.log_prob(img_in) loss = tf.reduce_mean(log_prob) return loss @tf.function def _compute_log_loss(self, y_true, y_pred, mode): """ Inputs: y_true: (L, B, 1) y_pred: (L, B, 1) mode: "reward" or "discount" """ if mode == "discount": dist = tfd.Independent(tfd.Bernoulli(logits=y_pred), reinterpreted_batch_ndims=1) elif mode == "reward": dist = tfd.Independent(tfd.Normal(loc=y_pred, scale=1.), reinterpreted_batch_ndims=1) log_prob = dist.log_prob(y_true) loss = tf.reduce_mean(log_prob) return loss def rollout_in_dream(self, z_init, h_init, video=False): """ Inputs: h_init: (L, B, 1) z_init: (L, B, latent_dim * n_atoms) done_init: (L, B, 1) """ L, B = h_init.shape[:2] horizon = self.config.imagination_horizon z, h = tf.reshape(z_init, [L * B, -1]), tf.reshape(h_init, [L * B, -1]) feats = tf.concat([z, h], axis=-1) #: s_t, a_t, s_t+1 trajectory = {"state": [], "action": [], 'next_state': []} for t in range(horizon): actions = tf.cast(self.policy.sample(feats), dtype=tf.float32) trajectory["state"].append(feats) trajectory["action"].append(actions) h = self.world_model.step_h(z, h, actions) z, _ = self.world_model.rssm.sample_z_prior(h) z = tf.reshape(z, [L * B, -1]) feats = tf.concat([z, h], axis=-1) trajectory["next_state"].append(feats) trajectory = {k: tf.stack(v, axis=0) for k, v in trajectory.items()} #: reward_head(s_t+1) -> r_t #: Distribution.mode()は確立最大値を返すのでNormalの場合は #: trjactory["reward"] == rewards rewards = self.world_model.reward_head(trajectory['next_state']) trajectory["reward"] = rewards disc_logits = self.world_model.discount_head(trajectory['next_state']) trajectory["discount"] = tfd.Independent( tfd.Bernoulli(logits=disc_logits), reinterpreted_batch_ndims=1).mean() return trajectory def update_actor_critic(self, trajectory, batch_size=512, strategy="PPO"): """ Actor-Critic update using PPO & Generalized Advantage Estimator """ #: adv: (L*B, 1) targets, weights = self.compute_target(trajectory['state'], trajectory['reward'], trajectory['next_state'], trajectory['discount']) #: (H, L*B, ...) states = trajectory['state'] selected_actions = trajectory['action'] N = weights.shape[0] * weights.shape[1] states = tf.reshape(states, [N, -1]) selected_actions = tf.reshape(selected_actions, [N, -1]) targets = tf.reshape(targets, [N, -1]) weights = tf.reshape(weights, [N, -1]) _, old_action_probs = self.policy(states) old_logprobs = tf.math.log(old_action_probs + 1e-5) for _ in range(10): indices = np.random.choice(N, batch_size) _states = tf.gather(states, indices) _targets = tf.gather(targets, indices) _selected_actions = tf.gather(selected_actions, indices) _old_logprobs = tf.gather(old_logprobs, indices) _weights = tf.gather(weights, indices) #: Update value network with tf.GradientTape() as tape1: v_pred = self.value(_states) advantages = _targets - v_pred value_loss = 0.5 * tf.square(advantages) discount_value_loss = tf.reduce_mean(value_loss * _weights) grads = tape1.gradient(discount_value_loss, self.value.trainable_variables) self.value_optimizer.apply_gradients( zip(grads, self.value.trainable_variables)) #: Update policy network if strategy == "VanillaPG": with tf.GradientTape() as tape2: _, action_probs = self.policy(_states) action_probs += 1e-5 selected_action_logprobs = tf.reduce_sum( _selected_actions * tf.math.log(action_probs), axis=1, keepdims=True) objective = selected_action_logprobs * advantages dist = tfd.Independent( tfd.OneHotCategorical(probs=action_probs), reinterpreted_batch_ndims=0) ent = dist.entropy() policy_loss = objective + self.config.ent_scale * ent[..., None] policy_loss *= -1 discounted_policy_loss = tf.reduce_mean(policy_loss * _weights) elif strategy == "PPO": with tf.GradientTape() as tape2: _, action_probs = self.policy(_states) action_probs += 1e-5 new_logprobs = tf.math.log(action_probs) ratio = tf.reduce_sum(_selected_actions * tf.exp(new_logprobs - _old_logprobs), axis=1, keepdims=True) ratio_clipped = tf.clip_by_value(ratio, 0.9, 1.1) obj_unclipped = ratio * advantages obj_clipped = ratio_clipped * advantages objective = tf.minimum(obj_unclipped, obj_clipped) dist = tfd.Independent( tfd.OneHotCategorical(probs=action_probs), reinterpreted_batch_ndims=0) ent = dist.entropy() policy_loss = objective + self.config.ent_scale * ent[..., None] policy_loss *= -1 discounted_policy_loss = tf.reduce_mean(policy_loss * _weights) grads = tape2.gradient(discounted_policy_loss, self.policy.trainable_variables) self.policy_optimizer.apply_gradients( zip(grads, self.policy.trainable_variables)) info = { "policy_loss": tf.reduce_mean(policy_loss), "objective": tf.reduce_mean(objective), "actor_entropy": tf.reduce_mean(ent), "value_loss": tf.reduce_mean(value_loss), "target_0": tf.reduce_mean(_targets), } return info def compute_target(self, states, rewards, next_states, discounts, strategy="mixed-multistep"): T, B, F = states.shape v_next = self.target_value(next_states) _weights = tf.concat([tf.ones_like(discounts[:1]), discounts[:-1]], axis=0) weights = tf.math.cumprod(_weights, axis=0) if strategy == "gae": """ HIGH-DIMENSIONAL CONTINUOUS CONTROL USING GENERALIZED ADVANTAGE ESTIMATION https://arxiv.org/pdf/1506.02438.pdf """ raise NotImplementedError() #lambda_ = self.config.lambda_gae #deltas = rewards + discounts * v_next - v #_weights = tf.concat( # [tf.ones_like(discounts[:1]), discounts[:-1] * lambda_], # axis=0) #weights = tf.math.cumprod(_weights, axis=0) #advantage = tf.reduce_sum(weights * deltas, axis=0) #v_target = advantage + v[0] elif strategy == "mixed-multistep": targets = np.zeros_like(v_next) #: (H, L*B, 1) last_value = v_next[-1] for i in reversed(range(targets.shape[0])): last_value = rewards[i] + discounts[i] * last_value targets[i] = last_value else: raise NotImplementedError() return targets, weights def testplay(self, test_id, video_dir: Path = None, weights=None): if weights: self.set_weights(weights) images = [] env = gym.make(self.env_id) obs = self.preprocess_func(env.reset()) episode_steps, episode_rewards = 0, 0 r_pred_total = 0. prev_z, prev_h = self.world_model.get_initial_state(batch_size=1) prev_a = tf.convert_to_tensor([[0] * self.action_space], dtype=tf.float32) done = False while not done: (h, z_prior, z_prior_probs, z_post, z_post_probs, feat, img_out, r_pred, discount_logit) = self.world_model(obs, prev_z, prev_h, prev_a) action = self.policy.sample_action(feat, 0) action_onehot = tf.one_hot([action], self.action_space) next_frame, reward, done, info = env.step(action) next_obs = self.preprocess_func(next_frame) #img_out = tfd.Independent(tfd.Bernoulli(logits=img_out), 3).mean() disc = tfd.Bernoulli(logits=discount_logit).mean() r_pred_total += float(r_pred) img = util.vizualize_vae(obs[0, :, :, 0], img_out.numpy()[0, :, :, 0], float(r_pred), float(disc), r_pred_total) images.append(img) #: Update states obs = next_obs prev_z, prev_h, prev_a = z_post, h, action_onehot episode_steps += 1 episode_rewards += reward #: avoiding agent freeze if episode_steps > 300 and episode_rewards < 2: break elif episode_steps > 1000 and episode_rewards < 10: break elif episode_steps > 4000: break if video_dir is not None: images[0].save(f'{video_dir}/testplay_{test_id}.gif', save_all=True, append_images=images[1:], optimize=False, duration=120, loop=0) return episode_steps, episode_rewards def testplay_in_dream(self, test_id, outdir: Path, H, weights=None): if weights: self.set_weights(weights) img_outs = [] prev_z, prev_h = self.world_model.get_initial_state(batch_size=1) prev_a = tf.convert_to_tensor([[0] * self.action_space], dtype=tf.float32) actions, rewards, discounts = [], [], [] env = gym.make(self.env_id) obs = self.preprocess_func(env.reset()) N = random.randint(2, 10) for i in range(N + H + 1): if i < N: (h, z_prior, z_prior_probs, z_post, z_post_probs, feat, img_out, r_pred, disc_logit) = self.world_model(obs, prev_z, prev_h, prev_a) discount_pred = tfd.Bernoulli(logits=disc_logit).mean() img_out = obs[0, :, :, 0] action = 1 if i == 0 else self.policy.sample_action(feat, 0) next_frame, reward, done, info = env.step(action) obs = self.preprocess_func(next_frame) z = z_post else: h = self.world_model.step_h(prev_z, prev_h, prev_a) z, _ = self.world_model.rssm.sample_z_prior(h) z = tf.reshape(z, [1, -1]) feat = tf.concat([z, h], axis=-1) img_out = self.world_model.decoder(feat) #img_out = tfd.Independent(tfd.Bernoulli(logits=img_out), 3).mean() img_out = img_out.numpy()[0, :, :, 0] r_pred = self.world_model.reward_head(feat) disc_logit = self.world_model.discount_head(feat) discount_pred = tfd.Bernoulli(logits=disc_logit).mean() action = self.policy.sample_action(feat, 0) actions.append(int(action)) rewards.append(float(r_pred)) discounts.append(float(discount_pred)) img_outs.append(img_out) action_onehot = tf.one_hot([action], self.action_space) prev_z, prev_h, prev_a = z, h, action_onehot img_outs, actions, rewards, discounts = img_outs[: -1], actions[:-1], rewards[ 1:], discounts[1:] images = util.visualize_dream(img_outs, actions, rewards, discounts) images[0].save(f'{outdir}/test_in_dream_{test_id}.gif', save_all=True, append_images=images[1:], optimize=False, duration=1000, loop=0)
def main(config): seq_length = 5 total_rewards = [] writer = SummaryWriter() is_entropy = config.entropy is_shuffle = config.shuffle ## TODO : remove config use dict only network_conf = NetworkConfig( input_size=int(config.input_size), hidden_size=int(config.hidden_size), num_steps=int(config.num_steps), action_space=int(config.action_space), learning_rate=float(config.learning_rate), beta=float(config.beta) ) trainset, valset, testset = prepare_train_test() X_train, y_train = sliding_window(trainset, seq_length) X_val, y_val = sliding_window(valset, seq_length) X_test, y_test = sliding_window(testset, seq_length) train_loader = DataLoader(X_train, y_train) val_loader = DataLoader(X_val, y_val) test_loader = DataLoader(X_test, y_test) episode = 0 policy_network = PolicyNetwork.from_dict(dict(network_conf._asdict())) print('current policy network', policy_network) while episode < N_EPISODE: initial_state = [[3, 8, 16]] logit_list = torch.empty(size=(0, network_conf.action_space)) weighted_log_prob_list = torch.empty(size=(0,), dtype=torch.float) action, log_prob, logits = policy_network.get_action(initial_state) child_network = ChildNetwork.from_dict(action) criterion = torch.nn.MSELoss() optimizer = optim.SGD(child_network.parameters(), lr=0.001, momentum=0.9) train_manager = TrainManager( model=child_network, criterion=criterion, optimizer=optimizer ) start_time = time.time() train_manager.train(train_loader, val_loader, is_shuffle=is_shuffle) elapsed = time.time() - start_time signal = train_manager.avg_validation_loss reward = reward_func2(signal) weighted_log_prob = log_prob * reward total_weighted_log_prob = torch.sum(weighted_log_prob).unsqueeze(dim=0) weighted_log_prob_list = torch.cat( (weighted_log_prob_list, total_weighted_log_prob), dim=0 ) logit_list = torch.cat((logit_list, logits), dim=0) # update the controller network policy_network.update(logit_list, weighted_log_prob_list, is_entropy) total_rewards.append(reward) #prepare metrics current_action = map(str, list(action.values())) action_str = "/".join(current_action) ActionSelection.update_selection(action_str) ActionSelection.update_reward(action_str, reward) print('current name mapping', ActionSelection.name_mapper) with open('logs/run.log', 'a') as run_file: run_file.write(json.dumps(ActionSelection.name_mapper)) run_file.write('\n') #reporting current_action = f"Action selection (Hidden size:{action['n_hidden']}, #layers {action['n_layers']}, drop_prob {action['dropout_prob']})." current_run = "Runs_{}".format(episode + 1) + " " + current_action counter = 0 for train_loss, val_loss in zip(train_manager.train_losses, train_manager.val_losses): writer.add_scalars( current_run, {"train_loss": train_loss, "val_loss": val_loss}, counter ) counter += 1 writer.add_scalar( tag="Average Return over {} episodes".format(N_EPISODE), scalar_value=np.mean(total_rewards), global_step=episode, ) if is_entropy: writer.add_scalar( tag=f"Entropy over time - BETA:{config.beta}", scalar_value=policy_network.entropy_mean, global_step=episode ) writer.add_scalar( tag="Episode runtime", scalar_value=elapsed, global_step=episode ) # # Prepare the plot plot_buf = prepare_figure(ActionSelection.action_selection, "Action Selection (n_hidden, n_layers, dropout)") #image = PIL.Image.open(plot_buf) #image = ToTensor()(image) # .unsqueeze(0) #writer.add_image("Image 1", image, episode) writer.add_figure("Image 1", plot_buf, episode) # Prepare the plot plot_buf = prepare_figure( ActionSelection.reward_distribution(), "Reward Distribution per action selection (n_hidden, n_layers, dropout)" ) #image = PIL.Image.open(plot_buf) #image = ToTensor()(image) # .unsqueeze(0) #writer.add_image("Image 2", image, episode) writer.add_figure("Image 2", plot_buf, episode) print('\n\nEpisode {} completed \n\n'.format(episode+1)) episode += 1 writer.close()