예제 #1
0
def create_feedforward_Q_function(input_shapes,
                                  *args,
                                  preprocessors=None,
                                  observation_keys=None,
                                  goal_keys=None,
                                  name='feedforward_Q',
                                  **kwargs):
    inputs_flat = create_inputs(input_shapes)
    preprocessors_flat = (flatten_input_structure(preprocessors)
                          if preprocessors is not None else tuple(
                              None for _ in inputs_flat))

    assert len(inputs_flat) == len(preprocessors_flat), (inputs_flat,
                                                         preprocessors_flat)

    preprocessed_inputs = [
        preprocessor(input_) if preprocessor is not None else input_
        for preprocessor, input_ in zip(preprocessors_flat, inputs_flat)
    ]

    Q_function = feedforward_model(*args, output_size=1, name=name, **kwargs)

    Q_function = PicklableModel(inputs_flat, Q_function(preprocessed_inputs))
    preprocessed_inputs_fn = PicklableModel(inputs_flat, preprocessed_inputs)

    Q_function.observation_keys = observation_keys or ()
    Q_function.goal_keys = goal_keys or ()
    Q_function.all_keys = observation_keys + goal_keys

    Q_function.actions_preprocessors = preprocessors['actions']
    Q_function.observations_preprocessors = preprocessors['observations']

    Q_function.preprocessed_inputs_fn = preprocessed_inputs_fn
    return Q_function
예제 #2
0
    def __init__(self,
                 input_shapes,
                 output_shape,
                 action_range,
                 *args,
                 preprocessors=None,
                 **kwargs):
        self._Serializable__initialize(locals())

        self._output_shape = output_shape
        self._action_range = action_range

        super(UniformPolicy, self).__init__(*args, **kwargs)

        self.inputs = create_inputs(input_shapes)

        x = self.inputs

        batch_size = tf.keras.layers.Lambda(lambda x: tf.shape(x)[0])(
            tree.flatten(self.inputs)[0])

        actions = tf.keras.layers.Lambda(self._actions_fn)(batch_size)

        self.actions_model = tf.keras.Model(self.inputs, actions)

        self.actions_input = tf.keras.Input(shape=output_shape, name='actions')

        log_pis = tf.keras.layers.Lambda(self._log_pis_fn)(self.actions_input)

        self.log_pis_model = tf.keras.Model((self.inputs, self.actions_input),
                                            log_pis)
예제 #3
0
def create_feedforward_Q_function(input_shapes,
                                  *args,
                                  preprocessors=None,
                                  observation_keys=None,
                                  name='feedforward_Q',
                                  **kwargs):
    print(input_shapes)
    inputs_flat = create_inputs(input_shapes)
    preprocessors_flat = (flatten_input_structure(preprocessors)
                          if preprocessors is not None else tuple(
                              None for _ in inputs_flat))

    assert len(inputs_flat) == len(preprocessors_flat), (inputs_flat,
                                                         preprocessors_flat)

    preprocessed_inputs = [
        tf.cast(preprocessor(input_), dtype=tf.float32)
        if preprocessor is not None else tf.cast(input_, dtype=tf.float32)
        for preprocessor, input_ in zip(preprocessors_flat, inputs_flat)
    ]

    Q_function = feedforward_model(*args, output_size=1, name=name, **kwargs)

    Q_function = PicklableModel(inputs_flat, Q_function(preprocessed_inputs))
    Q_function.observation_keys = observation_keys

    return Q_function
