def input_signature(self, *, function): if function == 'act': return SignatureDict( states=self.states_spec.signature(batched=True), auxiliaries=self.auxiliaries_spec.signature(batched=True), parallel=self.parallel_spec.signature(batched=True)) elif function == 'core_act': return SignatureDict( states=self.states_spec.signature(batched=True), internals=self.internals_spec.signature(batched=True), auxiliaries=self.auxiliaries_spec.signature(batched=True), parallel=self.parallel_spec.signature(batched=True), deterministic=self.deterministic_spec.signature(batched=False)) elif function == 'core_observe': return SignatureDict( terminal=self.terminal_spec.signature(batched=True), reward=self.reward_spec.signature(batched=True), parallel=self.parallel_spec.signature(batched=False)) elif function == 'independent_act': signature = SignatureDict(states=self.states_spec.signature( batched=True)) if len(self.internals_spec) > 0: signature['internals'] = self.internals_spec.signature( batched=True) if len(self.auxiliaries_spec) > 0: signature['auxiliaries'] = self.auxiliaries_spec.signature( batched=True) signature['deterministic'] = self.deterministic_spec.signature( batched=False) return signature elif function == 'observe': return SignatureDict( terminal=self.terminal_spec.signature(batched=True), reward=self.reward_spec.signature(batched=True), parallel=self.parallel_spec.signature(batched=False)) elif function == 'reset': return SignatureDict() else: return super().input_signature(function=function)
def input_signature(self, *, function): if function == 'enqueue': return self.values_spec.signature(batched=True) elif function == 'predecessors': return SignatureDict( indices=TensorSpec(type='int', shape=()).signature(batched=True), horizon=TensorSpec(type='int', shape=()).signature(batched=False)) elif function == 'reset': return SignatureDict() elif function == 'retrieve': return SignatureDict( indices=TensorSpec(type='int', shape=()).signature( batched=True)) elif function == 'retrieve_episodes': return SignatureDict(n=TensorSpec(type='int', shape=()).signature( batched=False)) elif function == 'retrieve_timesteps': return SignatureDict( n=TensorSpec(type='int', shape=()).signature(batched=False), past_horizon=TensorSpec(type='int', shape=()).signature(batched=False), future_horizon=TensorSpec(type='int', shape=()).signature(batched=False)) elif function == 'successors': return SignatureDict( indices=TensorSpec(type='int', shape=()).signature(batched=True), horizon=TensorSpec(type='int', shape=()).signature(batched=False)) else: return super().input_signature(function=function)
def output_signature(self, *, function): if function == 'apply': if self.temporal_processing == 'cumulative': return SignatureDict(singleton=self.output_spec().signature( batched=True)) elif self.temporal_processing == 'iterative': return SignatureDict( x=self.output_spec().signature(batched=True), internals=self.internals_spec.signature(batched=True)) elif function == 'cumulative_apply': assert self.temporal_processing == 'cumulative' return SignatureDict(singleton=self.output_spec().signature( batched=True)) elif function == 'iterative_apply': assert self.temporal_processing == 'iterative' return SignatureDict( x=self.output_spec().signature(batched=True), internals=self.internals_spec.signature(batched=True)) elif function == 'iterative_body': assert self.temporal_processing == 'iterative' return SignatureDict( x=self.input_spec.signature(batched=True), indices=TensorSpec(type='int', shape=()).signature(batched=True), remaining=TensorSpec(type='int', shape=()).signature(batched=True), current_x=self.output_spec().signature(batched=True), current_internals=self.internals_spec.signature(batched=True)) elif function == 'past_horizon': return SignatureDict( singleton=TensorSpec(type='int', shape=()).signature( batched=False)) else: return super().output_signature(function=function)
def input_signature(self, *, function): if function == 'entropy': return SignatureDict( states=self.states_spec.signature(batched=True), horizons=TensorSpec(type='int', shape=(2,)).signature(batched=True), internals=self.internals_spec.signature(batched=True), auxiliaries=self.auxiliaries_spec.signature(batched=True) ) elif function == 'entropies': return SignatureDict( states=self.states_spec.signature(batched=True), horizons=TensorSpec(type='int', shape=(2,)).signature(batched=True), internals=self.internals_spec.signature(batched=True), auxiliaries=self.auxiliaries_spec.signature(batched=True) ) elif function == 'kl_divergence': return SignatureDict( states=self.states_spec.signature(batched=True), horizons=TensorSpec(type='int', shape=(2,)).signature(batched=True), internals=self.internals_spec.signature(batched=True), auxiliaries=self.auxiliaries_spec.signature(batched=True), reference=self.distributions.fmap( function=(lambda x: x.parameters_spec), cls=TensorsSpec ).signature(batched=True) ) elif function == 'kl_divergences': return SignatureDict( states=self.states_spec.signature(batched=True), horizons=TensorSpec(type='int', shape=(2,)).signature(batched=True), internals=self.internals_spec.signature(batched=True), auxiliaries=self.auxiliaries_spec.signature(batched=True), reference=self.distributions.fmap( function=(lambda x: x.parameters_spec), cls=TensorsSpec ).signature(batched=True) ) elif function == 'kldiv_reference': return SignatureDict( states=self.states_spec.signature(batched=True), horizons=TensorSpec(type='int', shape=(2,)).signature(batched=True), internals=self.internals_spec.signature(batched=True), auxiliaries=self.auxiliaries_spec.signature(batched=True) ) elif function == 'log_probability': return SignatureDict( states=self.states_spec.signature(batched=True), horizons=TensorSpec(type='int', shape=(2,)).signature(batched=True), internals=self.internals_spec.signature(batched=True), auxiliaries=self.auxiliaries_spec.signature(batched=True), actions=self.actions_spec.signature(batched=True) ) elif function == 'log_probabilities': return SignatureDict( states=self.states_spec.signature(batched=True), horizons=TensorSpec(type='int', shape=(2,)).signature(batched=True), internals=self.internals_spec.signature(batched=True), auxiliaries=self.auxiliaries_spec.signature(batched=True), actions=self.actions_spec.signature(batched=True) ) elif function == 'sample_actions': return SignatureDict( states=self.states_spec.signature(batched=True), horizons=TensorSpec(type='int', shape=(2,)).signature(batched=True), internals=self.internals_spec.signature(batched=True), auxiliaries=self.auxiliaries_spec.signature(batched=True), temperatures=self.actions_spec.fmap( function=(lambda _: TensorSpec(type='float', shape=())) ).signature(batched=False) ) else: return super().input_signature(function=function)
def input_signature(self, *, function): if function == 'regularize': return SignatureDict() else: raise NotImplementedError
def input_signature(self, *, function): if function == 'apply': return SignatureDict(x=self.input_spec.signature(batched=True)) else: return super().input_signature(function=function)
def input_signature(self, *, function): if function == 'reset': return SignatureDict() else: return super().input_signature(function=function)
def output_signature(self, *, function): if function == 'value': return SignatureDict(singleton=self.spec.signature(batched=False)) else: return super().output_signature(function=function)
def output_signature(self, *, function): if function == 'enqueue': return SignatureDict( singleton=TensorSpec(type='bool', shape=()).signature(batched=False) ) elif function == 'predecessors': def get_output_signature(sequence_values, initial_values): if len(sequence_values) == 0: if len(initial_values) == 0: return SignatureDict( singleton=TensorSpec(type='int', shape=()).signature(batched=True) ) else: return SignatureDict( lengths=TensorSpec(type='int', shape=()).signature(batched=True), initial_values=self.values_spec[initial_values].signature(batched=True) ) elif len(initial_values) == 0: return SignatureDict( starts_lengths=TensorSpec(type='int', shape=(2,)).signature(batched=True), sequence_values=self.values_spec[sequence_values].signature(batched=True) ) else: return SignatureDict( starts_lengths=TensorSpec(type='int', shape=(2,)).signature(batched=True), sequence_values=self.values_spec[sequence_values].signature(batched=True), initial_values=self.values_spec[initial_values].signature(batched=True) ) return get_output_signature elif function == 'reset': return SignatureDict( singleton=TensorSpec(type='bool', shape=()).signature(batched=False) ) elif function == 'retrieve': def get_output_signature(values): return SignatureDict(singleton=self.values_spec[values].signature(batched=True)) return get_output_signature elif function == 'retrieve_episodes': return SignatureDict(singleton=TensorSpec(type='int', shape=()).signature(batched=True)) elif function == 'retrieve_timesteps': return SignatureDict(singleton=TensorSpec(type='int', shape=()).signature(batched=True)) elif function == 'successors': def get_output_signature(sequence_values, final_values): if len(sequence_values) == 0: if len(final_values) == 0: return SignatureDict( singleton=TensorSpec(type='int', shape=()).signature(batched=True) ) else: return SignatureDict( lengths=TensorSpec(type='int', shape=()).signature(batched=True), final_values=self.values_spec[final_values].signature(batched=True) ) elif len(final_values) == 0: return SignatureDict( starts_lengths=TensorSpec(type='int', shape=(2,)).signature(batched=True), sequence_values=self.values_spec[sequence_values].signature(batched=True) ) else: return SignatureDict( starts_lengths=TensorSpec(type='int', shape=(2,)).signature(batched=True), sequence_values=self.values_spec[sequence_values].signature(batched=True), final_values=self.values_spec[final_values].signature(batched=True) ) return get_output_signature else: return super().output_signature(function=function)
def get_output_signature(values): return SignatureDict(singleton=self.values_spec[values].signature(batched=True))
def input_signature(self, *, function): if function == 'step' or function == 'update': return SignatureDict(arguments=self.arguments_spec.signature(batched=True)) else: return super().input_signature(function=function)