def __init__(self, observation_spec, action_spec, actor_lr = 3e-4, critic_lr = 3e-4, discount = 0.99, tau = 0.005, num_augmentations = 1): """Creates networks. Args: observation_spec: environment observation spec. action_spec: Action spec. actor_lr: Actor learning rate. critic_lr: Critic learning rate. discount: MDP discount. tau: Soft target update parameter. num_augmentations: Number of DrQ-style augmentations to perform on pixels """ self.num_augmentations = num_augmentations self.discrete_actions = False if len(action_spec.shape) else True actor_kwargs = {} critic_kwargs = {} if observation_spec.shape == (64, 64, 3): # IMPALA for Procgen def conv_stack(): return make_impala_cnn_network( depths=[16, 32, 32], use_batch_norm=False, dropout_rate=0.) state_dim = 256 else: # Reduced architecture for DMC def conv_stack(): return ConvStack(observation_spec.shape) state_dim = 50 conv_stack_actor = conv_stack() conv_stack_critic = conv_stack() conv_target_stack_critic = conv_stack() if observation_spec.shape == (64, 64, 3): conv_stack_actor.output_size = state_dim conv_stack_critic.output_size = state_dim conv_target_stack_critic.output_size = state_dim # Combine and stop_grad some of the above conv stacks actor_kwargs['encoder'] = ImageEncoder( conv_stack_actor, feature_dim=state_dim, bprop_conv_stack=True) critic_kwargs['encoder'] = ImageEncoder( conv_stack_critic, feature_dim=state_dim, bprop_conv_stack=True) # Note: the target critic does not share any weights. critic_kwargs['encoder_target'] = ImageEncoder( conv_target_stack_critic, feature_dim=state_dim, bprop_conv_stack=True) if self.num_augmentations == 0: dummy_state = tf.constant( np.zeros(shape=[1] + list(observation_spec.shape))) else: # account for padding of +4 everywhere and then cropping out 68 dummy_state = tf.constant(np.zeros(shape=[1, 68, 68, 3])) @tf.function def init_models(): actor_kwargs['encoder'](dummy_state) critic_kwargs['encoder'](dummy_state) critic_kwargs['encoder_target'](dummy_state) init_models() if self.discrete_actions: action_dim = action_spec.maximum.item() + 1 self.actor = policies.CVAEPolicyPixelsDiscrete( state_dim, action_spec, action_dim * 2, encoder=actor_kwargs['encoder']) else: action_dim = action_spec.shape[0] self.actor = policies.CVAEPolicyPixels( state_dim, action_spec, action_dim * 2, encoder=actor_kwargs['encoder']) self.action_dim = action_dim self.state_dim = state_dim if self.discrete_actions: self.action_encoder = tf.keras.Sequential( [ tf.keras.layers.Dense( state_dim, use_bias=True ), # , kernel_regularizer=tf.keras.regularizers.l2(WEIGHT_DECAY) tf.keras.layers.ReLU(), # tf.keras.layers.BatchNormalization(), tf.keras.layers.Dense(action_dim) ], name='action_encoder') dummy_psi_act = tf.constant(np.zeros(shape=[1, state_dim])) self.action_encoder(dummy_psi_act) else: self.action_encoder = None self.actor_optimizer = tf.keras.optimizers.Adam(learning_rate=actor_lr) self.critic_learner = critic.CriticLearner( state_dim, action_dim, critic_lr, discount, tau, encoder=critic_kwargs['encoder'], encoder_target=critic_kwargs['encoder_target']) self.bc = None self.threshold = 0.3 self.model_dict = { 'critic_learner': self.critic_learner, 'action_encoder': self.action_encoder, 'actor': self.actor, 'actor_optimizer': self.actor_optimizer }
else: # Reduced architecture for DMC def conv_stack(): return ConvStack(observation_spec.shape) state_dim = 50 conv_stack_critic = conv_stack() conv_target_stack_critic = conv_stack() if observation_spec.shape == (64, 64, 3): conv_stack_critic.output_size = state_dim conv_target_stack_critic.output_size = state_dim # Combine and stop_grad some of the above conv stacks critic_kwargs['encoder'] = ImageEncoder(conv_stack_critic, feature_dim=state_dim, bprop_conv_stack=True) # Note: the target critic does not share any weights. critic_kwargs['encoder_target'] = ImageEncoder( conv_target_stack_critic, feature_dim=state_dim, bprop_conv_stack=True) if self.num_augmentations == 0: dummy_state = tf.constant( np.zeros(shape=[1] + list(observation_spec.shape))) else: # account for padding of +4 everywhere and then cropping out 68 dummy_state = tf.constant(np.zeros(shape=[1, 68, 68, 3])) @tf.function def init_models():
def __init__(self, observation_spec, action_spec, actor_lr=3e-4, critic_lr=3e-4, alpha_lr=3e-4, discount=0.99, tau=0.005, target_entropy=0.0, f_reg=1.0, reward_bonus=5.0, num_augmentations=1, rep_learn_keywords='outer', env_name='', batch_size=256, n_quantiles=5, temp=0.1, num_training_levels=200, latent_dim=256, n_levels_nce=5, popart_norm_beta=0.1): """Creates networks. Args: observation_spec: environment observation spec. action_spec: Action spec. actor_lr: Actor learning rate. critic_lr: Critic learning rate. alpha_lr: Temperature learning rate. discount: MDP discount. tau: Soft target update parameter. target_entropy: Target entropy. f_reg: Critic regularization weight. reward_bonus: Bonus added to the rewards. num_augmentations: Number of DrQ augmentations (crops) rep_learn_keywords: Representation learning loss to add (see below) env_name: Env name batch_size: Batch size n_quantiles: Number of GVF quantiles temp: Temperature of NCE softmax num_training_levels: Number of training MDPs (Procgen=200) latent_dim: Latent dimensions of auxiliary MLPs n_levels_nce: Number of MDPs to use contrastive loss on popart_norm_beta: PopArt normalization constant For `rep_learn_keywords`, pick from: stop_grad_FQI: whether to stop_grad TD/FQI critic updates? linear_Q: use a linear critic? successor_features: uses ||SF|| as cumulant gvf_termination: uses +1 if done else 0 as cumulant gvf_action_count: uses state-cond. action counts as cumulant nce: uses the multi-class dot-product InfoNCE objective cce: uses MoCo Categorical CrossEntropy objective energy: uses SimCLR + pairwise GVF distance (not fully tested) If no cumulant is specified, the reward will be taken as default one. """ del actor_lr, critic_lr, alpha_lr, target_entropy self.action_spec = action_spec self.num_augmentations = num_augmentations self.rep_learn_keywords = rep_learn_keywords.split('__') self.batch_size = batch_size self.env_name = env_name self.stop_grad_fqi = 'stop_grad_FQI' in self.rep_learn_keywords critic_kwargs = {'hidden_dims': (1024, 1024)} self.latent_dim = latent_dim self.n_levels_nce = n_levels_nce hidden_dims = hidden_dims_per_level = (self.latent_dim, self.latent_dim) self.num_training_levels = int(num_training_levels) self.n_quantiles = n_quantiles self.temp = temp # Make 2 sets of weights: # - Critic # - Critic (target) # Optionally, make a 3rd set for per-level critics if observation_spec.shape == (64, 64, 3): # IMPALA for Procgen def conv_stack(): return make_impala_cnn_network(depths=[16, 32, 32], use_batch_norm=False, dropout_rate=0.) state_dim = 256 else: # Reduced architecture for DMC def conv_stack(): return ConvStack(observation_spec.shape) state_dim = 50 conv_stack_critic = conv_stack() conv_target_stack_critic = conv_stack() if observation_spec.shape == (64, 64, 3): conv_stack_critic.output_size = state_dim conv_target_stack_critic.output_size = state_dim critic_kwargs['encoder'] = ImageEncoder(conv_stack_critic, feature_dim=state_dim, bprop_conv_stack=True) # Note: the target critic does not share any weights. critic_kwargs['encoder_target'] = ImageEncoder( conv_target_stack_critic, feature_dim=state_dim, bprop_conv_stack=True) conv_stack_critic_per_level = conv_stack() conv_target_stack_critic_per_level = conv_stack() if observation_spec.shape == (64, 64, 3): conv_stack_critic_per_level.output_size = state_dim conv_target_stack_critic_per_level.output_size = state_dim self.encoder_per_level = ImageEncoder(conv_stack_critic_per_level, feature_dim=state_dim, bprop_conv_stack=True) self.encoder_per_level_target = ImageEncoder( conv_target_stack_critic_per_level, feature_dim=state_dim, bprop_conv_stack=True) criticCL.soft_update(self.encoder_per_level, self.encoder_per_level_target, tau=1.0) if self.num_augmentations == 0: dummy_state = tf.constant( np.zeros([1] + list(observation_spec.shape))) else: # account for padding of +4 everywhere and then cropping out 68 dummy_state = tf.constant(np.zeros(shape=[1, 68, 68, 3])) dummy_enc = critic_kwargs['encoder'](dummy_state) @tf.function def init_models(): """This function initializes all auxiliary networks (state and action encoders) with dummy input (Procgen-specific, 68x68x3, 15 actions). """ critic_kwargs['encoder'](dummy_state) critic_kwargs['encoder_target'](dummy_state) self.encoder_per_level(dummy_state) self.encoder_per_level_target(dummy_state) init_models() action_dim = action_spec.maximum.item() + 1 self.action_dim = action_dim self.discount = discount self.tau = tau self.reg = f_reg self.reward_bonus = reward_bonus self.critic = criticCL.Critic(state_dim, action_dim, hidden_dims=hidden_dims, encoder=critic_kwargs['encoder'], discrete_actions=True, linear='linear_Q' in self.rep_learn_keywords) self.critic_target = criticCL.Critic( state_dim, action_dim, hidden_dims=hidden_dims, encoder=critic_kwargs['encoder_target'], discrete_actions=True, linear='linear_Q' in self.rep_learn_keywords) self.critic_optimizer = tf.keras.optimizers.Adam(learning_rate=3e-4) self.task_critic_optimizer = tf.keras.optimizers.Adam( learning_rate=3e-4) self.br_optimizer = tf.keras.optimizers.Adam(learning_rate=3e-4) if 'cce' in self.rep_learn_keywords: self.classifier = tf.keras.Sequential([ tf.keras.layers.Dense(self.latent_dim, use_bias=True), tf.keras.layers.ReLU(), tf.keras.layers.Dense(self.n_quantiles, use_bias=True) ], name='classifier') elif 'nce' in self.rep_learn_keywords: self.embedding = tf.keras.Sequential([ tf.keras.layers.Dense(self.latent_dim, use_bias=True), tf.keras.layers.ReLU(), tf.keras.layers.Dense(self.latent_dim, use_bias=True) ], name='embedding') # This snipet initializes all auxiliary networks (state and action encoders) # with dummy input (Procgen-specific, 68x68x3, 15 actions). dummy_state = tf.zeros((1, 68, 68, 3), dtype=tf.float32) phi_s = self.critic.encoder(dummy_state) phi_a = tf.eye(action_dim, dtype=tf.float32) if 'linear_Q' in self.rep_learn_keywords: _ = self.critic.critic1.state_encoder(phi_s) _ = self.critic.critic2.state_encoder(phi_s) _ = self.critic.critic1.action_encoder(phi_a) _ = self.critic.critic2.action_encoder(phi_a) _ = self.critic_target.critic1.state_encoder(phi_s) _ = self.critic_target.critic2.state_encoder(phi_s) _ = self.critic_target.critic1.action_encoder(phi_a) _ = self.critic_target.critic2.action_encoder(phi_a) if 'cce' in self.rep_learn_keywords: self.classifier(phi_s) elif 'nce' in self.rep_learn_keywords: self.embedding(phi_s) self.target_critic_to_use = self.critic_target self.critic_to_use = self.critic criticCL.soft_update(self.critic, self.critic_target, tau=1.0) self.cce = tf.keras.losses.SparseCategoricalCrossentropy( reduction=tf.keras.losses.Reduction.NONE, from_logits=True) self.bc = None if 'successor_features' in self.rep_learn_keywords: self.output_dim_level = self.latent_dim elif 'gvf_termination' in self.rep_learn_keywords: self.output_dim_level = 1 elif 'gvf_action_count' in self.rep_learn_keywords: self.output_dim_level = action_dim else: self.output_dim_level = action_dim self.task_critic_one = criticCL.Critic( state_dim, self.output_dim_level * self.num_training_levels, hidden_dims=hidden_dims_per_level, encoder=None, # critic_kwargs['encoder'], discrete_actions=True, cross_norm=False) self.task_critic_target_one = criticCL.Critic( state_dim, self.output_dim_level * 200, hidden_dims=hidden_dims_per_level, encoder=None, # critic_kwargs['encoder'], discrete_actions=True, cross_norm=False) self.task_critic_one(dummy_enc, actions=None, training=False, return_features=False, stop_grad_features=False) self.task_critic_target_one(dummy_enc, actions=None, training=False, return_features=False, stop_grad_features=False) criticCL.soft_update(self.task_critic_one, self.task_critic_target_one, tau=1.0) # Normalization constant beta, set to best default value as per PopArt paper self.reward_normalizer = popart.PopArt( running_statistics.EMAMeanStd(popart_norm_beta)) self.reward_normalizer.init() if 'CLIP' in self.rep_learn_keywords or 'clip' in self.rep_learn_keywords: self.loss_temp = tf.Variable(tf.constant(0.0, dtype=tf.float32), name='loss_temp', trainable=True) self.model_dict = { 'critic': self.critic, 'critic_target': self.critic_target, 'critic_optimizer': self.critic_optimizer, 'br_optimizer': self.br_optimizer } self.model_dict['encoder_perLevel'] = self.encoder_per_level self.model_dict[ 'encoder_perLevel_target'] = self.encoder_per_level_target self.model_dict['task_critic'] = self.task_critic_one self.model_dict['task_critic_target'] = self.task_critic_target_one
def __init__(self, observation_spec, action_spec, embedding_dim=256, hidden_dims=(256, 256), sequence_length=2, learning_rate=None, discount=0.95, target_update_period=1000, num_augmentations=0, rep_learn_keywords='outer', batch_size=256): """Creates networks. Args: observation_spec: State spec. action_spec: Action spec. embedding_dim: Embedding size. hidden_dims: List of hidden dimensions. sequence_length: Expected length of sequences provided as input. learning_rate: Learning rate. discount: discount factor. target_update_period: How frequently update target? num_augmentations: Number of DrQ random crops. rep_learn_keywords: Representation learning loss to add. batch_size: batch size. """ super().__init__() action_dim = action_spec.maximum.item() + 1 self.observation_spec = observation_spec self.action_dim = action_dim self.action_spec = action_spec self.embedding_dim = embedding_dim self.sequence_length = sequence_length self.discount = discount self.tau = 0.005 self.discount = 0.99 self.target_update_period = target_update_period self.num_augmentations = num_augmentations self.rep_learn_keywords = rep_learn_keywords.split('__') self.batch_size = batch_size critic_kwargs = {} if observation_spec.shape == (64, 64, 3): # IMPALA for Procgen def conv_stack(): return make_impala_cnn_network(depths=[16, 32, 32], use_batch_norm=False, dropout_rate=0.) state_dim = 256 else: # Reduced architecture for DMC def conv_stack(): return ConvStack(observation_spec.shape) state_dim = 50 conv_stack_critic = conv_stack() conv_target_stack_critic = conv_stack() if observation_spec.shape == (64, 64, 3): conv_stack_critic.output_size = state_dim conv_target_stack_critic.output_size = state_dim critic_kwargs['encoder'] = ImageEncoder(conv_stack_critic, feature_dim=state_dim, bprop_conv_stack=True) critic_kwargs['encoder_target'] = ImageEncoder( conv_target_stack_critic, feature_dim=state_dim, bprop_conv_stack=True) self.embedder = tf_utils.EmbedNet(state_dim, embedding_dim=self.embedding_dim, hidden_dims=hidden_dims) self.f_value = tf_utils.create_mlp(self.embedding_dim, 1, hidden_dims=hidden_dims, activation=tf.nn.swish) self.f_value_target = tf_utils.create_mlp(self.embedding_dim, 1, hidden_dims=hidden_dims, activation=tf.nn.swish) self.f_trans = tf_utils.create_mlp(self.embedding_dim + self.embedding_dim, self.embedding_dim, hidden_dims=hidden_dims, activation=tf.nn.swish) self.f_out = tf_utils.create_mlp(self.embedding_dim + self.embedding_dim, 2, hidden_dims=hidden_dims, activation=tf.nn.swish) self.action_encoder = tf.keras.Sequential( [ tf.keras.layers.Dense( self.embedding_dim, use_bias=True ), # , kernel_regularizer=tf.keras.regularizers.l2(WEIGHT_DECAY) tf.keras.layers.ReLU(), tf.keras.layers.Dense(self.embedding_dim) ], name='action_encoder') if self.num_augmentations == 0: dummy_state = tf.constant( np.zeros(shape=[1] + list(observation_spec.shape))) self.obs_spec = list(observation_spec.shape) else: # account for padding of +4 everywhere and then cropping out 68 dummy_state = tf.constant(np.zeros(shape=[1, 68, 68, 3])) self.obs_spec = [68, 68, 3] @tf.function def init_models(): critic_kwargs['encoder'](dummy_state) critic_kwargs['encoder_target'](dummy_state) self.action_encoder( tf.cast(tf.one_hot([1], depth=action_dim), tf.float32)) init_models() self.critic = critic.Critic(state_dim, action_dim, hidden_dims=hidden_dims, encoder=critic_kwargs['encoder'], discrete_actions=True, linear='linear_Q' in self.rep_learn_keywords) self.critic_target = critic.Critic( state_dim, action_dim, hidden_dims=hidden_dims, encoder=critic_kwargs['encoder_target'], discrete_actions=True, linear='linear_Q' in self.rep_learn_keywords) @tf.function def init_models2(): dummy_state = tf.zeros((1, 68, 68, 3), dtype=tf.float32) phi_s = self.critic.encoder(dummy_state) phi_a = tf.eye(15, dtype=tf.float32) if 'linear_Q' in self.rep_learn_keywords: _ = self.critic.critic1.state_encoder(phi_s) _ = self.critic.critic2.state_encoder(phi_s) _ = self.critic.critic1.action_encoder(phi_a) _ = self.critic.critic2.action_encoder(phi_a) _ = self.critic_target.critic1.state_encoder(phi_s) _ = self.critic_target.critic2.state_encoder(phi_s) _ = self.critic_target.critic1.action_encoder(phi_a) _ = self.critic_target.critic2.action_encoder(phi_a) init_models2() critic.soft_update(self.critic, self.critic_target, tau=1.0) critic.soft_update(self.f_value, self.f_value_target, tau=1.0) learning_rate = learning_rate or 1e-4 self.optimizer = tf.keras.optimizers.Adam(learning_rate=3e-4) self.critic_optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3) self.all_variables = (self.embedder.trainable_variables + self.f_value.trainable_variables + self.f_value_target.trainable_variables + self.f_trans.trainable_variables + self.f_out.trainable_variables + self.critic.trainable_variables + self.critic_target.trainable_variables) self.model_dict = { 'action_encoder': self.action_encoder, 'f_out': self.f_out, 'f_trans': self.f_trans, 'f_value_target': self.f_value_target, 'f_value': self.f_value, 'embedder': self.embedder, 'critic': self.critic, 'critic_target': self.critic_target, 'critic_optimizer': self.critic_optimizer, 'optimizer': self.optimizer }
def __init__(self, observation_spec, action_spec, actor_lr=3e-4, critic_lr=3e-4, alpha_lr=3e-4, discount=0.99, tau=0.005, target_update_period=1, target_entropy=0.0, cross_norm=False, pcl_actor_update=False): """Creates networks. Args: observation_spec: environment observation spec. action_spec: Action spec. actor_lr: Actor learning rate. critic_lr: Critic learning rate. alpha_lr: Temperature learning rate. discount: MDP discount. tau: Soft target update parameter. target_update_period: Target network update period. target_entropy: Target entropy. cross_norm: Whether to fit cross norm critic. pcl_actor_update: Whether to use PCL actor update. """ actor_kwargs = {} critic_kwargs = {} if len(observation_spec.shape) == 3: # Image observations. # DRQ encoder params. # https://github.com/denisyarats/drq/blob/master/config.yaml#L73 state_dim = 50 # Actor and critic encoders share conv weights only. conv_stack = ConvStack(observation_spec.shape) actor_kwargs['encoder'] = ImageEncoder(conv_stack, state_dim, bprop_conv_stack=False) actor_kwargs['hidden_dims'] = (1024, 1024) critic_kwargs['encoder'] = ImageEncoder(conv_stack, state_dim, bprop_conv_stack=True) critic_kwargs['hidden_dims'] = (1024, 1024) if not cross_norm: # Note: the target critic does not share any weights. critic_kwargs['encoder_target'] = ImageEncoder( ConvStack(observation_spec.shape), state_dim, bprop_conv_stack=True) else: # 1D state observations. assert len(observation_spec.shape) == 1 state_dim = observation_spec.shape[0] if cross_norm: beta_1 = 0.0 else: beta_1 = 0.9 self.actor = policies.DiagGuassianPolicy(state_dim, action_spec, **actor_kwargs) self.actor_optimizer = tf.keras.optimizers.Adam(learning_rate=actor_lr, beta_1=beta_1) self.log_alpha = tf.Variable(tf.math.log(0.1), trainable=True) self.alpha_optimizer = tf.keras.optimizers.Adam(learning_rate=alpha_lr, beta_1=beta_1) if cross_norm: assert 'encoder_target' not in critic_kwargs self.critic_learner = critic.CrossNormCriticLearner( state_dim, action_spec.shape[0], critic_lr, discount, tau, **critic_kwargs) else: self.critic_learner = critic.CriticLearner( state_dim, action_spec.shape[0], critic_lr, discount, tau, target_update_period, **critic_kwargs) self.target_entropy = target_entropy self.discount = discount self.pcl_actor_update = pcl_actor_update
def __init__(self, observation_spec, action_spec, actor_lr=3e-4, critic_lr=3e-4, alpha_lr=3e-4, discount=0.99, tau=0.005, target_entropy=0.0, f_reg=1.0, reward_bonus=5.0, num_augmentations=1, env_name='', batch_size=256): """Creates networks. Args: observation_spec: environment observation spec. action_spec: Action spec. actor_lr: Actor learning rate. critic_lr: Critic learning rate. alpha_lr: Temperature learning rate. discount: MDP discount. tau: Soft target update parameter. target_entropy: Target entropy. f_reg: Critic regularization weight. reward_bonus: Bonus added to the rewards. num_augmentations: Number of DrQ augmentations (crops) env_name: Env name batch_size: Batch size """ self.num_augmentations = num_augmentations self.discrete_actions = False if len(action_spec.shape) else True self.batch_size = batch_size actor_kwargs = {'hidden_dims': (1024, 1024)} critic_kwargs = {'hidden_dims': (1024, 1024)} # DRQ encoder params. # https://github.com/denisyarats/drq/blob/master/config.yaml#L73 # Make 4 sets of weights: # - BC # - Actor # - Critic # - Critic (target) if observation_spec.shape == (64, 64, 3): # IMPALA for Procgen def conv_stack(): return make_impala_cnn_network(depths=[16, 32, 32], use_batch_norm=False, dropout_rate=0.) state_dim = 256 else: # Reduced architecture for DMC def conv_stack(): return ConvStack(observation_spec.shape) state_dim = 50 conv_stack_bc = conv_stack() conv_stack_actor = conv_stack() conv_stack_critic = conv_stack() conv_target_stack_critic = conv_stack() if observation_spec.shape == (64, 64, 3): conv_stack_bc.output_size = state_dim conv_stack_actor.output_size = state_dim conv_stack_critic.output_size = state_dim conv_target_stack_critic.output_size = state_dim # Combine and stop_grad some of the above conv stacks actor_kwargs['encoder_bc'] = ImageEncoder(conv_stack_bc, feature_dim=state_dim, bprop_conv_stack=True) actor_kwargs['encoder'] = ImageEncoder(conv_stack_critic, feature_dim=state_dim, bprop_conv_stack=False) critic_kwargs['encoder'] = ImageEncoder(conv_stack_critic, feature_dim=state_dim, bprop_conv_stack=True) # Note: the target critic does not share any weights. critic_kwargs['encoder_target'] = ImageEncoder( conv_target_stack_critic, feature_dim=state_dim, bprop_conv_stack=True) if self.num_augmentations == 0: dummy_state = tf.constant( np.zeros(shape=[1] + list(observation_spec.shape))) else: # account for padding of +4 everywhere and then cropping out 68 dummy_state = tf.constant(np.zeros(shape=[1, 68, 68, 3])) @tf.function def init_models(): actor_kwargs['encoder_bc'](dummy_state) actor_kwargs['encoder'](dummy_state) critic_kwargs['encoder'](dummy_state) critic_kwargs['encoder_target'](dummy_state) init_models() if self.discrete_actions: hidden_dims = () self.actor = policies.CategoricalPolicy( state_dim, action_spec, hidden_dims=hidden_dims, encoder=actor_kwargs['encoder']) action_dim = action_spec.maximum.item() + 1 else: hidden_dims = (256, 256, 256) self.actor = policies.DiagGuassianPolicy( state_dim, action_spec, hidden_dims=hidden_dims, encoder=actor_kwargs['encoder']) action_dim = action_spec.shape[0] self.action_dim = action_dim self.actor_optimizer = tf.keras.optimizers.Adam(learning_rate=actor_lr) self.log_alpha = tf.Variable(tf.math.log(1.0), trainable=True) self.alpha_optimizer = tf.keras.optimizers.Adam(learning_rate=alpha_lr) self.target_entropy = target_entropy self.discount = discount self.tau = tau self.bc = behavioral_cloning.BehavioralCloning( observation_spec, action_spec, mixture=True, encoder=actor_kwargs['encoder_bc'], num_augmentations=self.num_augmentations, env_name=env_name, batch_size=batch_size) self.critic = critic.Critic(state_dim, action_dim, hidden_dims=hidden_dims, encoder=critic_kwargs['encoder']) self.critic_target = critic.Critic( state_dim, action_dim, hidden_dims=hidden_dims, encoder=critic_kwargs['encoder_target']) critic.soft_update(self.critic, self.critic_target, tau=1.0) self.critic_optimizer = tf.keras.optimizers.Adam( learning_rate=critic_lr) self.f_reg = f_reg self.reward_bonus = reward_bonus self.model_dict = { 'critic': self.critic, 'critic_target': self.critic_target, 'actor': self.actor, 'bc': self.bc, 'critic_optimizer': self.critic_optimizer, 'alpha_optimizer': self.alpha_optimizer, 'actor_optimizer': self.actor_optimizer }
def __init__(self, observation_spec, action_spec, mixture = False, encoder=None, num_augmentations = 1, env_name = '', rep_learn_keywords = 'outer', batch_size = 256): if observation_spec.shape == (64, 64, 3): state_dim = 256 else: state_dim = 50 self.batch_size = batch_size self.num_augmentations = num_augmentations self.rep_learn_keywords = rep_learn_keywords.split('__') self.discrete_actions = False if len(action_spec.shape) else True self.action_spec = action_spec if encoder is None: if observation_spec.shape == (64, 64, 3): # IMPALA for Procgen def conv_stack(): return make_impala_cnn_network( depths=[16, 32, 32], use_batch_norm=False, dropout_rate=0.) state_dim = 256 else: # Reduced architecture for DMC def conv_stack(): return ConvStack(observation_spec.shape) state_dim = 50 conv_stack_bc = conv_stack() if observation_spec.shape == (64, 64, 3): conv_stack_bc.output_size = state_dim # Combine and stop_grad some of the above conv stacks encoder = ImageEncoder( conv_stack_bc, feature_dim=state_dim, bprop_conv_stack=True) if self.num_augmentations == 0: dummy_state = tf.constant( np.zeros(shape=[1] + list(observation_spec.shape))) else: # account for padding of +4 everywhere and then cropping out 68 dummy_state = tf.constant(np.zeros(shape=[1, 68, 68, 3])) @tf.function def init_models(): encoder(dummy_state) init_models() if self.discrete_actions: if 'linear_Q' in self.rep_learn_keywords: hidden_dims = () else: hidden_dims = (256, 256) self.policy = policies.CategoricalPolicy( state_dim, action_spec, hidden_dims=hidden_dims, encoder=encoder) action_dim = action_spec.maximum.item() + 1 else: action_dim = action_spec.shape[0] if mixture: self.policy = policies.MixtureGuassianPolicy( state_dim, action_spec, encoder=encoder) else: self.policy = policies.DiagGuassianPolicy( state_dim, action_spec, encoder=encoder) self.optimizer = tf.keras.optimizers.Adam( learning_rate=5e-4) self.log_alpha = tf.Variable(tf.math.log(1.0), trainable=True) self.alpha_optimizer = tf.keras.optimizers.Adam( learning_rate=5e-4) self.target_entropy = -action_dim if env_name and env_name.startswith('procgen'): self.procgen_action_mat = PROCGEN_ACTION_MAT[env_name.split('-')[1]] self.bc = None self.model_dict = { 'policy': self.policy, 'optimizer': self.optimizer }
def __init__( self, observation_spec, action_spec, embedding_dim = 256, num_distributions=None, hidden_dims = (256, 256), sequence_length = 2, learning_rate=None, latent_dim = 256, reward_weight = 1.0, forward_weight = 1.0, # Predict last state given prev actions/states. inverse_weight = 1.0, # Predict last action given states. state_prediction_mode = 'energy', num_augmentations = 0, rep_learn_keywords = 'outer', batch_size = 256): """Creates networks. Args: observation_spec: State spec. action_spec: Action spec. embedding_dim: Embedding size. num_distributions: Number of categorical distributions for discrete embedding. hidden_dims: List of hidden dimensions. sequence_length: Expected length of sequences provided as input learning_rate: Learning rate. latent_dim: Dimension of the latent variable. reward_weight: Weight on the reward loss. forward_weight: Weight on the forward loss. inverse_weight: Weight on the inverse loss. state_prediction_mode: One of ['latent', 'energy']. num_augmentations: Num of random crops rep_learn_keywords: Representation learning loss to add. batch_size: Batch size """ super().__init__() action_dim = action_spec.maximum.item() + 1 self.observation_spec = observation_spec self.action_dim = action_dim self.action_spec = action_spec self.embedding_dim = embedding_dim self.num_distributions = num_distributions self.sequence_length = sequence_length self.latent_dim = latent_dim self.reward_weight = reward_weight self.forward_weight = forward_weight self.inverse_weight = inverse_weight self.state_prediction_mode = state_prediction_mode self.num_augmentations = num_augmentations self.rep_learn_keywords = rep_learn_keywords.split('__') self.batch_size = batch_size self.tau = 0.005 self.discount = 0.99 critic_kwargs = {} if observation_spec.shape == (64, 64, 3): # IMPALA for Procgen def conv_stack(): return make_impala_cnn_network( depths=[16, 32, 32], use_batch_norm=False, dropout_rate=0.) state_dim = 256 else: # Reduced architecture for DMC def conv_stack(): return ConvStack(observation_spec.shape) state_dim = 50 conv_stack_critic = conv_stack() conv_target_stack_critic = conv_stack() if observation_spec.shape == (64, 64, 3): conv_stack_critic.output_size = state_dim conv_target_stack_critic.output_size = state_dim critic_kwargs['encoder'] = ImageEncoder( conv_stack_critic, feature_dim=state_dim, bprop_conv_stack=True) critic_kwargs['encoder_target'] = ImageEncoder( conv_target_stack_critic, feature_dim=state_dim, bprop_conv_stack=True) self.embedder = tf_utils.EmbedNet( state_dim, embedding_dim=self.embedding_dim, num_distributions=self.num_distributions, hidden_dims=hidden_dims) if self.sequence_length > 2: self.latent_embedder = tf_utils.RNNEmbedNet( [self.sequence_length - 2, self.embedding_dim + self.embedding_dim], embedding_dim=self.latent_dim) self.reward_decoder = tf_utils.EmbedNet( self.latent_dim + self.embedding_dim + self.embedding_dim, embedding_dim=1, hidden_dims=hidden_dims) forward_decoder_out = ( self.embedding_dim if (self.state_prediction_mode in ['latent', 'energy']) else self.input_dim) forward_decoder_dists = ( self.num_distributions if (self.state_prediction_mode in ['latent', 'energy']) else None) self.forward_decoder = tf_utils.StochasticEmbedNet( self.latent_dim + self.embedding_dim + self.embedding_dim, embedding_dim=forward_decoder_out, num_distributions=forward_decoder_dists, hidden_dims=hidden_dims) self.weight = tf.Variable(tf.eye(self.embedding_dim)) self.action_encoder = tf.keras.Sequential( [ tf.keras.layers.Dense( self.embedding_dim, use_bias=True ), # , kernel_regularizer=tf.keras.regularizers.l2(WEIGHT_DECAY) tf.keras.layers.ReLU(), tf.keras.layers.Dense(self.embedding_dim) ], name='action_encoder') if self.num_augmentations == 0: dummy_state = tf.constant( np.zeros(shape=[1] + list(observation_spec.shape))) self.obs_spec = list(observation_spec.shape) else: # account for padding of +4 everywhere and then cropping out 68 dummy_state = tf.constant(np.zeros(shape=[1, 68, 68, 3])) self.obs_spec = [68, 68, 3] @tf.function def init_models(): critic_kwargs['encoder'](dummy_state) critic_kwargs['encoder_target'](dummy_state) self.action_encoder( tf.cast(tf.one_hot([1], depth=action_dim), tf.float32)) init_models() hidden_dims = (256, 256) # self.actor = policies.CategoricalPolicy(state_dim, action_spec, # hidden_dims=hidden_dims, encoder=actor_kwargs['encoder']) self.critic = critic.Critic( state_dim, action_dim, hidden_dims=hidden_dims, encoder=critic_kwargs['encoder'], discrete_actions=True, linear='linear_Q' in self.rep_learn_keywords) self.critic_target = critic.Critic( state_dim, action_dim, hidden_dims=hidden_dims, encoder=critic_kwargs['encoder_target'], discrete_actions=True, linear='linear_Q' in self.rep_learn_keywords) @tf.function def init_models2(): dummy_state = tf.zeros((1, 68, 68, 3), dtype=tf.float32) phi_s = self.critic.encoder(dummy_state) phi_a = tf.eye(15, dtype=tf.float32) if 'linear_Q' in self.rep_learn_keywords: _ = self.critic.critic1.state_encoder(phi_s) _ = self.critic.critic2.state_encoder(phi_s) _ = self.critic.critic1.action_encoder(phi_a) _ = self.critic.critic2.action_encoder(phi_a) _ = self.critic_target.critic1.state_encoder(phi_s) _ = self.critic_target.critic2.state_encoder(phi_s) _ = self.critic_target.critic1.action_encoder(phi_a) _ = self.critic_target.critic2.action_encoder(phi_a) init_models2() critic.soft_update(self.critic, self.critic_target, tau=1.0) learning_rate = learning_rate or 1e-4 self.optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate) self.critic_optimizer = tf.keras.optimizers.Adam( learning_rate=learning_rate) self.all_variables = ( self.embedder.trainable_variables + self.reward_decoder.trainable_variables + self.forward_decoder.trainable_variables + self.action_encoder.trainable_variables + self.critic.trainable_variables + self.critic_target.trainable_variables) self.model_dict = { 'action_encoder': self.action_encoder, 'weight': self.weight, 'forward_decoder': self.forward_decoder, 'reward_decoder': self.reward_decoder, 'embedder': self.embedder, 'critic': self.critic, 'critic_target': self.critic_target, 'critic_optimizer': self.critic_optimizer, 'optimizer': self.optimizer }