def batchnorm_activation_dropout(input_layer, batchnorm, activation_function, dropout_rate, is_training, name): layers = [input_layer] # batchnorm if batchnorm: layers.append( tf.layers.batch_normalization(layers[-1], name="{}_batchnorm".format(name), training=is_training)) # activation if activation_function: if isinstance(activation_function, str): activation_function = utils.get_activation_function( activation_function) layers.append( activation_function(layers[-1], name="{}_activation".format(name))) # dropout if dropout_rate > 0: layers.append( tf.layers.dropout(layers[-1], dropout_rate, name="{}_dropout".format(name), training=is_training)) # remove the input layer from the layers list del layers[0] return layers
def batchnorm_activation_dropout(input_layer, batchnorm, activation_function, dropout_rate, is_training, name): layers = [input_layer] # Rationale: passing a bool here will mean that batchnorm and or activation will never activate assert not isinstance(is_training, bool) # batchnorm if batchnorm: layers.append( tf.layers.batch_normalization(layers[-1], name="{}_batchnorm".format(name), training=is_training)) # activation if activation_function: if isinstance(activation_function, str): activation_function = utils.get_activation_function( activation_function) layers.append( activation_function(layers[-1], name="{}_activation".format(name))) # dropout if dropout_rate > 0: layers.append( tf.layers.dropout(layers[-1], dropout_rate, name="{}_dropout".format(name), training=is_training)) # remove the input layer from the layers list del layers[0] return layers
def get_input_embedder(self, input_name: str, embedder_params: InputEmbedderParameters): """ Given an input embedder parameters class, creates the input embedder and returns it :param input_name: the name of the input to the embedder (used for retrieving the shape). The input should be a value within the state or the action. :param embedder_params: the parameters of the class of the embedder :return: the embedder instance """ allowed_inputs = copy.copy(self.spaces.state.sub_spaces) allowed_inputs["action"] = copy.copy(self.spaces.action) allowed_inputs["goal"] = copy.copy(self.spaces.goal) if input_name not in allowed_inputs.keys(): raise ValueError("The key for the input embedder ({}) must match one of the following keys: {}" .format(input_name, allowed_inputs.keys())) emb_type = "vector" if isinstance(allowed_inputs[input_name], TensorObservationSpace): emb_type = "tensor" elif isinstance(allowed_inputs[input_name], PlanarMapsObservationSpace): emb_type = "image" embedder_path = embedder_params.path(emb_type) embedder_params_copy = copy.copy(embedder_params) embedder_params_copy.activation_function = utils.get_activation_function(embedder_params.activation_function) embedder_params_copy.input_rescaling = embedder_params_copy.input_rescaling[emb_type] embedder_params_copy.input_offset = embedder_params_copy.input_offset[emb_type] embedder_params_copy.name = input_name module = dynamic_import_and_instantiate_module_from_params(embedder_params_copy, path=embedder_path, positional_args=[allowed_inputs[input_name].shape]) return module
def get_middleware(self, middleware_params: MiddlewareParameters): """ Given a middleware type, creates the middleware and returns it :param middleware_params: the paramaeters of the middleware class :return: the middleware instance """ mod_name = middleware_params.parameterized_class_name middleware_path = middleware_params.path middleware_params_copy = copy.copy(middleware_params) middleware_params_copy.activation_function = utils.get_activation_function(middleware_params.activation_function) module = dynamic_import_and_instantiate_module_from_params(middleware_params_copy, path=middleware_path) return module
def get_output_head(self, head_params: HeadParameters, head_idx: int): """ Given a head type, creates the head and returns it :param head_params: the parameters of the head to create :param head_idx: the head index :return: the head """ mod_name = head_params.parameterized_class_name head_path = head_params.path head_params_copy = copy.copy(head_params) head_params_copy.activation_function = utils.get_activation_function(head_params_copy.activation_function) return dynamic_import_and_instantiate_module_from_params(head_params_copy, path=head_path, extra_kwargs={ 'agent_parameters': self.ap, 'spaces': self.spaces, 'network_name': self.network_wrapper_name, 'head_idx': head_idx, 'is_local': self.network_is_local})