예제 #1
0
파일: model.py 프로젝트: xkarlx/tensorforce
    def input_signature(self, *, function):
        if function == 'act':
            return SignatureDict(
                states=self.states_spec.signature(batched=True),
                auxiliaries=self.auxiliaries_spec.signature(batched=True),
                parallel=self.parallel_spec.signature(batched=True))

        elif function == 'core_act':
            return SignatureDict(
                states=self.states_spec.signature(batched=True),
                internals=self.internals_spec.signature(batched=True),
                auxiliaries=self.auxiliaries_spec.signature(batched=True),
                parallel=self.parallel_spec.signature(batched=True),
                deterministic=self.deterministic_spec.signature(batched=False))

        elif function == 'core_observe':
            return SignatureDict(
                terminal=self.terminal_spec.signature(batched=True),
                reward=self.reward_spec.signature(batched=True),
                parallel=self.parallel_spec.signature(batched=False))

        elif function == 'independent_act':
            signature = SignatureDict(states=self.states_spec.signature(
                batched=True))
            if len(self.internals_spec) > 0:
                signature['internals'] = self.internals_spec.signature(
                    batched=True)
            if len(self.auxiliaries_spec) > 0:
                signature['auxiliaries'] = self.auxiliaries_spec.signature(
                    batched=True)
            signature['deterministic'] = self.deterministic_spec.signature(
                batched=False)
            return signature

        elif function == 'observe':
            return SignatureDict(
                terminal=self.terminal_spec.signature(batched=True),
                reward=self.reward_spec.signature(batched=True),
                parallel=self.parallel_spec.signature(batched=False))

        elif function == 'reset':
            return SignatureDict()

        else:
            return super().input_signature(function=function)
예제 #2
0
    def input_signature(self, *, function):
        if function == 'enqueue':
            return self.values_spec.signature(batched=True)

        elif function == 'predecessors':
            return SignatureDict(
                indices=TensorSpec(type='int',
                                   shape=()).signature(batched=True),
                horizon=TensorSpec(type='int',
                                   shape=()).signature(batched=False))

        elif function == 'reset':
            return SignatureDict()

        elif function == 'retrieve':
            return SignatureDict(
                indices=TensorSpec(type='int', shape=()).signature(
                    batched=True))

        elif function == 'retrieve_episodes':
            return SignatureDict(n=TensorSpec(type='int', shape=()).signature(
                batched=False))

        elif function == 'retrieve_timesteps':
            return SignatureDict(
                n=TensorSpec(type='int', shape=()).signature(batched=False),
                past_horizon=TensorSpec(type='int',
                                        shape=()).signature(batched=False),
                future_horizon=TensorSpec(type='int',
                                          shape=()).signature(batched=False))

        elif function == 'successors':
            return SignatureDict(
                indices=TensorSpec(type='int',
                                   shape=()).signature(batched=True),
                horizon=TensorSpec(type='int',
                                   shape=()).signature(batched=False))

        else:
            return super().input_signature(function=function)
예제 #3
0
    def output_signature(self, *, function):
        if function == 'apply':
            if self.temporal_processing == 'cumulative':
                return SignatureDict(singleton=self.output_spec().signature(
                    batched=True))
            elif self.temporal_processing == 'iterative':
                return SignatureDict(
                    x=self.output_spec().signature(batched=True),
                    internals=self.internals_spec.signature(batched=True))

        elif function == 'cumulative_apply':
            assert self.temporal_processing == 'cumulative'
            return SignatureDict(singleton=self.output_spec().signature(
                batched=True))

        elif function == 'iterative_apply':
            assert self.temporal_processing == 'iterative'
            return SignatureDict(
                x=self.output_spec().signature(batched=True),
                internals=self.internals_spec.signature(batched=True))

        elif function == 'iterative_body':
            assert self.temporal_processing == 'iterative'
            return SignatureDict(
                x=self.input_spec.signature(batched=True),
                indices=TensorSpec(type='int',
                                   shape=()).signature(batched=True),
                remaining=TensorSpec(type='int',
                                     shape=()).signature(batched=True),
                current_x=self.output_spec().signature(batched=True),
                current_internals=self.internals_spec.signature(batched=True))

        elif function == 'past_horizon':
            return SignatureDict(
                singleton=TensorSpec(type='int', shape=()).signature(
                    batched=False))

        else:
            return super().output_signature(function=function)
