def fit_value(self, states): """Updates critic parameters. Args: states: Batch of states. Returns: Dictionary with information to track. """ actions, log_probs = self.actor(states, sample=True, with_log_probs=True) q1, q2 = self.critic(states, actions) q = tf.minimum(q1, q2) - self.alpha * log_probs with tf.GradientTape(watch_accessed_variables=False) as tape: tape.watch(self.value.trainable_variables) v = self.value(states) value_loss = tf.losses.mean_squared_error(q, v) grads = tape.gradient(value_loss, self.value.trainable_variables) self.value_optimizer.apply_gradients( zip(grads, self.value.trainable_variables)) if self.value_optimizer.iterations % self.target_update_period == 0: critic.soft_update(self.value, self.value_target, tau=self.tau) return {'v': tf.reduce_mean(v), 'value_loss': value_loss}
def fit_critic(self, states, actions, next_states, rewards, discounts): """Updates critic parameters. Args: states: Batch of states. actions: Batch of actions. next_states: Batch of next states. rewards: Batch of rewards. discounts: Batch of masks indicating the end of the episodes. Returns: Dictionary with information to track. """ next_actions = self.actor(next_states, sample=True) next_q1, next_q2 = self.critic_target(next_states, next_actions) target_q = rewards + self.discount * discounts * tf.minimum( next_q1, next_q2) with tf.GradientTape(watch_accessed_variables=False) as tape: tape.watch(self.critic.trainable_variables) q1, q2 = self.critic(states, actions) policy_actions = self.actor(states, sample=True) q_pi1, q_pi2 = self.critic(states, policy_actions) def discriminator_loss(real_output, fake_output): diff = target_q - real_output def my_exp(x): return 1.0 + x + 0.5 * tf.square(x) alpha = 10.0 real_loss = tf.reduce_mean(my_exp(diff / alpha) * alpha) total_loss = real_loss + tf.reduce_mean(fake_output) return total_loss critic_loss1 = discriminator_loss(q1, q_pi1) critic_loss2 = discriminator_loss(q2, q_pi2) critic_loss = (critic_loss1 + critic_loss2) critic_grads = tape.gradient(critic_loss, self.critic.trainable_variables) self.critic_optimizer.apply_gradients( zip(critic_grads, self.critic.trainable_variables)) critic.soft_update(self.critic, self.critic_target, tau=self.tau) return { 'q1': tf.reduce_mean(q1), 'q2': tf.reduce_mean(q2), 'critic_loss': critic_loss, 'q_pi1': tf.reduce_mean(q_pi1), 'q_pi2': tf.reduce_mean(q_pi2) }
def __init__(self, observation_spec, action_spec, actor_lr=3e-4, critic_lr=3e-4, alpha_lr=3e-4, discount=0.99, tau=0.005, target_update_period=1, target_entropy=0.0, use_soft_critic=False): """Creates networks. Args: observation_spec: environment observation spec. action_spec: Action spec. actor_lr: Actor learning rate. critic_lr: Critic learning rate. alpha_lr: Temperature learning rate. discount: MDP discount. tau: Soft target update parameter. target_update_period: Target network update period. target_entropy: Target entropy. use_soft_critic: Whether to use soft critic representation. """ assert len(observation_spec.shape) == 1 state_dim = observation_spec.shape[0] self.actor = policies.DiagGuassianPolicy(state_dim, action_spec) self.actor_optimizer = tf.keras.optimizers.Adam(learning_rate=actor_lr) self.log_alpha = tf.Variable(tf.math.log(0.1), trainable=True) self.alpha_optimizer = tf.keras.optimizers.Adam(learning_rate=alpha_lr) self.target_entropy = target_entropy self.discount = discount self.tau = tau self.target_update_period = target_update_period self.value = critic.CriticNet(state_dim) self.value_target = critic.CriticNet(state_dim) critic.soft_update(self.value, self.value_target, tau=1.0) self.value_optimizer = tf.keras.optimizers.Adam( learning_rate=critic_lr) if use_soft_critic: self.critic = critic.SoftCritic(state_dim, action_spec) else: action_dim = action_spec.shape[0] self.critic = critic.Critic(state_dim, action_dim) self.critic_optimizer = tf.keras.optimizers.Adam( learning_rate=critic_lr)
def fit(self, states, actions, next_states, rewards, discounts): """Updates critic parameters. Args: states: Batch of states. actions: Batch of actions. next_states: Batch of next states. rewards: Batch of rewards. discounts: Batch of masks indicating the end of the episodes. Returns: Dictionary with information to track. """ next_v = self.value_target(next_states) target_q = rewards + self.discount * discounts * next_v all_vars = (list(self.actor.trainable_variables) + list(self.value.trainable_variables)) with tf.GradientTape(watch_accessed_variables=False, persistent=True) as tape: tape.watch(all_vars) actor_log_probs = self.actor.log_probs(states, actions) q = self.value(states) + self.alpha * actor_log_probs adv = tf.stop_gradient(target_q - q) actor_loss = -tf.reduce_mean(actor_log_probs * adv) critic_loss = tf.losses.mean_squared_error(target_q, q) actor_grads = tape.gradient(actor_loss, self.actor.trainable_variables) critic_grads = tape.gradient(critic_loss, self.value.trainable_variables) self.actor_optimizer.apply_gradients( zip(actor_grads, self.actor.trainable_variables)) self.critic_optimizer.apply_gradients( zip(critic_grads, self.value.trainable_variables)) del tape if self.critic_optimizer.iterations % self.target_update_period == 0: critic.soft_update(self.value, self.value_target, tau=self.tau) return { 'q': tf.reduce_mean(q), 'critic_loss': critic_loss, 'actor_log_probs': tf.reduce_mean(actor_log_probs), 'adv': tf.reduce_mean(adv) }
def __init__(self, observation_spec, action_spec, actor_lr=1e-4, critic_lr=3e-4, alpha_lr=1e-4, discount=0.99, tau=0.005, target_entropy=0.0): """Creates networks. Args: observation_spec: environment observation spec. action_spec: Action spec. actor_lr: Actor learning rate. critic_lr: Critic learning rate. alpha_lr: Temperature learning rate. discount: MDP discount. tau: Soft target update parameter. target_entropy: Target entropy. """ assert len(observation_spec.shape) == 1 state_dim = observation_spec.shape[0] beta_1 = 0.0 self.actor = policies.DiagGuassianPolicy(state_dim, action_spec) self.actor_optimizer = tf.keras.optimizers.Adam(learning_rate=actor_lr, beta_1=beta_1) self.log_alpha = tf.Variable(tf.math.log(0.1), trainable=True) self.alpha_optimizer = tf.keras.optimizers.Adam(learning_rate=alpha_lr, beta_1=beta_1) self.target_entropy = target_entropy self.discount = discount self.tau = tau action_dim = action_spec.shape[0] self.critic = critic.Critic(state_dim, action_dim) self.critic_target = critic.Critic(state_dim, action_dim) critic.soft_update(self.critic, self.critic_target, tau=1.0) self.critic_optimizer = tf.keras.optimizers.Adam( learning_rate=critic_lr, beta_1=beta_1)
def fit_critic(self, states, actions, next_states, rewards, discounts): """Updates critic parameters. Args: states: Batch of states. actions: Batch of actions. next_states: Batch of next states. rewards: Batch of rewards. discounts: Batch of masks indicating the end of the episodes. Returns: Dictionary with information to track. """ next_actions = self.actor(next_states, sample=True) bc_log_probs = self.bc.policy.log_probs(next_states, next_actions) next_target_q1, next_target_q2 = self.critic_target( next_states, next_actions) target_q = rewards + self.discount * discounts * (tf.minimum( next_target_q1, next_target_q2) + self.alpha * bc_log_probs) with tf.GradientTape(watch_accessed_variables=False) as tape: tape.watch(self.critic.trainable_variables) q1, q2 = self.critic(states, actions) critic_loss = (tf.losses.mean_squared_error(target_q, q1) + tf.losses.mean_squared_error(target_q, q2)) critic_grads = tape.gradient(critic_loss, self.critic.trainable_variables) self.critic_optimizer.apply_gradients( zip(critic_grads, self.critic.trainable_variables)) critic.soft_update(self.critic, self.critic_target, tau=self.tau) return { 'q1': tf.reduce_mean(q1), 'q2': tf.reduce_mean(q2), 'critic_loss': critic_loss }
def __init__(self, observation_spec, action_spec, actor_lr=3e-4, critic_lr=3e-4, alpha_lr=3e-4, discount=0.99, tau=0.005, target_entropy=0.0, f_reg=1.0, reward_bonus=5.0, num_augmentations=1, env_name='', batch_size=256): """Creates networks. Args: observation_spec: environment observation spec. action_spec: Action spec. actor_lr: Actor learning rate. critic_lr: Critic learning rate. alpha_lr: Temperature learning rate. discount: MDP discount. tau: Soft target update parameter. target_entropy: Target entropy. f_reg: Critic regularization weight. reward_bonus: Bonus added to the rewards. num_augmentations: Number of random crops env_name: Env name batch_size: batch size """ del num_augmentations, env_name assert len(observation_spec.shape) == 1 state_dim = observation_spec.shape[0] self.batch_size = batch_size hidden_dims = (256, 256, 256) self.actor = policies.DiagGuassianPolicy(state_dim, action_spec, hidden_dims=hidden_dims) self.actor_optimizer = tf.keras.optimizers.Adam(learning_rate=actor_lr) self.log_alpha = tf.Variable(tf.math.log(1.0), trainable=True) self.alpha_optimizer = tf.keras.optimizers.Adam(learning_rate=alpha_lr) self.target_entropy = target_entropy self.discount = discount self.tau = tau self.bc = behavioral_cloning.BehavioralCloning(observation_spec, action_spec, mixture=True) action_dim = action_spec.shape[0] self.critic = critic.Critic(state_dim, action_dim, hidden_dims=hidden_dims) self.critic_target = critic.Critic(state_dim, action_dim, hidden_dims=hidden_dims) critic.soft_update(self.critic, self.critic_target, tau=1.0) self.critic_optimizer = tf.keras.optimizers.Adam( learning_rate=critic_lr) self.f_reg = f_reg self.reward_bonus = reward_bonus self.model_dict = { 'critic': self.critic, 'actor': self.actor, 'critic_target': self.critic_target, 'actor_optimizer': self.actor_optimizer, 'critic_optimizer': self.critic_optimizer, 'alpha_optimizer': self.alpha_optimizer }
def fit_critic(self, states, actions, next_states, rewards, discounts): """Updates critic parameters. Args: states: Batch of states. actions: Batch of actions. next_states: Batch of next states. rewards: Batch of rewards. discounts: Batch of masks indicating the end of the episodes. Returns: Dictionary with information to track. """ next_actions = self.actor(next_states, sample=True) policy_actions = self.actor(states, sample=True) next_target_q1, next_target_q2 = self.dist_critic(next_states, next_actions, target=True) target_q = rewards + self.discount * discounts * tf.minimum( next_target_q1, next_target_q2) critic_variables = self.critic.trainable_variables with tf.GradientTape(watch_accessed_variables=False) as tape: tape.watch(critic_variables) q1, q2 = self.dist_critic(states, actions, stop_gradient=True) with tf.GradientTape(watch_accessed_variables=False, persistent=True) as tape2: tape2.watch([policy_actions]) q1_reg, q2_reg = self.critic(states, policy_actions) q1_grads = tape2.gradient(q1_reg, policy_actions) q2_grads = tape2.gradient(q2_reg, policy_actions) q1_grad_norm = tf.reduce_sum(tf.square(q1_grads), axis=-1) q2_grad_norm = tf.reduce_sum(tf.square(q2_grads), axis=-1) del tape2 q_reg = tf.reduce_mean(q1_grad_norm + q2_grad_norm) critic_loss = (tf.losses.mean_squared_error(target_q, q1) + tf.losses.mean_squared_error(target_q, q2) + self.f_reg * q_reg) critic_grads = tape.gradient(critic_loss, critic_variables) self.critic_optimizer.apply_gradients( zip(critic_grads, critic_variables)) critic.soft_update(self.critic, self.critic_target, tau=self.tau) return { 'q1': tf.reduce_mean(q1), 'q2': tf.reduce_mean(q2), 'critic_loss': critic_loss, 'q1_grad': tf.reduce_mean(q1_grad_norm), 'q2_grad': tf.reduce_mean(q2_grad_norm) }
def __init__(self, observation_spec, action_spec, actor_lr=3e-4, critic_lr=3e-4, alpha_lr=3e-4, discount=0.99, tau=0.005, target_entropy=0.0, f_reg=1.0, reward_bonus=5.0, num_augmentations=1, rep_learn_keywords='outer', env_name='', batch_size=256, n_quantiles=5, temp=0.1, num_training_levels=200, latent_dim=256, n_levels_nce=5, popart_norm_beta=0.1): """Creates networks. Args: observation_spec: environment observation spec. action_spec: Action spec. actor_lr: Actor learning rate. critic_lr: Critic learning rate. alpha_lr: Temperature learning rate. discount: MDP discount. tau: Soft target update parameter. target_entropy: Target entropy. f_reg: Critic regularization weight. reward_bonus: Bonus added to the rewards. num_augmentations: Number of DrQ augmentations (crops) rep_learn_keywords: Representation learning loss to add (see below) env_name: Env name batch_size: Batch size n_quantiles: Number of GVF quantiles temp: Temperature of NCE softmax num_training_levels: Number of training MDPs (Procgen=200) latent_dim: Latent dimensions of auxiliary MLPs n_levels_nce: Number of MDPs to use contrastive loss on popart_norm_beta: PopArt normalization constant For `rep_learn_keywords`, pick from: stop_grad_FQI: whether to stop_grad TD/FQI critic updates? linear_Q: use a linear critic? successor_features: uses ||SF|| as cumulant gvf_termination: uses +1 if done else 0 as cumulant gvf_action_count: uses state-cond. action counts as cumulant nce: uses the multi-class dot-product InfoNCE objective cce: uses MoCo Categorical CrossEntropy objective energy: uses SimCLR + pairwise GVF distance (not fully tested) If no cumulant is specified, the reward will be taken as default one. """ del actor_lr, critic_lr, alpha_lr, target_entropy self.action_spec = action_spec self.num_augmentations = num_augmentations self.rep_learn_keywords = rep_learn_keywords.split('__') self.batch_size = batch_size self.env_name = env_name self.stop_grad_fqi = 'stop_grad_FQI' in self.rep_learn_keywords critic_kwargs = {'hidden_dims': (1024, 1024)} self.latent_dim = latent_dim self.n_levels_nce = n_levels_nce hidden_dims = hidden_dims_per_level = (self.latent_dim, self.latent_dim) self.num_training_levels = int(num_training_levels) self.n_quantiles = n_quantiles self.temp = temp # Make 2 sets of weights: # - Critic # - Critic (target) # Optionally, make a 3rd set for per-level critics if observation_spec.shape == (64, 64, 3): # IMPALA for Procgen def conv_stack(): return make_impala_cnn_network(depths=[16, 32, 32], use_batch_norm=False, dropout_rate=0.) state_dim = 256 else: # Reduced architecture for DMC def conv_stack(): return ConvStack(observation_spec.shape) state_dim = 50 conv_stack_critic = conv_stack() conv_target_stack_critic = conv_stack() if observation_spec.shape == (64, 64, 3): conv_stack_critic.output_size = state_dim conv_target_stack_critic.output_size = state_dim critic_kwargs['encoder'] = ImageEncoder(conv_stack_critic, feature_dim=state_dim, bprop_conv_stack=True) # Note: the target critic does not share any weights. critic_kwargs['encoder_target'] = ImageEncoder( conv_target_stack_critic, feature_dim=state_dim, bprop_conv_stack=True) conv_stack_critic_per_level = conv_stack() conv_target_stack_critic_per_level = conv_stack() if observation_spec.shape == (64, 64, 3): conv_stack_critic_per_level.output_size = state_dim conv_target_stack_critic_per_level.output_size = state_dim self.encoder_per_level = ImageEncoder(conv_stack_critic_per_level, feature_dim=state_dim, bprop_conv_stack=True) self.encoder_per_level_target = ImageEncoder( conv_target_stack_critic_per_level, feature_dim=state_dim, bprop_conv_stack=True) criticCL.soft_update(self.encoder_per_level, self.encoder_per_level_target, tau=1.0) if self.num_augmentations == 0: dummy_state = tf.constant( np.zeros([1] + list(observation_spec.shape))) else: # account for padding of +4 everywhere and then cropping out 68 dummy_state = tf.constant(np.zeros(shape=[1, 68, 68, 3])) dummy_enc = critic_kwargs['encoder'](dummy_state) @tf.function def init_models(): """This function initializes all auxiliary networks (state and action encoders) with dummy input (Procgen-specific, 68x68x3, 15 actions). """ critic_kwargs['encoder'](dummy_state) critic_kwargs['encoder_target'](dummy_state) self.encoder_per_level(dummy_state) self.encoder_per_level_target(dummy_state) init_models() action_dim = action_spec.maximum.item() + 1 self.action_dim = action_dim self.discount = discount self.tau = tau self.reg = f_reg self.reward_bonus = reward_bonus self.critic = criticCL.Critic(state_dim, action_dim, hidden_dims=hidden_dims, encoder=critic_kwargs['encoder'], discrete_actions=True, linear='linear_Q' in self.rep_learn_keywords) self.critic_target = criticCL.Critic( state_dim, action_dim, hidden_dims=hidden_dims, encoder=critic_kwargs['encoder_target'], discrete_actions=True, linear='linear_Q' in self.rep_learn_keywords) self.critic_optimizer = tf.keras.optimizers.Adam(learning_rate=3e-4) self.task_critic_optimizer = tf.keras.optimizers.Adam( learning_rate=3e-4) self.br_optimizer = tf.keras.optimizers.Adam(learning_rate=3e-4) if 'cce' in self.rep_learn_keywords: self.classifier = tf.keras.Sequential([ tf.keras.layers.Dense(self.latent_dim, use_bias=True), tf.keras.layers.ReLU(), tf.keras.layers.Dense(self.n_quantiles, use_bias=True) ], name='classifier') elif 'nce' in self.rep_learn_keywords: self.embedding = tf.keras.Sequential([ tf.keras.layers.Dense(self.latent_dim, use_bias=True), tf.keras.layers.ReLU(), tf.keras.layers.Dense(self.latent_dim, use_bias=True) ], name='embedding') # This snipet initializes all auxiliary networks (state and action encoders) # with dummy input (Procgen-specific, 68x68x3, 15 actions). dummy_state = tf.zeros((1, 68, 68, 3), dtype=tf.float32) phi_s = self.critic.encoder(dummy_state) phi_a = tf.eye(action_dim, dtype=tf.float32) if 'linear_Q' in self.rep_learn_keywords: _ = self.critic.critic1.state_encoder(phi_s) _ = self.critic.critic2.state_encoder(phi_s) _ = self.critic.critic1.action_encoder(phi_a) _ = self.critic.critic2.action_encoder(phi_a) _ = self.critic_target.critic1.state_encoder(phi_s) _ = self.critic_target.critic2.state_encoder(phi_s) _ = self.critic_target.critic1.action_encoder(phi_a) _ = self.critic_target.critic2.action_encoder(phi_a) if 'cce' in self.rep_learn_keywords: self.classifier(phi_s) elif 'nce' in self.rep_learn_keywords: self.embedding(phi_s) self.target_critic_to_use = self.critic_target self.critic_to_use = self.critic criticCL.soft_update(self.critic, self.critic_target, tau=1.0) self.cce = tf.keras.losses.SparseCategoricalCrossentropy( reduction=tf.keras.losses.Reduction.NONE, from_logits=True) self.bc = None if 'successor_features' in self.rep_learn_keywords: self.output_dim_level = self.latent_dim elif 'gvf_termination' in self.rep_learn_keywords: self.output_dim_level = 1 elif 'gvf_action_count' in self.rep_learn_keywords: self.output_dim_level = action_dim else: self.output_dim_level = action_dim self.task_critic_one = criticCL.Critic( state_dim, self.output_dim_level * self.num_training_levels, hidden_dims=hidden_dims_per_level, encoder=None, # critic_kwargs['encoder'], discrete_actions=True, cross_norm=False) self.task_critic_target_one = criticCL.Critic( state_dim, self.output_dim_level * 200, hidden_dims=hidden_dims_per_level, encoder=None, # critic_kwargs['encoder'], discrete_actions=True, cross_norm=False) self.task_critic_one(dummy_enc, actions=None, training=False, return_features=False, stop_grad_features=False) self.task_critic_target_one(dummy_enc, actions=None, training=False, return_features=False, stop_grad_features=False) criticCL.soft_update(self.task_critic_one, self.task_critic_target_one, tau=1.0) # Normalization constant beta, set to best default value as per PopArt paper self.reward_normalizer = popart.PopArt( running_statistics.EMAMeanStd(popart_norm_beta)) self.reward_normalizer.init() if 'CLIP' in self.rep_learn_keywords or 'clip' in self.rep_learn_keywords: self.loss_temp = tf.Variable(tf.constant(0.0, dtype=tf.float32), name='loss_temp', trainable=True) self.model_dict = { 'critic': self.critic, 'critic_target': self.critic_target, 'critic_optimizer': self.critic_optimizer, 'br_optimizer': self.br_optimizer } self.model_dict['encoder_perLevel'] = self.encoder_per_level self.model_dict[ 'encoder_perLevel_target'] = self.encoder_per_level_target self.model_dict['task_critic'] = self.task_critic_one self.model_dict['task_critic_target'] = self.task_critic_target_one
def fit_task_critics(self, mb_states, mb_actions, mb_next_states, mb_next_actions, mb_rewards, mb_discounts, level_ids): """Updates per-level critic parameters. Args: mb_states: Batch of states. mb_actions: Batch of actions. mb_next_states: Batch of next states. mb_next_actions: Batch of next actions from training policy. mb_rewards: Batch of rewards. mb_discounts: Batch of masks indicating the end of the episodes. level_ids: Batch of level ids Returns: Dictionary with information to track. """ if 'popart' in self.rep_learn_keywords: # The PopArt normalization normalizes the GVF's cumulant signal so that # it's not affected by the difference in scales across MDPs. mb_rewards = self.reward_normalizer.normalize_target(mb_rewards) trainable_variables = self.encoder_per_level.trainable_variables + self.task_critic_one.trainable_variables next_action_indices = tf.stack([ tf.range(tf.shape(mb_next_actions)[0], dtype=tf.int32), level_ids * self.output_dim_level + tf.cast(mb_next_actions, dtype=tf.int32) ], axis=-1) action_indices = tf.stack([ tf.range(tf.shape(mb_actions)[0], dtype=tf.int32), level_ids * self.output_dim_level + tf.cast(mb_actions, dtype=tf.int32) ], axis=-1) level_ids = tf.stack([ tf.range(tf.shape(mb_next_actions)[0], dtype=tf.int32), tf.cast(level_ids, dtype=tf.int32) ], axis=-1) next_states = [self.encoder_per_level_target(mb_next_states[0])] next_q1, next_q2 = self.task_critic_target_one(next_states[0], actions=None) # Learn d-dimensional successor features if 'successor_features' in self.rep_learn_keywords: target_q = tf.concat( [next_states[0]] * 200, 1) + self.discount * tf.expand_dims( mb_discounts, 1) * tf.minimum(next_q1, next_q2) # Learn discounted episode termination elif 'gvf_termination' in self.rep_learn_keywords: target_q = tf.expand_dims( mb_discounts, 1) + self.discount * tf.expand_dims( mb_discounts, 1) * tf.minimum(next_q1, next_q2) # Learn discounted future action counts elif 'gvf_action_count' in self.rep_learn_keywords: target_q = tf.concat( [tf.one_hot(mb_actions, depth=self.action_dim)] * 200, 1) + self.discount * tf.expand_dims( mb_discounts, 1) * tf.minimum(next_q1, next_q2) else: target_q = tf.expand_dims( mb_rewards, 1) + self.discount * tf.expand_dims( mb_discounts, 1) * tf.minimum(next_q1, next_q2) if ('successor_features' in self.rep_learn_keywords or 'gvf_termination' in self.rep_learn_keywords or 'gvf_action_count' in self.rep_learn_keywords): target_q = tf.reshape(target_q, (-1, 200, self.output_dim_level)) target_q = tf.gather_nd(target_q, indices=level_ids) else: target_q = tf.gather_nd(target_q, indices=next_action_indices) with tf.GradientTape(watch_accessed_variables=False) as tape: tape.watch(trainable_variables) states = [self.encoder_per_level(mb_states[0])] q1_all, q2_all = self.task_critic_one(states[0], actions=None) q = tf.minimum(q1_all, q2_all) if ('successor_features' in self.rep_learn_keywords or 'gvf_termination' in self.rep_learn_keywords or 'gvf_action_count' in self.rep_learn_keywords): q1_all = tf.reshape(q1_all, (-1, 200, self.output_dim_level)) q2_all = tf.reshape(q2_all, (-1, 200, self.output_dim_level)) critic_loss = ( tf.losses.mean_squared_error( target_q, tf.gather_nd(q1_all, indices=level_ids)) + tf.losses.mean_squared_error( target_q, tf.gather_nd(q2_all, indices=level_ids))) else: critic_loss = (tf.losses.mean_squared_error( target_q, tf.gather_nd(q1_all, indices=action_indices)) + tf.losses.mean_squared_error( target_q, tf.gather_nd(q2_all, indices=action_indices))) critic_grads = tape.gradient(critic_loss, trainable_variables) self.task_critic_optimizer.apply_gradients( zip(critic_grads, trainable_variables)) criticCL.soft_update(self.encoder_per_level, self.encoder_per_level_target, tau=self.tau) criticCL.soft_update(self.task_critic_one, self.task_critic_target_one, tau=self.tau) gn = tf.reduce_mean( [tf.linalg.norm(v) for v in critic_grads if v is not None]) return { 'avg_level_critic_loss': tf.reduce_mean(critic_loss), 'avg_q': tf.reduce_mean(q), 'level_critic_grad_norm': gn }
def fit_critic(self, states, actions, next_states, next_actions, rewards, discounts): """Updates critic parameters. Args: states: Batch of states. actions: Batch of actions. next_states: Batch of next states. next_actions: Batch of next actions from training policy. rewards: Batch of rewards. discounts: Batch of masks indicating the end of the episodes. Returns: Dictionary with information to track. """ action_indices = tf.stack( [tf.range(tf.shape(actions)[0], dtype=tf.int64), actions], axis=-1) next_action_indices = tf.stack([ tf.range(tf.shape(next_actions)[0], dtype=tf.int64), next_actions ], axis=-1) if self.num_augmentations > 1: target_q = 0. for i in range(self.num_augmentations): next_q1_i, next_q2_i = self.critic_target( next_states[i], actions=None, stop_grad_features=self.stop_grad_fqi) target_q_i = tf.expand_dims( rewards, 1) + self.discount * tf.expand_dims( discounts, 1) * tf.minimum(next_q1_i, next_q2_i) target_q += target_q_i target_q /= self.num_augmentations elif self.num_augmentations == 1: next_q1, next_q2 = self.critic_target( next_states[0], actions=None, stop_grad_features=self.stop_grad_fqi) target_q = tf.expand_dims( rewards, 1) + self.discount * tf.expand_dims( discounts, 1) * tf.minimum(next_q1, next_q2) else: next_q1, next_q2 = self.target_critic_to_use( next_states, actions=None, stop_grad_features=self.stop_grad_fqi) target_q = tf.expand_dims( rewards, 1) + self.discount * tf.expand_dims( discounts, 1) * tf.minimum(next_q1, next_q2) target_q = tf.gather_nd(target_q, indices=next_action_indices) trainable_variables = self.critic.trainable_variables with tf.GradientTape(watch_accessed_variables=False) as tape: tape.watch(trainable_variables) if self.num_augmentations > 1: critic_loss = 0. q1 = 0. q2 = 0. for i in range(self.num_augmentations): q1_i, q2_i = self.critic_to_use( states[i], actions=None, stop_grad_features=self.stop_grad_fqi) critic_loss_i = (tf.losses.mean_squared_error( target_q, tf.gather_nd(q1_i, indices=action_indices)) + tf.losses.mean_squared_error( target_q, tf.gather_nd(q2_i, indices=action_indices))) q1 += q1_i q2 += q2_i critic_loss += critic_loss_i q1 /= self.num_augmentations q2 /= self.num_augmentations critic_loss /= self.num_augmentations elif self.num_augmentations == 1: q1, q2 = self.critic_to_use( states[0], actions=None, stop_grad_features=self.stop_grad_fqi) q = tf.minimum(q1, q2) critic_loss = ( tf.losses.mean_squared_error( target_q, tf.gather_nd(q1, indices=action_indices)) + tf.losses.mean_squared_error( target_q, tf.gather_nd(q2, indices=action_indices))) else: q1, q2 = self.critic_to_use( states, actions=None, stop_grad_features=self.stop_grad_fqi) q = tf.minimum(q1, q2) critic_loss = ( tf.losses.mean_squared_error( target_q, tf.gather_nd(q1, indices=action_indices)) + tf.losses.mean_squared_error( target_q, tf.gather_nd(q2, indices=action_indices))) # LSE from CQL cql_logsumexp = tf.reduce_logsumexp(q, 1) cql_loss = tf.reduce_mean(cql_logsumexp - tf.gather_nd(q, indices=action_indices)) # Jointly optimize both losses critic_loss = critic_loss + cql_loss critic_grads = tape.gradient(critic_loss, trainable_variables) self.critic_optimizer.apply_gradients( zip(critic_grads, trainable_variables)) criticCL.soft_update(self.critic, self.critic_target, tau=self.tau) gn = tf.reduce_mean( [tf.linalg.norm(v) for v in critic_grads if v is not None]) return { 'q1': tf.reduce_mean(q1), 'q2': tf.reduce_mean(q2), 'critic_loss': critic_loss, 'cql_loss': cql_loss, 'critic_grad_norm': gn }
self.log_alpha = tf.Variable(tf.math.log(1.0), trainable=True) self.log_cql_alpha = self.log_alpha self.alpha_optimizer = tf.keras.optimizers.Adam(learning_rate=actor_lr) self.critic = critic.Critic(state_dim, action_dim, hidden_dims=hidden_dims, encoder=critic_kwargs['encoder'], discrete_actions=True) self.critic_target = critic.Critic( state_dim, action_dim, hidden_dims=hidden_dims, encoder=critic_kwargs['encoder_target'], discrete_actions=True) critic.soft_update(self.critic, self.critic_target, tau=1.0) self.critic_optimizer = tf.keras.optimizers.Adam( learning_rate=critic_lr) self.tau = tau self.reg = reg self.target_entropy = target_entropy self.discount = discount self.num_cql_actions = num_cql_actions self.bc_pretraining_steps = bc_pretraining_steps self.min_q_weight = min_q_weight self.bc = None self.model_dict = {
def __init__(self, observation_spec, action_spec, embedding_dim=256, hidden_dims=(256, 256), sequence_length=2, learning_rate=None, discount=0.95, target_update_period=1000, num_augmentations=0, rep_learn_keywords='outer', batch_size=256): """Creates networks. Args: observation_spec: State spec. action_spec: Action spec. embedding_dim: Embedding size. hidden_dims: List of hidden dimensions. sequence_length: Expected length of sequences provided as input. learning_rate: Learning rate. discount: discount factor. target_update_period: How frequently update target? num_augmentations: Number of DrQ random crops. rep_learn_keywords: Representation learning loss to add. batch_size: batch size. """ super().__init__() action_dim = action_spec.maximum.item() + 1 self.observation_spec = observation_spec self.action_dim = action_dim self.action_spec = action_spec self.embedding_dim = embedding_dim self.sequence_length = sequence_length self.discount = discount self.tau = 0.005 self.discount = 0.99 self.target_update_period = target_update_period self.num_augmentations = num_augmentations self.rep_learn_keywords = rep_learn_keywords.split('__') self.batch_size = batch_size critic_kwargs = {} if observation_spec.shape == (64, 64, 3): # IMPALA for Procgen def conv_stack(): return make_impala_cnn_network(depths=[16, 32, 32], use_batch_norm=False, dropout_rate=0.) state_dim = 256 else: # Reduced architecture for DMC def conv_stack(): return ConvStack(observation_spec.shape) state_dim = 50 conv_stack_critic = conv_stack() conv_target_stack_critic = conv_stack() if observation_spec.shape == (64, 64, 3): conv_stack_critic.output_size = state_dim conv_target_stack_critic.output_size = state_dim critic_kwargs['encoder'] = ImageEncoder(conv_stack_critic, feature_dim=state_dim, bprop_conv_stack=True) critic_kwargs['encoder_target'] = ImageEncoder( conv_target_stack_critic, feature_dim=state_dim, bprop_conv_stack=True) self.embedder = tf_utils.EmbedNet(state_dim, embedding_dim=self.embedding_dim, hidden_dims=hidden_dims) self.f_value = tf_utils.create_mlp(self.embedding_dim, 1, hidden_dims=hidden_dims, activation=tf.nn.swish) self.f_value_target = tf_utils.create_mlp(self.embedding_dim, 1, hidden_dims=hidden_dims, activation=tf.nn.swish) self.f_trans = tf_utils.create_mlp(self.embedding_dim + self.embedding_dim, self.embedding_dim, hidden_dims=hidden_dims, activation=tf.nn.swish) self.f_out = tf_utils.create_mlp(self.embedding_dim + self.embedding_dim, 2, hidden_dims=hidden_dims, activation=tf.nn.swish) self.action_encoder = tf.keras.Sequential( [ tf.keras.layers.Dense( self.embedding_dim, use_bias=True ), # , kernel_regularizer=tf.keras.regularizers.l2(WEIGHT_DECAY) tf.keras.layers.ReLU(), tf.keras.layers.Dense(self.embedding_dim) ], name='action_encoder') if self.num_augmentations == 0: dummy_state = tf.constant( np.zeros(shape=[1] + list(observation_spec.shape))) self.obs_spec = list(observation_spec.shape) else: # account for padding of +4 everywhere and then cropping out 68 dummy_state = tf.constant(np.zeros(shape=[1, 68, 68, 3])) self.obs_spec = [68, 68, 3] @tf.function def init_models(): critic_kwargs['encoder'](dummy_state) critic_kwargs['encoder_target'](dummy_state) self.action_encoder( tf.cast(tf.one_hot([1], depth=action_dim), tf.float32)) init_models() self.critic = critic.Critic(state_dim, action_dim, hidden_dims=hidden_dims, encoder=critic_kwargs['encoder'], discrete_actions=True, linear='linear_Q' in self.rep_learn_keywords) self.critic_target = critic.Critic( state_dim, action_dim, hidden_dims=hidden_dims, encoder=critic_kwargs['encoder_target'], discrete_actions=True, linear='linear_Q' in self.rep_learn_keywords) @tf.function def init_models2(): dummy_state = tf.zeros((1, 68, 68, 3), dtype=tf.float32) phi_s = self.critic.encoder(dummy_state) phi_a = tf.eye(15, dtype=tf.float32) if 'linear_Q' in self.rep_learn_keywords: _ = self.critic.critic1.state_encoder(phi_s) _ = self.critic.critic2.state_encoder(phi_s) _ = self.critic.critic1.action_encoder(phi_a) _ = self.critic.critic2.action_encoder(phi_a) _ = self.critic_target.critic1.state_encoder(phi_s) _ = self.critic_target.critic2.state_encoder(phi_s) _ = self.critic_target.critic1.action_encoder(phi_a) _ = self.critic_target.critic2.action_encoder(phi_a) init_models2() critic.soft_update(self.critic, self.critic_target, tau=1.0) critic.soft_update(self.f_value, self.f_value_target, tau=1.0) learning_rate = learning_rate or 1e-4 self.optimizer = tf.keras.optimizers.Adam(learning_rate=3e-4) self.critic_optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3) self.all_variables = (self.embedder.trainable_variables + self.f_value.trainable_variables + self.f_value_target.trainable_variables + self.f_trans.trainable_variables + self.f_out.trainable_variables + self.critic.trainable_variables + self.critic_target.trainable_variables) self.model_dict = { 'action_encoder': self.action_encoder, 'f_out': self.f_out, 'f_trans': self.f_trans, 'f_value_target': self.f_value_target, 'f_value': self.f_value, 'embedder': self.embedder, 'critic': self.critic, 'critic_target': self.critic_target, 'critic_optimizer': self.critic_optimizer, 'optimizer': self.optimizer }
def fit_embedding(self, states, actions, next_states, next_actions, rewards, discounts): """Updates critic parameters. Args: states: Batch of states. actions: Batch of actions. next_states: Batch of next states. next_actions: batch of next actions rewards: Batch of rewards. discounts: Batch of masks indicating the end of the episodes. Returns: Dictionary with information to track. """ states = tf.transpose( tf.stack([states, next_states])[:, 0], (1, 0, 2, 3, 4)) batch_size = tf.shape(states)[0] with tf.GradientTape(watch_accessed_variables=False) as tape: tape.watch(self.all_variables) actions = tf.transpose( tf.one_hot(tf.stack([actions, next_actions]), depth=self.action_dim), (1, 0, 2)) actions = tf.reshape( actions, [batch_size * self.sequence_length, self.action_dim]) actions = self.action_encoder(actions) actions = tf.reshape( actions, [batch_size, self.sequence_length, self.embedding_dim]) all_states = tf.reshape( states, [batch_size * self.sequence_length] + self.obs_spec) all_features = self.critic.encoder(all_states) all_embeddings = self.embedder(all_features, stop_gradient=False) embeddings = tf.reshape( all_embeddings, [batch_size, self.sequence_length, self.embedding_dim])[:, 0, :] all_pred_values = [] all_pred_rewards = [] all_pred_discounts = [] for idx in range(self.sequence_length): pred_value = self.f_value(embeddings)[Ellipsis, 0] pred_reward, pred_discount = tf.unstack(self.f_out( tf.concat([embeddings, actions[:, idx, :]], -1)), axis=-1) pred_embeddings = embeddings + self.f_trans( tf.concat([embeddings, actions[:, idx, :]], -1)) all_pred_values.append(pred_value) all_pred_rewards.append(pred_reward) all_pred_discounts.append(pred_discount) embeddings = pred_embeddings last_value = tf.stop_gradient( self.f_value_target(embeddings)[Ellipsis, 0]) / (1 - self.discount) all_true_values = [] # for idx in range(self.sequence_length - 1, -1, -1): value = self.discount * discounts * last_value + rewards #[:, idx] all_true_values.append(value) last_value = value all_true_values = all_true_values[::-1] reward_error = tf.stack(all_pred_rewards, -1)[:, 0] - rewards value_error = tf.stack( all_pred_values, -1) - (1 - self.discount) * tf.stack(all_true_values, -1) reward_loss = tf.reduce_sum(tf.math.square(reward_error), -1) value_loss = tf.reduce_sum(tf.math.square(value_error), -1) loss = tf.reduce_mean(reward_loss + value_loss) grads = tape.gradient(loss, self.all_variables) self.optimizer.apply_gradients(zip(grads, self.all_variables)) if self.optimizer.iterations % self.target_update_period == 0: critic.soft_update(self.f_value, self.f_value_target, tau=self.tau) return { 'embed_loss': loss, 'reward_loss': tf.reduce_mean(reward_loss), 'value_loss': tf.reduce_mean(value_loss), }
def fit_critic(self, states, actions, next_states, next_actions, rewards, discounts): """Updates critic parameters. Args: states: Batch of states. actions: Batch of actions. next_states: Batch of next states. next_actions: Batch of next actions from training policy. rewards: Batch of rewards. discounts: Batch of masks indicating the end of the episodes. Returns: Dictionary with information to track. """ next_q1, next_q2 = self.critic_target(next_states, next_actions) target_q = rewards + self.discount * discounts * tf.minimum( next_q1, next_q2) with tf.GradientTape(watch_accessed_variables=False) as tape: tape.watch(self.critic.trainable_variables) q1, q2 = self.critic(states, actions) critic_loss = (tf.losses.mean_squared_error(target_q, q1) + tf.losses.mean_squared_error(target_q, q2)) n_states = tf.repeat(states[tf.newaxis, :, :], self.num_cql_actions, 0) n_states = tf.reshape(n_states, [-1, n_states.get_shape()[-1]]) n_rand_actions = tf.random.uniform( [tf.shape(n_states)[0], actions.get_shape()[-1]], self.actor.action_spec.minimum, self.actor.action_spec.maximum) n_actions, n_log_probs = self.actor(n_states, sample=True, with_log_probs=True) q1_rand, q2_rand = self.critic(n_states, n_rand_actions) q1_curr_actions, q2_curr_actions = self.critic(n_states, n_actions) log_u = -tf.reduce_mean( tf.repeat((tf.math.log(2.0 * self.actor.action_scale) * n_rand_actions.shape[-1])[tf.newaxis, :], tf.shape(n_states)[0], 0), 1) log_probs_all = tf.concat([n_log_probs, log_u], 0) q1_all = tf.concat([q1_curr_actions, q1_rand], 0) q2_all = tf.concat([q2_curr_actions, q2_rand], 0) def get_qf_loss(q, log_probs): q -= log_probs q = tf.reshape(q, [-1, tf.shape(states)[0]]) return tf.math.reduce_logsumexp(q, axis=0) min_qf1_loss = get_qf_loss(q1_all, log_probs_all) min_qf2_loss = get_qf_loss(q2_all, log_probs_all) cql_loss = tf.reduce_mean((min_qf1_loss - q1) + (min_qf2_loss - q2)) critic_loss += self.min_q_weight * cql_loss critic_grads = tape.gradient(critic_loss, self.critic.trainable_variables) self.critic_optimizer.apply_gradients( zip(critic_grads, self.critic.trainable_variables)) critic.soft_update(self.critic, self.critic_target, tau=self.tau) return { 'q1': tf.reduce_mean(q1), 'q2': tf.reduce_mean(q2), 'critic_loss': critic_loss, 'cql_loss': cql_loss }
def __init__(self, observation_spec, action_spec, actor_lr=3e-4, critic_lr=3e-4, alpha_lr=3e-4, discount=0.99, tau=0.005, target_entropy=0.0, f_reg=1.0, reward_bonus=5.0, num_augmentations=1, env_name='', batch_size=256): """Creates networks. Args: observation_spec: environment observation spec. action_spec: Action spec. actor_lr: Actor learning rate. critic_lr: Critic learning rate. alpha_lr: Temperature learning rate. discount: MDP discount. tau: Soft target update parameter. target_entropy: Target entropy. f_reg: Critic regularization weight. reward_bonus: Bonus added to the rewards. num_augmentations: Number of DrQ augmentations (crops) env_name: Env name batch_size: Batch size """ self.num_augmentations = num_augmentations self.discrete_actions = False if len(action_spec.shape) else True self.batch_size = batch_size actor_kwargs = {'hidden_dims': (1024, 1024)} critic_kwargs = {'hidden_dims': (1024, 1024)} # DRQ encoder params. # https://github.com/denisyarats/drq/blob/master/config.yaml#L73 # Make 4 sets of weights: # - BC # - Actor # - Critic # - Critic (target) if observation_spec.shape == (64, 64, 3): # IMPALA for Procgen def conv_stack(): return make_impala_cnn_network(depths=[16, 32, 32], use_batch_norm=False, dropout_rate=0.) state_dim = 256 else: # Reduced architecture for DMC def conv_stack(): return ConvStack(observation_spec.shape) state_dim = 50 conv_stack_bc = conv_stack() conv_stack_actor = conv_stack() conv_stack_critic = conv_stack() conv_target_stack_critic = conv_stack() if observation_spec.shape == (64, 64, 3): conv_stack_bc.output_size = state_dim conv_stack_actor.output_size = state_dim conv_stack_critic.output_size = state_dim conv_target_stack_critic.output_size = state_dim # Combine and stop_grad some of the above conv stacks actor_kwargs['encoder_bc'] = ImageEncoder(conv_stack_bc, feature_dim=state_dim, bprop_conv_stack=True) actor_kwargs['encoder'] = ImageEncoder(conv_stack_critic, feature_dim=state_dim, bprop_conv_stack=False) critic_kwargs['encoder'] = ImageEncoder(conv_stack_critic, feature_dim=state_dim, bprop_conv_stack=True) # Note: the target critic does not share any weights. critic_kwargs['encoder_target'] = ImageEncoder( conv_target_stack_critic, feature_dim=state_dim, bprop_conv_stack=True) if self.num_augmentations == 0: dummy_state = tf.constant( np.zeros(shape=[1] + list(observation_spec.shape))) else: # account for padding of +4 everywhere and then cropping out 68 dummy_state = tf.constant(np.zeros(shape=[1, 68, 68, 3])) @tf.function def init_models(): actor_kwargs['encoder_bc'](dummy_state) actor_kwargs['encoder'](dummy_state) critic_kwargs['encoder'](dummy_state) critic_kwargs['encoder_target'](dummy_state) init_models() if self.discrete_actions: hidden_dims = () self.actor = policies.CategoricalPolicy( state_dim, action_spec, hidden_dims=hidden_dims, encoder=actor_kwargs['encoder']) action_dim = action_spec.maximum.item() + 1 else: hidden_dims = (256, 256, 256) self.actor = policies.DiagGuassianPolicy( state_dim, action_spec, hidden_dims=hidden_dims, encoder=actor_kwargs['encoder']) action_dim = action_spec.shape[0] self.action_dim = action_dim self.actor_optimizer = tf.keras.optimizers.Adam(learning_rate=actor_lr) self.log_alpha = tf.Variable(tf.math.log(1.0), trainable=True) self.alpha_optimizer = tf.keras.optimizers.Adam(learning_rate=alpha_lr) self.target_entropy = target_entropy self.discount = discount self.tau = tau self.bc = behavioral_cloning.BehavioralCloning( observation_spec, action_spec, mixture=True, encoder=actor_kwargs['encoder_bc'], num_augmentations=self.num_augmentations, env_name=env_name, batch_size=batch_size) self.critic = critic.Critic(state_dim, action_dim, hidden_dims=hidden_dims, encoder=critic_kwargs['encoder']) self.critic_target = critic.Critic( state_dim, action_dim, hidden_dims=hidden_dims, encoder=critic_kwargs['encoder_target']) critic.soft_update(self.critic, self.critic_target, tau=1.0) self.critic_optimizer = tf.keras.optimizers.Adam( learning_rate=critic_lr) self.f_reg = f_reg self.reward_bonus = reward_bonus self.model_dict = { 'critic': self.critic, 'critic_target': self.critic_target, 'actor': self.actor, 'bc': self.bc, 'critic_optimizer': self.critic_optimizer, 'alpha_optimizer': self.alpha_optimizer, 'actor_optimizer': self.actor_optimizer }
def __init__( self, observation_spec, action_spec, embedding_dim = 256, num_distributions=None, hidden_dims = (256, 256), sequence_length = 2, learning_rate=None, latent_dim = 256, reward_weight = 1.0, forward_weight = 1.0, # Predict last state given prev actions/states. inverse_weight = 1.0, # Predict last action given states. state_prediction_mode = 'energy', num_augmentations = 0, rep_learn_keywords = 'outer', batch_size = 256): """Creates networks. Args: observation_spec: State spec. action_spec: Action spec. embedding_dim: Embedding size. num_distributions: Number of categorical distributions for discrete embedding. hidden_dims: List of hidden dimensions. sequence_length: Expected length of sequences provided as input learning_rate: Learning rate. latent_dim: Dimension of the latent variable. reward_weight: Weight on the reward loss. forward_weight: Weight on the forward loss. inverse_weight: Weight on the inverse loss. state_prediction_mode: One of ['latent', 'energy']. num_augmentations: Num of random crops rep_learn_keywords: Representation learning loss to add. batch_size: Batch size """ super().__init__() action_dim = action_spec.maximum.item() + 1 self.observation_spec = observation_spec self.action_dim = action_dim self.action_spec = action_spec self.embedding_dim = embedding_dim self.num_distributions = num_distributions self.sequence_length = sequence_length self.latent_dim = latent_dim self.reward_weight = reward_weight self.forward_weight = forward_weight self.inverse_weight = inverse_weight self.state_prediction_mode = state_prediction_mode self.num_augmentations = num_augmentations self.rep_learn_keywords = rep_learn_keywords.split('__') self.batch_size = batch_size self.tau = 0.005 self.discount = 0.99 critic_kwargs = {} if observation_spec.shape == (64, 64, 3): # IMPALA for Procgen def conv_stack(): return make_impala_cnn_network( depths=[16, 32, 32], use_batch_norm=False, dropout_rate=0.) state_dim = 256 else: # Reduced architecture for DMC def conv_stack(): return ConvStack(observation_spec.shape) state_dim = 50 conv_stack_critic = conv_stack() conv_target_stack_critic = conv_stack() if observation_spec.shape == (64, 64, 3): conv_stack_critic.output_size = state_dim conv_target_stack_critic.output_size = state_dim critic_kwargs['encoder'] = ImageEncoder( conv_stack_critic, feature_dim=state_dim, bprop_conv_stack=True) critic_kwargs['encoder_target'] = ImageEncoder( conv_target_stack_critic, feature_dim=state_dim, bprop_conv_stack=True) self.embedder = tf_utils.EmbedNet( state_dim, embedding_dim=self.embedding_dim, num_distributions=self.num_distributions, hidden_dims=hidden_dims) if self.sequence_length > 2: self.latent_embedder = tf_utils.RNNEmbedNet( [self.sequence_length - 2, self.embedding_dim + self.embedding_dim], embedding_dim=self.latent_dim) self.reward_decoder = tf_utils.EmbedNet( self.latent_dim + self.embedding_dim + self.embedding_dim, embedding_dim=1, hidden_dims=hidden_dims) forward_decoder_out = ( self.embedding_dim if (self.state_prediction_mode in ['latent', 'energy']) else self.input_dim) forward_decoder_dists = ( self.num_distributions if (self.state_prediction_mode in ['latent', 'energy']) else None) self.forward_decoder = tf_utils.StochasticEmbedNet( self.latent_dim + self.embedding_dim + self.embedding_dim, embedding_dim=forward_decoder_out, num_distributions=forward_decoder_dists, hidden_dims=hidden_dims) self.weight = tf.Variable(tf.eye(self.embedding_dim)) self.action_encoder = tf.keras.Sequential( [ tf.keras.layers.Dense( self.embedding_dim, use_bias=True ), # , kernel_regularizer=tf.keras.regularizers.l2(WEIGHT_DECAY) tf.keras.layers.ReLU(), tf.keras.layers.Dense(self.embedding_dim) ], name='action_encoder') if self.num_augmentations == 0: dummy_state = tf.constant( np.zeros(shape=[1] + list(observation_spec.shape))) self.obs_spec = list(observation_spec.shape) else: # account for padding of +4 everywhere and then cropping out 68 dummy_state = tf.constant(np.zeros(shape=[1, 68, 68, 3])) self.obs_spec = [68, 68, 3] @tf.function def init_models(): critic_kwargs['encoder'](dummy_state) critic_kwargs['encoder_target'](dummy_state) self.action_encoder( tf.cast(tf.one_hot([1], depth=action_dim), tf.float32)) init_models() hidden_dims = (256, 256) # self.actor = policies.CategoricalPolicy(state_dim, action_spec, # hidden_dims=hidden_dims, encoder=actor_kwargs['encoder']) self.critic = critic.Critic( state_dim, action_dim, hidden_dims=hidden_dims, encoder=critic_kwargs['encoder'], discrete_actions=True, linear='linear_Q' in self.rep_learn_keywords) self.critic_target = critic.Critic( state_dim, action_dim, hidden_dims=hidden_dims, encoder=critic_kwargs['encoder_target'], discrete_actions=True, linear='linear_Q' in self.rep_learn_keywords) @tf.function def init_models2(): dummy_state = tf.zeros((1, 68, 68, 3), dtype=tf.float32) phi_s = self.critic.encoder(dummy_state) phi_a = tf.eye(15, dtype=tf.float32) if 'linear_Q' in self.rep_learn_keywords: _ = self.critic.critic1.state_encoder(phi_s) _ = self.critic.critic2.state_encoder(phi_s) _ = self.critic.critic1.action_encoder(phi_a) _ = self.critic.critic2.action_encoder(phi_a) _ = self.critic_target.critic1.state_encoder(phi_s) _ = self.critic_target.critic2.state_encoder(phi_s) _ = self.critic_target.critic1.action_encoder(phi_a) _ = self.critic_target.critic2.action_encoder(phi_a) init_models2() critic.soft_update(self.critic, self.critic_target, tau=1.0) learning_rate = learning_rate or 1e-4 self.optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate) self.critic_optimizer = tf.keras.optimizers.Adam( learning_rate=learning_rate) self.all_variables = ( self.embedder.trainable_variables + self.reward_decoder.trainable_variables + self.forward_decoder.trainable_variables + self.action_encoder.trainable_variables + self.critic.trainable_variables + self.critic_target.trainable_variables) self.model_dict = { 'action_encoder': self.action_encoder, 'weight': self.weight, 'forward_decoder': self.forward_decoder, 'reward_decoder': self.reward_decoder, 'embedder': self.embedder, 'critic': self.critic, 'critic_target': self.critic_target, 'critic_optimizer': self.critic_optimizer, 'optimizer': self.optimizer }