예제 #4
0
    def __init__(self,
                 input_shapes,
                 output_shape,
                 *args,
                 preprocessors=None,
                 **kwargs):
        self._Serializable__initialize(locals())

        super(UniformPolicy, self).__init__(*args, **kwargs)

        inputs_flat = create_inputs(input_shapes)

        self.inputs = inputs_flat

        x = self.inputs

        batch_size = tf.keras.layers.Lambda(lambda x: tf.shape(x)[0])(
            inputs_flat[0])

        actions = tf.keras.layers.Lambda(lambda batch_size: tf.random.uniform(
            (batch_size, output_shape[0])))(batch_size)

        actions = tf.one_hot(tf.cast(tf.argmax(actions, axis=-1), tf.int32),
                             output_shape[0])

        self.actions_model = tf.keras.Model(self.inputs, actions)

        self.actions_input = tf.keras.Input(shape=output_shape, name='actions')
예제 #5
0
def create_embedding_fn(input_shapes,
                        embedding_dim,
                        *args,
                        preprocessors=None,
                        observation_keys=None,
                        goal_keys=None,
                        name='embedding_fn',
                        **kwargs):
    inputs_flat = create_inputs(input_shapes)
    preprocessors_flat = (flatten_input_structure(preprocessors)
                          if preprocessors is not None else tuple(
                              None for _ in inputs_flat))

    assert len(inputs_flat) == len(preprocessors_flat), (inputs_flat,
                                                         preprocessors_flat)

    preprocessed_inputs = [
        preprocessor(input_) if preprocessor is not None else input_
        for preprocessor, input_ in zip(preprocessors_flat, inputs_flat)
    ]

    embedding_fn = feedforward_model(*args,
                                     output_size=embedding_dim,
                                     name=f'feedforward_{name}',
                                     **kwargs)

    embedding_fn = PicklableModel(inputs_flat,
                                  embedding_fn(preprocessed_inputs),
                                  name=name)

    embedding_fn.observation_keys = observation_keys or tuple()
    embedding_fn.goal_keys = goal_keys or tuple()
    embedding_fn.all_keys = embedding_fn.observation_keys + embedding_fn.goal_keys

    return embedding_fn
예제 #6
0
    def __init__(self,
                 input_shapes,
                 output_shape,
                 *args,
                 action_range=np.array(((-1.0, ), (1.0, ))),
                 preprocessors=None,
                 **kwargs):
        self._Serializable__initialize(locals())

        super(UniformPolicy, self).__init__(*args, **kwargs)

        inputs_flat = create_inputs(input_shapes)

        self.inputs = inputs_flat

        self._action_range = action_range

        x = self.inputs

        batch_size = tf.keras.layers.Lambda(lambda x: tf.shape(x)[0])(
            inputs_flat[0])

        actions = tf.keras.layers.Lambda(lambda batch_size: tf.random.uniform(
            (batch_size, output_shape[0]), *action_range))(batch_size)

        self.actions_model = tf.keras.Model(self.inputs, actions)

        self.actions_input = tf.keras.Input(shape=output_shape, name='actions')

        log_pis = tf.keras.layers.Lambda(lambda x: tf.tile(
            tf.math.log((action_range[1] - action_range[0]) / 2.0)[None],
            (tf.shape(input=x)[0], 1)))(self.actions_input)

        self.log_pis_model = tf.keras.Model((*self.inputs, self.actions_input),
                                            log_pis)
예제 #7
0
def feedforward_Q_function(input_shapes,
                           *args,
                           preprocessors=None,
                           observation_keys=None,
                           name='feedforward_Q',
                           **kwargs):
    inputs = create_inputs(input_shapes)

    if preprocessors is None:
        preprocessors = tree.map_structure(lambda _: None, inputs)

    preprocessors = tree.map_structure_up_to(inputs,
                                             preprocessors_lib.deserialize,
                                             preprocessors)

    preprocessed_inputs = apply_preprocessors(preprocessors, inputs)

    # NOTE(hartikainen): `feedforward_model` would do the `cast_and_concat`
    # step for us, but tf2.2 broke the sequential multi-input handling: See:
    # https://github.com/tensorflow/tensorflow/issues/37061.
    out = tf.keras.layers.Lambda(cast_and_concat)(preprocessed_inputs)
    Q_model_body = feedforward_model(*args,
                                     output_shape=[1],
                                     name=name,
                                     **kwargs)

    Q_model = tf.keras.Model(inputs, Q_model_body(out), name=name)

    Q_function = StateActionValueFunction(model=Q_model,
                                          observation_keys=observation_keys,
                                          name=name)

    return Q_function