예제 #4
0
    def input_signature(self, *, function):
        if function == 'entropy':
            return SignatureDict(
                states=self.states_spec.signature(batched=True),
                horizons=TensorSpec(type='int', shape=(2,)).signature(batched=True),
                internals=self.internals_spec.signature(batched=True),
                auxiliaries=self.auxiliaries_spec.signature(batched=True)
            )

        elif function == 'entropies':
            return SignatureDict(
                states=self.states_spec.signature(batched=True),
                horizons=TensorSpec(type='int', shape=(2,)).signature(batched=True),
                internals=self.internals_spec.signature(batched=True),
                auxiliaries=self.auxiliaries_spec.signature(batched=True)
            )

        elif function == 'kl_divergence':
            return SignatureDict(
                states=self.states_spec.signature(batched=True),
                horizons=TensorSpec(type='int', shape=(2,)).signature(batched=True),
                internals=self.internals_spec.signature(batched=True),
                auxiliaries=self.auxiliaries_spec.signature(batched=True),
                reference=self.distributions.fmap(
                    function=(lambda x: x.parameters_spec), cls=TensorsSpec
                ).signature(batched=True)
            )

        elif function == 'kl_divergences':
            return SignatureDict(
                states=self.states_spec.signature(batched=True),
                horizons=TensorSpec(type='int', shape=(2,)).signature(batched=True),
                internals=self.internals_spec.signature(batched=True),
                auxiliaries=self.auxiliaries_spec.signature(batched=True),
                reference=self.distributions.fmap(
                    function=(lambda x: x.parameters_spec), cls=TensorsSpec
                ).signature(batched=True)
            )

        elif function == 'kldiv_reference':
            return SignatureDict(
                states=self.states_spec.signature(batched=True),
                horizons=TensorSpec(type='int', shape=(2,)).signature(batched=True),
                internals=self.internals_spec.signature(batched=True),
                auxiliaries=self.auxiliaries_spec.signature(batched=True)
            )

        elif function == 'log_probability':
            return SignatureDict(
                states=self.states_spec.signature(batched=True),
                horizons=TensorSpec(type='int', shape=(2,)).signature(batched=True),
                internals=self.internals_spec.signature(batched=True),
                auxiliaries=self.auxiliaries_spec.signature(batched=True),
                actions=self.actions_spec.signature(batched=True)
            )

        elif function == 'log_probabilities':
            return SignatureDict(
                states=self.states_spec.signature(batched=True),
                horizons=TensorSpec(type='int', shape=(2,)).signature(batched=True),
                internals=self.internals_spec.signature(batched=True),
                auxiliaries=self.auxiliaries_spec.signature(batched=True),
                actions=self.actions_spec.signature(batched=True)
            )

        elif function == 'sample_actions':
            return SignatureDict(
                states=self.states_spec.signature(batched=True),
                horizons=TensorSpec(type='int', shape=(2,)).signature(batched=True),
                internals=self.internals_spec.signature(batched=True),
                auxiliaries=self.auxiliaries_spec.signature(batched=True),
                temperatures=self.actions_spec.fmap(
                    function=(lambda _: TensorSpec(type='float', shape=()))
                ).signature(batched=False)
            )

        else:
            return super().input_signature(function=function)
예제 #5
0
    def input_signature(self, *, function):
        if function == 'regularize':
            return SignatureDict()

        else:
            raise NotImplementedError
예제 #6
0
    def input_signature(self, *, function):
        if function == 'apply':
            return SignatureDict(x=self.input_spec.signature(batched=True))

        else:
            return super().input_signature(function=function)
예제 #7
0
    def input_signature(self, *, function):
        if function == 'reset':
            return SignatureDict()

        else:
            return super().input_signature(function=function)
예제 #8
0
    def output_signature(self, *, function):
        if function == 'value':
            return SignatureDict(singleton=self.spec.signature(batched=False))

        else:
            return super().output_signature(function=function)
예제 #9
0
    def output_signature(self, *, function):
        if function == 'enqueue':
            return SignatureDict(
                singleton=TensorSpec(type='bool', shape=()).signature(batched=False)
            )

        elif function == 'predecessors':
            def get_output_signature(sequence_values, initial_values):
                if len(sequence_values) == 0:
                    if len(initial_values) == 0:
                        return SignatureDict(
                            singleton=TensorSpec(type='int', shape=()).signature(batched=True)
                        )
                    else:
                        return SignatureDict(
                            lengths=TensorSpec(type='int', shape=()).signature(batched=True),
                            initial_values=self.values_spec[initial_values].signature(batched=True)
                        )
                elif len(initial_values) == 0:
                    return SignatureDict(
                        starts_lengths=TensorSpec(type='int', shape=(2,)).signature(batched=True),
                        sequence_values=self.values_spec[sequence_values].signature(batched=True)
                    )
                else:
                    return SignatureDict(
                        starts_lengths=TensorSpec(type='int', shape=(2,)).signature(batched=True),
                        sequence_values=self.values_spec[sequence_values].signature(batched=True),
                        initial_values=self.values_spec[initial_values].signature(batched=True)
                    )
            return get_output_signature

        elif function == 'reset':
            return SignatureDict(
                singleton=TensorSpec(type='bool', shape=()).signature(batched=False)
            )

        elif function == 'retrieve':
            def get_output_signature(values):
                return SignatureDict(singleton=self.values_spec[values].signature(batched=True))
            return get_output_signature

        elif function == 'retrieve_episodes':
            return SignatureDict(singleton=TensorSpec(type='int', shape=()).signature(batched=True))

        elif function == 'retrieve_timesteps':
            return SignatureDict(singleton=TensorSpec(type='int', shape=()).signature(batched=True))

        elif function == 'successors':
            def get_output_signature(sequence_values, final_values):
                if len(sequence_values) == 0:
                    if len(final_values) == 0:
                        return SignatureDict(
                            singleton=TensorSpec(type='int', shape=()).signature(batched=True)
                        )
                    else:
                        return SignatureDict(
                            lengths=TensorSpec(type='int', shape=()).signature(batched=True),
                            final_values=self.values_spec[final_values].signature(batched=True)
                        )
                elif len(final_values) == 0:
                    return SignatureDict(
                        starts_lengths=TensorSpec(type='int', shape=(2,)).signature(batched=True),
                        sequence_values=self.values_spec[sequence_values].signature(batched=True)
                    )
                else:
                    return SignatureDict(
                        starts_lengths=TensorSpec(type='int', shape=(2,)).signature(batched=True),
                        sequence_values=self.values_spec[sequence_values].signature(batched=True),
                        final_values=self.values_spec[final_values].signature(batched=True)
                    )
            return get_output_signature

        else:
            return super().output_signature(function=function)
예제 #10
0
 def get_output_signature(values):
     return SignatureDict(singleton=self.values_spec[values].signature(batched=True))
예제 #11
0
    def input_signature(self, *, function):
        if function == 'step' or function == 'update':
            return SignatureDict(arguments=self.arguments_spec.signature(batched=True))

        else:
            return super().input_signature(function=function)