def graph_fn(mode, features, labels=None): kwargs = {} if 'labels' in get_arguments(self._graph_fn): kwargs['labels'] = labels graph_outputs = self._graph_fn(mode=mode, features=features, **kwargs) a = Dense(units=self.num_actions)(graph_outputs) v = None if self.dueling is not None: # Q = V(s) + A(s, a) v = Dense(units=1)(graph_outputs) if self.dueling == 'mean': q = v + (a - tf.reduce_mean(a, axis=1, keep_dims=True)) elif self.dueling == 'max': q = v + (a - tf.reduce_max(a, axis=1, keep_dims=True)) elif self.dueling == 'naive': q = v + a elif self.dueling is True: q = tf.identity(a) else: raise ValueError("The value `{}` provided for " "dueling is unsupported.".format( self.dueling)) else: q = tf.identity(a) return QModelSpec(graph_outputs=graph_outputs, a=a, v=v, q=q)
def _call_model_fn(self, features, labels, mode, config=None): """Calls model function. Args: features: features dict. labels: labels dict. mode: ModeKeys Returns: An `EstimatorSpec` object. Raises: ValueError: if model_fn returns invalid objects. """ model_fn_args = get_arguments(self._model_fn) kwargs = {} if 'labels' in model_fn_args: kwargs['labels'] = labels else: if labels is not None: raise ValueError( 'model_fn does not take labels, but input_fn returns labels.' ) if 'mode' in model_fn_args: kwargs['mode'] = mode if 'params' in model_fn_args: kwargs['params'] = self.params if 'config' in model_fn_args: kwargs['config'] = self.config model_fn_results = self._model_fn(features=features, **kwargs) if not isinstance(model_fn_results, EstimatorSpec): raise ValueError('model_fn should return an EstimatorSpec.') return model_fn_results
def _decay_fn(timestep, exploration_rate, decay_type='polynomial_decay', start_decay_at=0, stop_decay_at=1e9, decay_rate=0., staircase=False, decay_steps=100000, min_exploration_rate=0): """The computed decayed exploration rate. Args: timestep: the current timestep. exploration_rate: `float` or `list` of `float` or `function`. The initial value of the exploration rate. decay_type: A decay function name defined in `exploration_decay` possible Values: exponential_decay, inverse_time_decay, natural_exp_decay, piecewise_constant, polynomial_decay. start_decay_at: `int`. When to start the decay. stop_decay_at: `int`. When to stop the decay. decay_rate: A Python number. The decay rate. staircase: Whether to apply decay in a discrete staircase, as opposed to continuous, fashion. decay_steps: How often to apply decay. min_exploration_rate: `float`. Don't decay below this number. """ if isinstance(exploration_rate, partial): _exploration_rate = exploration_rate() else: _exploration_rate = exploration_rate timestep = tf.to_int32(timestep) decay_type_fn = getattr(exploration_decay, decay_type) kwargs = dict(exploration_rate=_exploration_rate, timestep=tf.minimum(timestep, tf.to_int32(stop_decay_at)) - tf.to_int32(start_decay_at), decay_steps=decay_steps, name="decayed_exploration_rate") decay_fn_args = get_arguments(decay_type_fn) if 'decay_rate' in decay_fn_args: kwargs['decay_rate'] = decay_rate if 'staircase' in decay_fn_args: kwargs['staircase'] = staircase decayed_exploration_rate = decay_type_fn(**kwargs) final_exploration_rate = tf.train.piecewise_constant( x=timestep, boundaries=[start_decay_at], values=[exploration_rate, decayed_exploration_rate]) if min_exploration_rate: final_exploration_rate = tf.maximum(final_exploration_rate, min_exploration_rate) return final_exploration_rate
def _check_subgraph_fn(function, function_name): """Checks that the functions provided for constructing the graph has a valid signature.""" if function is not None: # Check number of arguments of the given function matches requirements. model_fn_args = get_arguments(function) if 'mode' not in model_fn_args or 'features' not in model_fn_args: raise ValueError( "Model's `{}` `{}` should have at least 2 args: " "`mode`, `features`, and possibly `features`.".format(function_name, function)) else: raise ValueError("`{}` must be provided to Model.".format(function_name))
def __init__(self, mode, build_fn, name=None): if not callable(build_fn): raise TypeError("`build_fn` must be callable.") build_fn_args = get_arguments(build_fn) if 'mode' not in build_fn_args: raise ValueError("`build_fn` must include `mode` argument.") self._build_fn = build_fn super(FunctionModule, self).__init__( mode=mode, name=name or get_function_name(build_fn), module_type=self.ModuleType.IMAGE_PROCESSOR)
def _call_graph_fn(self, features, labels=None): """Calls graph function. Args: features: `Tensor` or `dict` of tensors labels: `Tensor` or `dict` of tensors """ set_learning_phase(Modes.is_train(self.mode)) kwargs = {} if 'labels' in get_arguments(self._graph_fn): kwargs['labels'] = labels return self._graph_fn(mode=self.mode, features=features, **kwargs)
def encode(self, features, labels, encoder_fn, *args, **kwargs): """Encodes the incoming tensor. Args: features: `Tensor`. labels: `dict` or `Tensor` encoder_fn: `function`. *args: **kwargs: """ if 'labels' in get_arguments(encoder_fn): kwargs['labels'] = labels output = encoder_fn(mode=self.mode, features=features, **kwargs) if self.state_size is None: self.state_size = get_shape(output)[1:] return output
def _check_bridge_fn(function): if function is not None: # Check number of arguments of the given function matches requirements. model_fn_args = get_arguments(function) cond = ('mode' not in model_fn_args or 'loss' not in model_fn_args or 'features' not in model_fn_args or 'labels' not in model_fn_args or 'encoder_fn' not in model_fn_args or 'encoder_fn' not in model_fn_args) if cond: raise ValueError( "Model's `bridge` `{}` should have these args: " "`mode`, `features`, `labels`, `encoder_fn`, " "and `decoder_fn`.".format(function)) else: raise ValueError("`bridge_fn` must be provided to Model.")
def graph_fn(mode, features, labels=None): kwargs = {} if 'labels' in get_arguments(self._graph_fn): kwargs['labels'] = labels graph_outputs = self._graph_fn(mode=mode, features=features, **kwargs) a = Dense(units=self.num_actions)(graph_outputs) if self.is_continuous: values = tf.concat(values=[a, tf.exp(a) + 1], axis=0) distribution = self._build_distribution(values=values) else: values = tf.identity(a) distribution = self._build_distribution(values=a) return PGModelSpec(graph_outputs=graph_outputs, a=a, distribution=distribution, dist_values=values)
def decode(self, features, labels, decoder_fn, *args, **kwargs): """Decodes the incoming tensor if it's validates against the state size of the decoder. Otherwise, generates a random value. Args: features: `Tensor` labels: `dict` or `Tensor` decoder_fn: `function`. *args: **kwargs: """ # TODO: make decode capable of generating values directly, # TODO: basically accepting None incoming values. Should also specify a distribution. # shape = self._get_decoder_shape(incoming) # return decoder_fn(mode=self.mode, inputs=tf.random_normal(shape=shape)) if 'labels' in get_arguments(decoder_fn): kwargs['labels'] = labels return decoder_fn(mode=self.mode, features=features, **kwargs)
def _verify_model_fn_args(model_fn, params): """Verifies model fn arguments.""" valid_model_fn_args = { 'features', 'labels', 'mode', 'params', 'config' } if model_fn is not None: # Check number of arguments of the given function matches requirements. model_fn_args = get_arguments(model_fn) if 'features' not in model_fn_args: raise ValueError( 'model_fn `{}` must include features argument.'.format( model_fn)) if 'labels' not in model_fn_args: raise ValueError( 'model_fn `{}` must include labels argument.'.format( model_fn)) if params is not None and 'params' not in model_fn_args: raise ValueError( "Estimator's model_fn `{}` does not include params argument, " "but params `{}` are passed.".format(model_fn, params)) if params is None and 'params' in model_fn_args: logging.warning( "Estimator's model_fn (%s) includes params " "argument, but params are not passed to Estimator.", model_fn) else: raise ValueError("`model_fn` must be provided to Estimator.") if 'self' in model_fn_args: model_fn_args.remove('self') non_valid_args = set(model_fn_args) - valid_model_fn_args if non_valid_args: raise ValueError( "model_fn `{}` has following not expected args: {}".format( model_fn, non_valid_args))