예제 #8
0
def create_dynamics_model(input_shapes,
                          dynamics_latent_dim,
                          *args,
                          preprocessors=None,
                          observation_keys=None,
                          goal_keys=None,
                          name='dynamics_model',
                          encoder_kwargs=None,
                          decoder_kwargs=None,
                          **kwargs):
    inputs_flat = create_inputs(input_shapes)
    preprocessors_flat = (
        flatten_input_structure(preprocessors)
        if preprocessors is not None
        else tuple(None for _ in inputs_flat))

    assert len(inputs_flat) == len(preprocessors_flat), (
        inputs_flat, preprocessors_flat)

    preprocessed_inputs = [
        preprocessor(input_) if preprocessor is not None else input_
        for preprocessor, input_
        in zip(preprocessors_flat, inputs_flat)
    ]
    encoder = feedforward_model(
        *args,
        output_size=dynamics_latent_dim,
        name=f'{name}_encoder',
        **encoder_kwargs)

    output_size = sum([
        shape.as_list()[0]
        for shape in input_shapes['observations'].values()
    ])
    decoder = feedforward_model(
        *args,
        output_size=output_size,
        name=f'{name}_decoder',
        **decoder_kwargs)

    latent = encoder(preprocessed_inputs)
    dynamics_pred = decoder(latent)

    dynamics_model = PicklableModel(inputs_flat, dynamics_pred, name=name)

    dynamics_model.observation_keys = observation_keys or tuple()
    dynamics_model.goal_keys = goal_keys or tuple()
    dynamics_model.all_keys = dynamics_model.observation_keys + dynamics_model.goal_keys

    dynamics_model.encoder = PicklableModel(inputs_flat, latent, name=f'{name}_encoder_model')

    return dynamics_model
예제 #9
0
def get_rnd_networks_from_variant(variant, env):
    rnd_params = variant['algorithm_params']['rnd_params']
    target_network = None
    predictor_network = None

    observation_keys = variant['policy_params']['kwargs']['observation_keys']
    if not observation_keys:
        observation_keys = env.observation_keys
    observation_shapes = OrderedDict(
        ((key, value) for key, value in env.observation_shape.items()
         if key in observation_keys))

    inputs_flat = create_inputs(observation_shapes)

    target_network, predictor_network = [], []
    for input_tensor in inputs_flat:
        if 'pixels' in input_tensor.name:  # check logic
            from softlearning.preprocessors.utils import get_convnet_preprocessor
            target_network.append(
                get_convnet_preprocessor(
                    'rnd_target_conv',
                    **rnd_params['convnet_params'])(input_tensor))
            predictor_network.append(
                get_convnet_preprocessor(
                    'rnd_predictor_conv',
                    **rnd_params['convnet_params'])(input_tensor))
        else:
            target_network.append(input_tensor)
            predictor_network.append(input_tensor)

    target_network = tf.keras.layers.Lambda(
        lambda inputs: tf.concat(training_utils.cast_if_floating_dtype(inputs),
                                 axis=-1))(target_network)

    predictor_network = tf.keras.layers.Lambda(
        lambda inputs: tf.concat(training_utils.cast_if_floating_dtype(inputs),
                                 axis=-1))(predictor_network)

    target_network = get_feedforward_preprocessor(
        'rnd_target_fc', **rnd_params['fc_params'])(target_network)

    predictor_network = get_feedforward_preprocessor(
        'rnd_predictor_fc', **rnd_params['fc_params'])(predictor_network)

    # Initialize RN weights
    target_network = PicklableModel(inputs_flat, target_network)
    target_network.set_weights([
        np.random.normal(0, 0.1, size=weight.shape)
        for weight in target_network.get_weights()
    ])
    predictor_network = PicklableModel(inputs_flat, predictor_network)
    return target_network, predictor_network
예제 #10
0
def create_feedforward_Q_function(input_shapes,
                                  *args,
                                  preprocessors=None,
                                  observation_keys=None,
                                  name='feedforward_Q',
                                  **kwargs):
    inputs = create_inputs(input_shapes)
    if preprocessors is None:
        preprocessors = tree.map_structure(lambda _: None, inputs)

    preprocessed_inputs = apply_preprocessors(preprocessors, inputs)

    Q_function = feedforward_model(*args, output_size=1, name=name, **kwargs)

    Q_function = PicklableModel(inputs, Q_function(preprocessed_inputs))
    Q_function.observation_keys = observation_keys

    return Q_function
예제 #11
0
    def __init__(self,
                 input_shapes,
                 output_shape,
                 observation_keys=None,
                 preprocessors=None,
                 name='policy'):
        self._input_shapes = input_shapes
        self._output_shape = output_shape
        self._observation_keys = observation_keys
        self._inputs = create_inputs(input_shapes)

        if preprocessors is None:
            preprocessors = tree.map_structure(lambda x: None, input_shapes)

        preprocessors = tree.map_structure_up_to(
            input_shapes, preprocessors_lib.deserialize, preprocessors)

        self._preprocessors = preprocessors

        self._name = name
예제 #12
0
def create_distance_estimator(input_shapes,
                              *args,
                              preprocessors=None,
                              observation_keys=None,
                              goal_keys=None,
                              name='distance_estimator',
                              classifier_params=None,
                              **kwargs):
    inputs_flat = create_inputs(input_shapes)
    preprocessors_flat = (flatten_input_structure(preprocessors)
                          if preprocessors is not None else tuple(
                              None for _ in inputs_flat))

    assert len(inputs_flat) == len(preprocessors_flat), (inputs_flat,
                                                         preprocessors_flat)

    preprocessed_inputs = [
        preprocessor(input_) if preprocessor is not None else input_
        for preprocessor, input_ in zip(preprocessors_flat, inputs_flat)
    ]

    output_size = 1 if not classifier_params else int(
        classifier_params.get('bins', 1) + 1)

    distance_fn = feedforward_model(*args,
                                    output_size=output_size,
                                    name=name,
                                    **kwargs)

    distance_fn = PicklableModel(inputs_flat, distance_fn(preprocessed_inputs))
    # preprocessed_inputs_fn = PicklableModel(inputs_flat, preprocessed_inputs)

    distance_fn.observation_keys = observation_keys or tuple()
    distance_fn.goal_keys = goal_keys or tuple()
    distance_fn.all_keys = distance_fn.observation_keys + distance_fn.goal_keys
    distance_fn.classifier_params = classifier_params

    # distance_fn.observations_preprocessors = preprocessors['s1']
    # distance_fn.preprocessed_inputs_fn = preprocessed_inputs_fn
    return distance_fn
예제 #13
0
def create_feedforward_reward_classifier_function(
        input_shapes,
        *args,
        preprocessors=None,
        observation_keys=None,
        name='feedforward_reward_classifier',
        kernel_regularizer_lambda=1e-3,
        # output_activation=tf.math.log_sigmoid,
        **kwargs):
    inputs_flat = create_inputs(input_shapes)
    preprocessors_flat = (flatten_input_structure(preprocessors)
                          if preprocessors is not None else tuple(
                              None for _ in inputs_flat))

    assert len(inputs_flat) == len(preprocessors_flat), (inputs_flat,
                                                         preprocessors_flat)

    preprocessed_inputs = [
        preprocessor(input_) if preprocessor is not None else input_
        for preprocessor, input_ in zip(preprocessors_flat, inputs_flat)
    ]

    reward_classifier_function = feedforward_model(
        *args,
        output_size=1,
        kernel_regularizer=tf.keras.regularizers.l2(kernel_regularizer_lambda)
        if kernel_regularizer_lambda else None,
        name=name,
        # output_activation=output_activation,
        **kwargs)

    # from IPython import embed; embed()
    reward_classifier_function = PicklableModel(
        inputs_flat, reward_classifier_function(preprocessed_inputs))
    reward_classifier_function.observation_keys = observation_keys
    reward_classifier_function.observations_preprocessors = preprocessors

    return reward_classifier_function
예제 #14
0
def halite_Q_function(input_shapes,
                      *args,
                      observation_keys=None,
                      name='halite_Q',
                      **kwargs):
    """
    Args:
        input_shapes:(
            map_shape: [x, y, number of features layers],
            scalar_features_length: number of scalar features,
            actions_length: it should be 1
            )
        observation_keys: "compute values given observations"
        name: a name

    Returns:
        keras model which predicts q values - estimated reward
    """
    # it prepares the input layers
    inputs = create_inputs(input_shapes, dtypes=tf.float32)

    obs, input_actions = inputs
    input_map = obs["feature_maps"]
    input_scalar = obs["scalar_features"]

    # conv_net_output = custom_resnet_model(input_map)
    conv_net_output = tf.keras.layers.Flatten()(input_map)
    concat = tf.keras.layers.concatenate([conv_net_output, input_scalar, input_actions])
    dense1 = tf.keras.layers.Dense(1024, activation="relu")(concat)
    dense2 = tf.keras.layers.Dense(1024, activation="relu")(dense1)
    output = tf.keras.layers.Dense(1, activation="linear", name="output")(dense2)

    Q_model = tf.keras.Model(inputs=inputs, outputs=output, name=name)
    Q_function = StateActionValueFunction(
        model=Q_model, observation_keys=observation_keys, name=name)

    return Q_function
예제 #15
0
    def __init__(self,
                 input_shapes,
                 output_shape,
                 action_range,
                 *args,
                 squash=True,
                 preprocessors=None,
                 hidden_layer_sizes=(128, 128),
                 num_coupling_layers=2,
                 name=None,
                 **kwargs):

        raise NotImplementedError(
            "TODO(hartikainen): RealNVPPolicy is currently broken. The keras"
            " models together with the tfp distributions somehow count the"
            "variables multiple times. This needs to be fixed before usage.")
        assert (np.all(action_range == np.array([[-1], [1]]))), (
            "The action space should be scaled to (-1, 1)."
            " TODO(hartikainen): We should support non-scaled actions spaces.")

        self._Serializable__initialize(locals())

        self._action_range = action_range
        self._input_shapes = input_shapes
        self._output_shape = output_shape
        self._squash = squash
        self._name = name

        super(RealNVPPolicy, self).__init__(*args, **kwargs)

        inputs = create_inputs(input_shapes)
        if preprocessors is None:
            preprocessors = tree.map_structure(lambda _: None, inputs)

        preprocessed_inputs = apply_preprocessors(preprocessors, inputs)

        conditions = tf.keras.layers.Lambda(cast_and_concat)(
            preprocessed_inputs)

        self.condition_inputs = inputs

        batch_size = tf.keras.layers.Lambda(lambda x: tf.shape(input=x)[0])(
            conditions)

        base_distribution = tfp.distributions.MultivariateNormalDiag(
            loc=tf.zeros(output_shape), scale_diag=tf.ones(output_shape))

        flow_model = RealNVPFlow(num_coupling_layers=num_coupling_layers,
                                 hidden_layer_sizes=hidden_layer_sizes)

        flow_distribution = flow_model(base_distribution)

        latents = base_distribution.sample(batch_size)

        self.latents_model = tf.keras.Model(self.condition_inputs, latents)
        self.latents_input = tf.keras.layers.Input(shape=output_shape,
                                                   name='latents')

        raw_actions = flow_distribution.bijector.forward(latents,
                                                         conditions=conditions)

        raw_actions_for_fixed_latents = flow_distribution.bijector.forward(
            self.latents_input, conditions=conditions)

        squash_bijector = (tfp.bijectors.Tanh()
                           if self._squash else tfp.bijectors.Identity())

        actions = squash_bijector(raw_actions)
        self.actions_model = tf.keras.Model(self.condition_inputs, actions)

        actions_for_fixed_latents = squash_bijector(raw_actions)
        self.actions_model_for_fixed_latents = tf.keras.Model(
            (self.condition_inputs, self.latents_input),
            actions_for_fixed_latents)

        self.deterministic_actions_model = self.actions_model

        self.actions_input = tf.keras.layers.Input(shape=output_shape,
                                                   name='actions')

        log_pis = flow_distribution.log_prob(actions)[..., tf.newaxis]
        log_pis_for_action_input = flow_distribution.log_prob(
            self.actions_input)[..., tf.newaxis]

        self.log_pis_model = tf.keras.Model(
            (self.condition_inputs, self.actions_input),
            log_pis_for_action_input)

        self.diagnostics_model = tf.keras.Model(
            self.condition_inputs, (log_pis, raw_actions, actions))
예제 #16
0
    def __init__(self,
                 input_shapes,
                 output_shape,
                 *args,
                 squash=True,
                 preprocessors=None,
                 name=None,
                 **kwargs):
        self._Serializable__initialize(locals())

        self._input_shapes = input_shapes
        self._output_shape = output_shape
        self._squash = squash
        self._name = name

        super(GaussianPolicy, self).__init__(*args, **kwargs)

        inputs_flat = create_inputs(input_shapes)
        preprocessors_flat = (flatten_input_structure(preprocessors)
                              if preprocessors is not None else tuple(
                                  None for _ in inputs_flat))

        assert len(inputs_flat) == len(preprocessors_flat), (
            inputs_flat, preprocessors_flat)

        preprocessed_inputs = [
            preprocessor(input_) if preprocessor is not None else input_
            for preprocessor, input_ in zip(preprocessors_flat, inputs_flat)
        ]

        float_inputs = tf.keras.layers.Lambda(
            lambda inputs: training_utils.cast_if_floating_dtype(inputs))(
                preprocessed_inputs)

        conditions = tf.keras.layers.Lambda(
            lambda inputs: tf.concat(inputs, axis=-1))(float_inputs)

        self.condition_inputs = inputs_flat

        shift_and_log_scale_diag = self._shift_and_log_scale_diag_net(
            output_size=output_shape[0] * 2, )(conditions)

        shift, log_scale_diag = tf.keras.layers.Lambda(
            lambda shift_and_log_scale_diag: tf.split(
                shift_and_log_scale_diag, num_or_size_splits=2, axis=-1))(
                    shift_and_log_scale_diag)

        log_scale_diag = tf.keras.layers.Lambda(
            lambda log_scale_diag: tf.clip_by_value(
                log_scale_diag, *SCALE_DIAG_MIN_MAX))(log_scale_diag)

        batch_size = tf.keras.layers.Lambda(lambda x: tf.shape(input=x)[0])(
            conditions)

        base_distribution = tfp.distributions.MultivariateNormalDiag(
            loc=tf.zeros(output_shape), scale_diag=tf.ones(output_shape))

        latents = tf.keras.layers.Lambda(lambda batch_size: base_distribution.
                                         sample(batch_size))(batch_size)

        self.latents_model = tf.keras.Model(self.condition_inputs, latents)
        self.latents_input = tf.keras.layers.Input(shape=output_shape,
                                                   name='latents')

        def raw_actions_fn(inputs):
            shift, log_scale_diag, latents = inputs
            bijector = tfp.bijectors.Affine(shift=shift,
                                            scale_diag=tf.exp(log_scale_diag))
            actions = bijector.forward(latents)
            return actions

        raw_actions = tf.keras.layers.Lambda(raw_actions_fn)(
            (shift, log_scale_diag, latents))

        raw_actions_for_fixed_latents = tf.keras.layers.Lambda(raw_actions_fn)(
            (shift, log_scale_diag, self.latents_input))

        squash_bijector = (SquashBijector()
                           if self._squash else tfp.bijectors.Identity())

        actions = tf.keras.layers.Lambda(lambda raw_actions: squash_bijector.
                                         forward(raw_actions))(raw_actions)
        self.actions_model = tf.keras.Model(self.condition_inputs, actions)

        actions_for_fixed_latents = tf.keras.layers.Lambda(
            lambda raw_actions: squash_bijector.forward(raw_actions))(
                raw_actions_for_fixed_latents)
        self.actions_model_for_fixed_latents = tf.keras.Model(
            (*self.condition_inputs, self.latents_input),
            actions_for_fixed_latents)

        deterministic_actions = tf.keras.layers.Lambda(
            lambda shift: squash_bijector.forward(shift))(shift)

        self.deterministic_actions_model = tf.keras.Model(
            self.condition_inputs, deterministic_actions)

        def log_pis_fn(inputs):
            shift, log_scale_diag, actions = inputs
            base_distribution = tfp.distributions.MultivariateNormalDiag(
                loc=tf.zeros(output_shape), scale_diag=tf.ones(output_shape))
            bijector = tfp.bijectors.Chain((
                squash_bijector,
                tfp.bijectors.Affine(shift=shift,
                                     scale_diag=tf.exp(log_scale_diag)),
            ))
            distribution = (
                tfp.distributions.ConditionalTransformedDistribution(
                    distribution=base_distribution, bijector=bijector))

            log_pis = distribution.log_prob(actions)[:, None]
            return log_pis

        self.actions_input = tf.keras.layers.Input(shape=output_shape,
                                                   name='actions')

        log_pis = tf.keras.layers.Lambda(log_pis_fn)(
            [shift, log_scale_diag, actions])

        log_pis_for_action_input = tf.keras.layers.Lambda(log_pis_fn)(
            [shift, log_scale_diag, self.actions_input])

        self.log_pis_model = tf.keras.Model(
            (*self.condition_inputs, self.actions_input),
            log_pis_for_action_input)

        self.diagnostics_model = tf.keras.Model(
            self.condition_inputs,
            (shift, log_scale_diag, log_pis, raw_actions, actions))
예제 #17
0
    def __init__(self,
                 input_shapes,
                 output_shape,
                 action_range,
                 *args,
                 squash=True,
                 preprocessors=None,
                 name=None,
                 **kwargs):

        assert (np.all(action_range == np.array([[-1], [1]]))), (
            "The action space should be scaled to (-1, 1)."
            " TODO(hartikainen): We should support non-scaled actions spaces.")

        self._Serializable__initialize(locals())

        self._action_range = action_range
        self._input_shapes = input_shapes
        self._output_shape = output_shape
        self._squash = squash
        self._name = name

        super(GaussianPolicy, self).__init__(*args, **kwargs)

        inputs_flat = create_inputs(input_shapes)
        preprocessors_flat = (flatten_input_structure(preprocessors)
                              if preprocessors is not None else tuple(
                                  None for _ in inputs_flat))

        assert len(inputs_flat) == len(preprocessors_flat), (
            inputs_flat, preprocessors_flat)

        preprocessed_inputs = [
            preprocessor(input_) if preprocessor is not None else input_
            for preprocessor, input_ in zip(preprocessors_flat, inputs_flat)
        ]

        def cast_and_concat(x):
            x = nest.map_structure(
                lambda element: tf.cast(element, tf.float32), x)
            x = nest.flatten(x)
            x = tf.concat(x, axis=-1)
            return x

        conditions = tf.keras.layers.Lambda(cast_and_concat)(
            preprocessed_inputs)

        self.condition_inputs = inputs_flat

        shift_and_log_scale_diag = self._shift_and_log_scale_diag_net(
            output_size=np.prod(output_shape) * 2, )(conditions)

        shift, log_scale_diag = tf.keras.layers.Lambda(
            lambda shift_and_log_scale_diag: tf.split(
                shift_and_log_scale_diag, num_or_size_splits=2, axis=-1))(
                    shift_and_log_scale_diag)

        batch_size = tf.keras.layers.Lambda(lambda x: tf.shape(input=x)[0])(
            conditions)

        base_distribution = tfp.distributions.MultivariateNormalDiag(
            loc=tf.zeros(output_shape), scale_diag=tf.ones(output_shape))

        latents = tf.keras.layers.Lambda(lambda batch_size: base_distribution.
                                         sample(batch_size))(batch_size)

        self.latents_model = tf.keras.Model(self.condition_inputs, latents)
        self.latents_input = tf.keras.layers.Input(shape=output_shape,
                                                   name='latents')

        def raw_actions_fn(inputs):
            shift, log_scale_diag, latents = inputs
            bijector = tfp.bijectors.Affine(shift=shift,
                                            scale_diag=tf.exp(log_scale_diag))
            actions = bijector.forward(latents)
            return actions

        raw_actions = tf.keras.layers.Lambda(raw_actions_fn)(
            (shift, log_scale_diag, latents))

        raw_actions_for_fixed_latents = tf.keras.layers.Lambda(raw_actions_fn)(
            (shift, log_scale_diag, self.latents_input))

        squash_bijector = (tfp.bijectors.Tanh()
                           if self._squash else tfp.bijectors.Identity())

        actions = tf.keras.layers.Lambda(lambda raw_actions: squash_bijector.
                                         forward(raw_actions))(raw_actions)
        self.actions_model = tf.keras.Model(self.condition_inputs, actions)

        actions_for_fixed_latents = tf.keras.layers.Lambda(
            lambda raw_actions: squash_bijector.forward(raw_actions))(
                raw_actions_for_fixed_latents)
        self.actions_model_for_fixed_latents = tf.keras.Model(
            (*self.condition_inputs, self.latents_input),
            actions_for_fixed_latents)

        deterministic_actions = tf.keras.layers.Lambda(
            lambda shift: squash_bijector.forward(shift))(shift)

        self.deterministic_actions_model = tf.keras.Model(
            self.condition_inputs, deterministic_actions)

        def log_pis_fn(inputs):
            shift, log_scale_diag, actions = inputs
            base_distribution = tfp.distributions.MultivariateNormalDiag(
                loc=tf.zeros(output_shape), scale_diag=tf.ones(output_shape))
            bijector = tfp.bijectors.Chain((
                squash_bijector,
                tfp.bijectors.Affine(shift=shift,
                                     scale_diag=tf.exp(log_scale_diag)),
            ))
            distribution = (tfp.distributions.TransformedDistribution(
                distribution=base_distribution, bijector=bijector))

            log_pis = distribution.log_prob(actions)[:, None]
            return log_pis

        self.actions_input = tf.keras.layers.Input(shape=output_shape,
                                                   name='actions')

        log_pis = tf.keras.layers.Lambda(log_pis_fn)(
            [shift, log_scale_diag, actions])

        log_pis_for_action_input = tf.keras.layers.Lambda(log_pis_fn)(
            [shift, log_scale_diag, self.actions_input])

        self.log_pis_model = tf.keras.Model(
            (*self.condition_inputs, self.actions_input),
            log_pis_for_action_input)

        self.diagnostics_model = tf.keras.Model(
            self.condition_inputs,
            (shift, log_scale_diag, log_pis, raw_actions, actions))
예제 #18
0
 def _create_inputs(self, input_shapes):
     self._inputs = create_inputs(input_shapes)