Ejemplo n.º 1
0
    def __init__(self, x_in, ob_space, ac_space, lstm_class, lstm_layers):

        # Flatten end expand with fake time dim to feed to LSTM bank:
        x = tf.expand_dims(batch_flatten(x_in), [0])
        # x = tf.expand_dims(self.flatten_homebrew(x_in), [0])
        try:
            if self.train_phase is not None:
                pass

        except:
            self.train_phase = tf.placeholder_with_default(
                tf.constant(False, dtype=tf.bool),
                shape=(),
                name='train_phase_flag_pl'
            )
        self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

        #print('GOT HERE 2, x:', x.shape)
        #print('GOT HERE 2, train_phase:', self.train_phase.shape)
        #print('GOT HERE 2, update_ops:', self.update_ops)

        # Define LSTM layers:
        lstm = []
        for size in lstm_layers:
            lstm += [lstm_class(size, state_is_tuple=True)]

        self.lstm = rnn.MultiRNNCell(lstm, state_is_tuple=True)
        # self.lstm = lstm[0]

        # Get time_dimension as [1]-shaped tensor:
        step_size = tf.expand_dims(tf.shape(x)[1], [0])
        #step_size = tf.shape(self.x)[:1]
        #print('GOT HERE 3')
        self.lstm_init_state = self.lstm.zero_state(1, dtype=tf.float32)

        lstm_state_pl = self.rnn_placeholders(self.lstm.zero_state(1, dtype=tf.float32))
        self.lstm_state_pl_flatten = flatten_nested(lstm_state_pl)

        #print('GOT HERE 4, x:', x.shape)
        lstm_outputs, self.lstm_state_out = tf.nn.dynamic_rnn(
            self.lstm,
            x,
            initial_state=lstm_state_pl,
            sequence_length=step_size,
            time_major=False
        )
        #print('GOT HERE 5')
        x = tf.reshape(lstm_outputs, [-1, lstm_layers[-1]])

        self.logits = self.linear(x, ac_space, "action", self.normalized_columns_initializer(0.01))
        self.vf = tf.reshape(self.linear(x, 1, "value", self.normalized_columns_initializer(1.0)), [-1])
        self.sample = self.categorical_sample(self.logits, ac_space)[0, :]
        self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name)

        # Add moving averages to save list (meant for Batch_norm layer):
        moving_var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, tf.get_variable_scope().name + '.*moving.*')
        renorm_var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, tf.get_variable_scope().name + '.*renorm.*')

        self.var_list += moving_var_list + renorm_var_list
Ejemplo n.º 2
0
def conv2d_autoencoder(inputs,
                       layer_config,
                       resize_method=tf.image.ResizeMethod.BILINEAR,
                       pad='SAME',
                       linear_layer_ref=linear,
                       name='base_conv2d_autoencoder',
                       reuse=False,
                       **kwargs):
    """
    Basic convolutional autoencoder.
    Hidden state is passed through dense linear layer.

    Args:
        inputs:             input tensor
        layer_config:       layers configuration list: [layer_1_config, layer_2_config,...], where:
                            layer_i_config = [num_filters(int), filter_size(list), stride(list)];
                            this list represent decoder part of autoencoder bottleneck,
                            decoder part is inferred symmetrically
        resize_method:      up-sampling method, one of supported tf.image.ResizeMethod's
        pad:                str, padding scheme: 'SAME' or 'VALID'
        linear_layer_ref:   linear layer class to use
        name:               str, mame scope
        reuse:              bool

    Returns:
        list of tensors holding encoded features, layer_wise from outer to inner
        tensor holding batch-wise flattened hidden state vector
        list of tensors holding decoded features, layer-wise from inner to outer
        tensor holding reconstructed output
        None value

    """
    with tf.variable_scope(name, reuse=reuse):
        # Encode:
        encoder_layers, shapes = conv2d_encoder(x=inputs,
                                                layer_config=layer_config,
                                                pad=pad,
                                                reuse=reuse)
        # Flatten hidden state, pass through dense :
        z = batch_flatten(encoder_layers[-1])
        h, w, c = encoder_layers[-1].get_shape().as_list()[1:]

        z = linear_layer_ref(x=z,
                             size=h * w * c,
                             name='hidden_dense',
                             initializer=normalized_columns_initializer(1.0),
                             reuse=reuse)
        # Reshape back and feed to decoder:
        decoder_layers = conv2d_decoder(z=tf.reshape(z, [-1, h, w, c]),
                                        layer_config=layer_config,
                                        layer_shapes=shapes,
                                        pad=pad,
                                        resize_method=resize_method,
                                        reuse=reuse)
        y_hat = decoder_layers[-1]
        return encoder_layers, z, decoder_layers, y_hat, None
Ejemplo n.º 3
0
def lstm_network(x,
                 a_r,
                 lstm_class=rnn.BasicLSTMCell,
                 lstm_layers=(256, ),
                 reuse=False):
    """Stage2 network: from features to flattened LSTM output.
    Defines [multi-layered] dynamic [possibly shared] LSTM network.

    Returns:
         batch-wise flattened output tensor;
         lstm initial state tensor;
         lstm state output tensor;
         lstm flattened feed placeholders as tuple.
    """
    with tf.variable_scope('lstm', reuse=reuse):
        # Flatten, add action/reward and expand with fake time dim to feed LSTM bank:
        x = tf.concat([batch_flatten(x), a_r], axis=-1)
        x = tf.expand_dims(x, [0])

        # Define LSTM layers:
        lstm = []
        for size in lstm_layers:
            lstm += [lstm_class(size, state_is_tuple=True)]

        lstm = rnn.MultiRNNCell(lstm, state_is_tuple=True)
        # Get time_dimension as [1]-shaped tensor:
        step_size = tf.expand_dims(tf.shape(x)[1], [0])

        lstm_init_state = lstm.zero_state(1, dtype=tf.float32)

        lstm_state_pl = rnn_placeholders(lstm.zero_state(1, dtype=tf.float32))
        lstm_state_pl_flatten = flatten_nested(lstm_state_pl)

        lstm_outputs, lstm_state_out = tf.nn.dynamic_rnn(
            lstm,
            x,
            initial_state=lstm_state_pl,
            sequence_length=step_size,
            time_major=False)
        x_out = tf.reshape(lstm_outputs, [-1, lstm_layers[-1]])
    return x_out, lstm_init_state, lstm_state_out, lstm_state_pl_flatten
Ejemplo n.º 4
0
    def __init__(self, ob_space, ac_space, ff_size=64, **kwargs):
        """
        Simple and computationally cheap feed-forward policy.

        Args:
            ob_space:           dictionary of observation state shapes
            ac_space:           discrete action space shape (length)
            ff_size:            feed-forward dense layer size
            **kwargs            not used
        """
        kwargs.update(dict(
            conv_2d_filter_size=[3, 1],
            conv_2d_stride=[2, 1],
        ))

        self.ob_space = ob_space
        self.ac_space = ac_space
        self.aux_estimate = False
        self.callback = {}

        # Placeholders for obs. state input:
        self.on_state_in = nested_placeholders(ob_space,
                                               batch_dim=None,
                                               name='on_policy_state_in')
        self.off_state_in = nested_placeholders(ob_space,
                                                batch_dim=None,
                                                name='off_policy_state_in_pl')
        self.rp_state_in = nested_placeholders(ob_space,
                                               batch_dim=None,
                                               name='rp_state_in')

        # Placeholders for concatenated action [one-hot] and reward [scalar]:
        self.on_a_r_in = tf.placeholder(tf.float32, [None, ac_space + 1],
                                        name='on_policy_action_reward_in_pl')
        self.off_a_r_in = tf.placeholder(tf.float32, [None, ac_space + 1],
                                         name='off_policy_action_reward_in_pl')

        # Placeholders for rnn batch and time-step dimensions:
        self.on_batch_size = tf.placeholder(tf.int32,
                                            name='on_policy_batch_size')
        self.on_time_length = tf.placeholder(tf.int32,
                                             name='on_policy_sequence_size')

        self.off_batch_size = tf.placeholder(tf.int32,
                                             name='off_policy_batch_size')
        self.off_time_length = tf.placeholder(tf.int32,
                                              name='off_policy_sequence_size')

        # Base on-policy AAC network:
        # Conv. layers:
        on_aac_x = conv_2d_network(self.on_state_in['external'],
                                   ob_space['external'], ac_space, **kwargs)

        if False:
            # Reshape rnn inputs for  batch training as [rnn_batch_dim, rnn_time_dim, flattened_depth]:
            x_shape_dynamic = tf.shape(on_aac_x)
            max_seq_len = tf.cast(x_shape_dynamic[0] / self.on_batch_size,
                                  tf.int32)
            x_shape_static = on_aac_x.get_shape().as_list()

            on_a_r_in = tf.reshape(
                self.on_a_r_in,
                [self.on_batch_size, max_seq_len, ac_space + 1])
            on_aac_x = tf.reshape(
                on_aac_x,
                [self.on_batch_size, max_seq_len,
                 np.prod(x_shape_static[1:])])

            # Feed last action_reward [, internal obs. state] into LSTM along with external state features:
            on_stage2_input = [on_aac_x, on_a_r_in]

            if 'internal' in list(self.on_state_in.keys()):
                x_int_shape_static = self.on_state_in['internal'].get_shape(
                ).as_list()
                x_int = tf.reshape(self.on_state_in['internal'], [
                    self.on_batch_size, max_seq_len,
                    np.prod(x_int_shape_static[1:])
                ])
                on_stage2_input.append(x_int)

            on_aac_x = tf.concat(on_stage2_input, axis=-1)

        on_aac_x = batch_flatten(on_aac_x)

        # Dense layer:
        on_x_dense_out = tf.nn.elu(
            linear(on_aac_x,
                   ff_size,
                   'dense_pi_v',
                   normalized_columns_initializer(0.01),
                   reuse=False))

        # Dummy:
        self.on_lstm_init_state = (LSTMStateTuple(c=np.zeros((1, 1)),
                                                  h=np.zeros((1, 1))), )
        self.on_lstm_state_out = (LSTMStateTuple(c=np.zeros((1, 1)),
                                                 h=np.zeros((1, 1))), )
        self.on_lstm_state_pl_flatten = [
            tf.placeholder(shape=(None, 1), dtype=tf.float32, name='dummy_c'),
            tf.placeholder(shape=(None, 1), dtype=tf.float32, name='dummy_h')
        ]

        # Aac policy and value outputs and action-sampling function:
        [self.on_logits, self.on_vf,
         self.on_sample] = dense_aac_network(on_x_dense_out, ac_space)

        # Batch-norm related (useless, ignore):
        try:
            if self.train_phase is not None:
                pass

        except AttributeError:
            self.train_phase = tf.placeholder_with_default(
                tf.constant(False, dtype=tf.bool),
                shape=(),
                name='train_phase_flag_pl')
        self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        # Add moving averages to save list:
        moving_var_list = tf.get_collection(
            tf.GraphKeys.GLOBAL_VARIABLES,
            tf.get_variable_scope().name + '.*moving.*')
        renorm_var_list = tf.get_collection(
            tf.GraphKeys.GLOBAL_VARIABLES,
            tf.get_variable_scope().name + '.*renorm.*')

        # What to save:
        self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                          tf.get_variable_scope().name)
        self.var_list += moving_var_list + renorm_var_list
Ejemplo n.º 5
0
    def __init__(
            self,
            ob_space,
            ac_space,
            rp_sequence_size,
            lstm_class=tf.contrib.rnn.LayerNormBasicLSTMCell,
            #lstm_class=rnn.BasicLSTMCell,
            lstm_layers=(256, 256),
            aux_estimate=False,
            encode_internal_state=False,
            **kwargs):
        """
        Defines [partially shared] on/off-policy networks for estimating  action-logits, value function,
        reward and state 'pixel_change' predictions.
        Expects multi-modal observation as array of shape `ob_space`.

        Args:
            ob_space:           dictionary of observation state shapes
            ac_space:           discrete action space shape (length)
            rp_sequence_size:   reward prediction sample length
            lstm_class:         tf.nn.lstm class
            lstm_layers:        tuple of LSTM layers sizes
            aux_estimate:       (bool), if True - add auxiliary tasks estimations to self.callbacks dictionary.
            **kwargs            not used
        """
        # 1D plug-in:
        kwargs.update(
            dict(
                conv_2d_filter_size=[3, 1],
                conv_2d_stride=[2, 1],
                conv_2d_num_filters=32,
                pc_estimator_stride=[2, 1],
                duell_pc_x_inner_shape=(6, 1,
                                        32),  # [6,3,32] if swapping W-C dims
                duell_pc_filter_size=(4, 1),
                duell_pc_stride=(2, 1),
            ))

        self.ob_space = ob_space
        self.ac_space = ac_space
        self.rp_sequence_size = rp_sequence_size
        self.lstm_class = lstm_class
        self.lstm_layers = lstm_layers
        self.aux_estimate = aux_estimate
        self.callback = {}
        self.encode_internal_state = encode_internal_state
        self.debug = {}

        # Placeholders for obs. state input:
        self.on_state_in = nested_placeholders(ob_space,
                                               batch_dim=None,
                                               name='on_policy_state_in')
        self.off_state_in = nested_placeholders(ob_space,
                                                batch_dim=None,
                                                name='off_policy_state_in_pl')
        self.rp_state_in = nested_placeholders(ob_space,
                                               batch_dim=None,
                                               name='rp_state_in')

        # Placeholders for concatenated action [one-hot] and reward [scalar]:
        self.on_a_r_in = tf.placeholder(tf.float32, [None, ac_space + 1],
                                        name='on_policy_action_reward_in_pl')
        self.off_a_r_in = tf.placeholder(tf.float32, [None, ac_space + 1],
                                         name='off_policy_action_reward_in_pl')

        # Placeholders for rnn batch and time-step dimensions:
        self.on_batch_size = tf.placeholder(tf.int32,
                                            name='on_policy_batch_size')
        self.on_time_length = tf.placeholder(tf.int32,
                                             name='on_policy_sequence_size')

        self.off_batch_size = tf.placeholder(tf.int32,
                                             name='off_policy_batch_size')
        self.off_time_length = tf.placeholder(tf.int32,
                                              name='off_policy_sequence_size')

        # Base on-policy AAC network:
        # Conv. layers:
        on_aac_x = conv_2d_network(self.on_state_in['external'],
                                   ob_space['external'],
                                   ac_space,
                                   name='conv1d_external',
                                   **kwargs)

        # Aux min/max_loss:
        if 'raw_state' in list(self.on_state_in.keys()):
            self.raw_state = self.on_state_in['raw_state']
            self.state_min_max = tf.nn.elu(
                linear(batch_flatten(on_aac_x), 2, "min_max",
                       normalized_columns_initializer(0.01)))
        else:
            self.raw_state = None
            self.state_min_max = None

            # Reshape rnn inputs for  batch training as [rnn_batch_dim, rnn_time_dim, flattened_depth]:
        x_shape_dynamic = tf.shape(on_aac_x)
        max_seq_len = tf.cast(x_shape_dynamic[0] / self.on_batch_size,
                              tf.int32)
        x_shape_static = on_aac_x.get_shape().as_list()

        on_a_r_in = tf.reshape(self.on_a_r_in,
                               [self.on_batch_size, max_seq_len, ac_space + 1])
        on_aac_x = tf.reshape(
            on_aac_x,
            [self.on_batch_size, max_seq_len,
             np.prod(x_shape_static[1:])])

        # Prepare `internal` state, if any:
        if 'internal' in list(self.on_state_in.keys()):
            if self.encode_internal_state:
                # Use convolution encoder:
                on_x_internal = conv_2d_network(
                    self.on_state_in['internal'],
                    ob_space['internal'],
                    ac_space,
                    name='conv1d_internal',
                    # conv_2d_layer_ref=conv2d_dw,
                    conv_2d_num_filters=32,
                    conv_2d_num_layers=2,
                    conv_2d_filter_size=[3, 1],
                    conv_2d_stride=[2, 1],
                )
                x_int_shape_static = on_x_internal.get_shape().as_list()
                on_x_internal = [
                    tf.reshape(on_x_internal, [
                        self.on_batch_size, max_seq_len,
                        np.prod(x_int_shape_static[1:])
                    ])
                ]
                self.debug['state_internal_enc'] = tf.shape(on_x_internal)

            else:
                # Feed as is:
                x_int_shape_static = self.on_state_in['internal'].get_shape(
                ).as_list()
                on_x_internal = tf.reshape(self.on_state_in['internal'], [
                    self.on_batch_size, max_seq_len,
                    np.prod(x_int_shape_static[1:])
                ])
                self.debug['state_internal'] = tf.shape(
                    self.on_state_in['internal'])
                on_x_internal = [on_x_internal]

        else:
            on_x_internal = []

        # Not used:
        if 'reward' in list(self.on_state_in.keys()):
            x_rewards_shape_static = self.on_state_in['reward'].get_shape(
            ).as_list()
            x_rewards = tf.reshape(self.on_state_in['reward'], [
                self.on_batch_size, max_seq_len,
                np.prod(x_rewards_shape_static[1:])
            ])
            self.debug['rewards'] = tf.shape(x_rewards)
            x_rewards = [x_rewards]

        else:
            x_rewards = []

        self.debug['conv_input_to_lstm1'] = tf.shape(on_aac_x)

        # Feed last last_reward into LSTM_1 layer along with encoded `external` state features:
        on_stage2_1_input = [on_aac_x,
                             on_a_r_in[..., -1][..., None]]  #+ on_x_internal

        # Feed last_action, encoded `external` state,  `internal` state into LSTM_2:
        on_stage2_2_input = [on_aac_x, on_a_r_in] + on_x_internal

        # LSTM_1 full input:
        on_aac_x = tf.concat(on_stage2_1_input, axis=-1)

        self.debug['concat_input_to_lstm1'] = tf.shape(on_aac_x)

        # First LSTM layer takes encoded `external` state:
        [on_x_lstm_1_out, self.on_lstm_1_init_state, self.on_lstm_1_state_out, self.on_lstm_1_state_pl_flatten] =\
            lstm_network(on_aac_x, self.on_time_length, lstm_class, (lstm_layers[0],), name='lstm_1')

        self.debug['on_x_lstm_1_out'] = tf.shape(on_x_lstm_1_out)
        self.debug['self.on_lstm_1_state_out'] = tf.shape(
            self.on_lstm_1_state_out)
        self.debug['self.on_lstm_1_state_pl_flatten'] = tf.shape(
            self.on_lstm_1_state_pl_flatten)

        # For time_flat only: Reshape on_lstm_1_state_out from [1,2,20,size] -->[20,1,2,size] --> [20,1, 2xsize]:
        reshape_lstm_1_state_out = tf.transpose(self.on_lstm_1_state_out,
                                                [2, 0, 1, 3])
        reshape_lstm_1_state_out_shape_static = reshape_lstm_1_state_out.get_shape(
        ).as_list()
        reshape_lstm_1_state_out = tf.reshape(
            reshape_lstm_1_state_out,
            [
                self.on_batch_size, max_seq_len,
                np.prod(reshape_lstm_1_state_out_shape_static[-2:])
            ],
        )
        #self.debug['reshape_lstm_1_state_out'] = tf.shape(reshape_lstm_1_state_out)

        # Take policy logits off first LSTM-dense layer:
        # Reshape back to [batch, flattened_depth], where batch = rnn_batch_dim * rnn_time_dim:
        x_shape_static = on_x_lstm_1_out.get_shape().as_list()
        rsh_on_x_lstm_1_out = tf.reshape(
            on_x_lstm_1_out, [x_shape_dynamic[0], x_shape_static[-1]])

        self.debug['reshaped_on_x_lstm_1_out'] = tf.shape(rsh_on_x_lstm_1_out)

        # Aac policy output and action-sampling function:
        [self.on_logits, _,
         self.on_sample] = dense_aac_network(rsh_on_x_lstm_1_out,
                                             ac_space,
                                             name='aac_dense_pi')

        # Second LSTM layer takes concatenated encoded 'external' state, LSTM_1 output,
        # last_action and `internal_state` (if present) tensors:
        on_stage2_2_input += [on_x_lstm_1_out]

        # Try: feed context instead of output
        #on_stage2_2_input = [reshape_lstm_1_state_out] + on_stage2_1_input

        # LSTM_2 full input:
        on_aac_x = tf.concat(on_stage2_2_input, axis=-1)

        self.debug['on_stage2_2_input'] = tf.shape(on_aac_x)

        [on_x_lstm_2_out, self.on_lstm_2_init_state, self.on_lstm_2_state_out, self.on_lstm_2_state_pl_flatten] = \
            lstm_network(on_aac_x, self.on_time_length, lstm_class, (lstm_layers[-1],), name='lstm_2')

        self.debug['on_x_lstm_2_out'] = tf.shape(on_x_lstm_2_out)
        self.debug['self.on_lstm_2_state_out'] = tf.shape(
            self.on_lstm_2_state_out)
        self.debug['self.on_lstm_2_state_pl_flatten'] = tf.shape(
            self.on_lstm_2_state_pl_flatten)

        # Reshape back to [batch, flattened_depth], where batch = rnn_batch_dim * rnn_time_dim:
        x_shape_static = on_x_lstm_2_out.get_shape().as_list()
        on_x_lstm_out = tf.reshape(on_x_lstm_2_out,
                                   [x_shape_dynamic[0], x_shape_static[-1]])

        self.debug['reshaped_on_x_lstm_out'] = tf.shape(on_x_lstm_out)

        # Aac value function:
        [_, self.on_vf, _] = dense_aac_network(on_x_lstm_out,
                                               ac_space,
                                               name='aac_dense_vfn')

        # Concatenate LSTM placeholders, init. states and context:
        self.on_lstm_init_state = (self.on_lstm_1_init_state,
                                   self.on_lstm_2_init_state)
        self.on_lstm_state_out = (self.on_lstm_1_state_out,
                                  self.on_lstm_2_state_out)
        self.on_lstm_state_pl_flatten = self.on_lstm_1_state_pl_flatten + self.on_lstm_2_state_pl_flatten

        #if False: # Temp. disable

        # Off-policy AAC network (shared):
        off_aac_x = conv_2d_network(self.off_state_in['external'],
                                    ob_space['external'],
                                    ac_space,
                                    name='conv1d_external',
                                    reuse=True,
                                    **kwargs)
        # Reshape rnn inputs for  batch training as [rnn_batch_dim, rnn_time_dim, flattened_depth]:
        x_shape_dynamic = tf.shape(off_aac_x)
        max_seq_len = tf.cast(x_shape_dynamic[0] / self.off_batch_size,
                              tf.int32)
        x_shape_static = off_aac_x.get_shape().as_list()

        off_a_r_in = tf.reshape(
            self.off_a_r_in, [self.off_batch_size, max_seq_len, ac_space + 1])
        off_aac_x = tf.reshape(
            off_aac_x,
            [self.off_batch_size, max_seq_len,
             np.prod(x_shape_static[1:])])

        # Prepare `internal` state, if any:
        if 'internal' in list(self.off_state_in.keys()):
            if self.encode_internal_state:
                # Use convolution encoder:
                off_x_internal = conv_2d_network(
                    self.off_state_in['internal'],
                    ob_space['internal'],
                    ac_space,
                    name='conv1d_internal',
                    # conv_2d_layer_ref=conv2d_dw,
                    conv_2d_num_filters=32,
                    conv_2d_num_layers=2,
                    conv_2d_filter_size=[3, 1],
                    conv_2d_stride=[2, 1],
                    reuse=True,
                )
                x_int_shape_static = off_x_internal.get_shape().as_list()
                off_x_internal = [
                    tf.reshape(off_x_internal, [
                        self.off_batch_size, max_seq_len,
                        np.prod(x_int_shape_static[1:])
                    ])
                ]
            else:
                x_int_shape_static = self.off_state_in['internal'].get_shape(
                ).as_list()
                off_x_internal = tf.reshape(self.off_state_in['internal'], [
                    self.off_batch_size, max_seq_len,
                    np.prod(x_int_shape_static[1:])
                ])
                off_x_internal = [off_x_internal]

        else:
            off_x_internal = []

        off_stage2_1_input = [off_aac_x,
                              off_a_r_in[..., -1][...,
                                                  None]]  #+ off_x_internal

        off_stage2_2_input = [off_aac_x, off_a_r_in] + off_x_internal

        off_aac_x = tf.concat(off_stage2_1_input, axis=-1)

        [off_x_lstm_1_out, _, _, self.off_lstm_1_state_pl_flatten] =\
            lstm_network(off_aac_x, self.off_time_length, lstm_class, (lstm_layers[0],), name='lstm_1', reuse=True)

        # Reshape back to [batch, flattened_depth], where batch = rnn_batch_dim * rnn_time_dim:
        x_shape_static = off_x_lstm_1_out.get_shape().as_list()
        rsh_off_x_lstm_1_out = tf.reshape(
            off_x_lstm_1_out, [x_shape_dynamic[0], x_shape_static[-1]])

        [self.off_logits, _, _] =\
            dense_aac_network(rsh_off_x_lstm_1_out, ac_space, name='aac_dense_pi', reuse=True)

        off_stage2_2_input += [off_x_lstm_1_out]

        # LSTM_2 full input:
        off_aac_x = tf.concat(off_stage2_2_input, axis=-1)

        [off_x_lstm_2_out, _, _, self.off_lstm_2_state_pl_flatten] = \
            lstm_network(off_aac_x, self.off_time_length, lstm_class, (lstm_layers[-1],), name='lstm_2', reuse=True)

        # Reshape back to [batch, flattened_depth], where batch = rnn_batch_dim * rnn_time_dim:
        x_shape_static = off_x_lstm_2_out.get_shape().as_list()
        off_x_lstm_out = tf.reshape(off_x_lstm_2_out,
                                    [x_shape_dynamic[0], x_shape_static[-1]])

        # Aac value function:
        [_, self.off_vf, _] = dense_aac_network(off_x_lstm_out,
                                                ac_space,
                                                name='aac_dense_vfn',
                                                reuse=True)

        # Concatenate LSTM states:
        self.off_lstm_state_pl_flatten = self.off_lstm_1_state_pl_flatten + self.off_lstm_2_state_pl_flatten

        # Aux1:
        # `Pixel control` network.
        #
        # Define pixels-change estimation function:
        # Yes, it rather env-specific but for atari case it is handy to do it here, see self.get_pc_target():
        [self.pc_change_state_in, self.pc_change_last_state_in, self.pc_target] =\
            pixel_change_2d_estimator(ob_space['external'], **kwargs)

        self.pc_batch_size = self.off_batch_size
        self.pc_time_length = self.off_time_length

        self.pc_state_in = self.off_state_in
        self.pc_a_r_in = self.off_a_r_in
        self.pc_lstm_state_pl_flatten = self.off_lstm_state_pl_flatten

        # Shared conv and lstm nets, same off-policy batch:
        pc_x = off_x_lstm_out

        # PC duelling Q-network, outputs [None, 20, 20, ac_size] Q-features tensor:
        self.pc_q = duelling_pc_network(pc_x, self.ac_space, **kwargs)

        # Aux2:
        # `Value function replay` network.
        #
        # VR network is fully shared with ppo network but with `value` only output:
        # and has same off-policy batch pass with off_ppo network:
        self.vr_batch_size = self.off_batch_size
        self.vr_time_length = self.off_time_length

        self.vr_state_in = self.off_state_in
        self.vr_a_r_in = self.off_a_r_in

        self.vr_lstm_state_pl_flatten = self.off_lstm_state_pl_flatten
        self.vr_value = self.off_vf

        # Aux3:
        # `Reward prediction` network.
        self.rp_batch_size = tf.placeholder(tf.int32, name='rp_batch_size')

        # Shared conv. output:
        rp_x = conv_2d_network(self.rp_state_in['external'],
                               ob_space['external'],
                               ac_space,
                               name='conv1d_external',
                               reuse=True,
                               **kwargs)

        # Flatten batch-wise:
        rp_x_shape_static = rp_x.get_shape().as_list()
        rp_x = tf.reshape(rp_x, [
            self.rp_batch_size,
            np.prod(rp_x_shape_static[1:]) * (self.rp_sequence_size - 1)
        ])

        # RP output:
        self.rp_logits = dense_rp_network(rp_x)

        # Batch-norm related (useless, ignore):
        try:
            if self.train_phase is not None:
                pass

        except AttributeError:
            self.train_phase = tf.placeholder_with_default(
                tf.constant(False, dtype=tf.bool),
                shape=(),
                name='train_phase_flag_pl')
        self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        # Add moving averages to save list:
        moving_var_list = tf.get_collection(
            tf.GraphKeys.GLOBAL_VARIABLES,
            tf.get_variable_scope().name + '.*moving.*')
        renorm_var_list = tf.get_collection(
            tf.GraphKeys.GLOBAL_VARIABLES,
            tf.get_variable_scope().name + '.*renorm.*')

        # What to save:
        self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                          tf.get_variable_scope().name)
        self.var_list += moving_var_list + renorm_var_list

        # Callbacks:
        if self.aux_estimate:
            self.callback['pixel_change'] = self.get_pc_target
Ejemplo n.º 6
0
def beta_var_conv2d_autoencoder(
        inputs,
        layer_config,
        resize_method=tf.image.ResizeMethod.BILINEAR,
        pad='SAME',
        linear_layer_ref=linear,
        name='vae_conv2d',
        max_batch_size=256,
        reuse=False
    ):
    """
    Variational autoencoder.

    Papers:
        https://arxiv.org/pdf/1312.6114.pdf
        https://arxiv.org/pdf/1606.05908.pdf
        http://www.matthey.me/pdf/betavae_iclr_2017.pdf


    Args:
        inputs:             input tensor
        layer_config:       layers configuration list: [layer_1_config, layer_2_config,...], where:
                            layer_i_config = [num_filters(int), filter_size(list), stride(list)];
                            this list represent decoder part of autoencoder bottleneck,
                            decoder part is inferred symmetrically
        resize_method:      up-sampling method, one of supported tf.image.ResizeMethod's
        pad:                str, padding scheme: 'SAME' or 'VALID'
        linear_layer_ref:   linear layer class - not used
        name:               str, mame scope
        max_batch_size:     int, dynamic batch size should be no greater than this value
        reuse:              bool

    Returns:
        list of tensors holding encoded features, layer_wise from outer to inner
        tensor holding batch-wise flattened hidden state vector
        list of tensors holding decoded features, layer-wise from inner to outer
        tensor holding reconstructed output
        tensor holding estimated KL divergence

    """
    with tf.variable_scope(name, reuse=reuse):

        # Encode:
        encoder_layers, shapes = conv2d_encoder(
            x=inputs,
            layer_config=layer_config,
            pad=pad,
            reuse=reuse
        )
        # Flatten hidden state, pass through dense:
        z_flat = batch_flatten(encoder_layers[-1])

        h, w, c = encoder_layers[-1].get_shape().as_list()[1:]

        z = tf.nn.elu(
            linear(
                x=z_flat,
                size=h * w * c,
                name='enc_dense',
                initializer=normalized_columns_initializer(1.0),
                reuse=reuse
            )
        )
        # TODO: revert back to dubled Z-size
        # half_size_z = h * w * c
        # size_z = 2 * half_size_z

        size_z = int(h * w * c/2)
        z = tf.nn.elu(
            linear(
                #x=z_flat,
                x=z,
                #size=size_z,
                size=size_z * 2,
                name='hidden_dense',
                initializer=normalized_columns_initializer(1.0),
                reuse=reuse
            )
        )
        # Get sample parameters:
        #mu, log_sigma = tf.split(z, [half_size_z, half_size_z], axis=-1)
        mu, log_sigma = tf.split(z, [size_z, size_z], axis=-1)

        # Oversized noise generator:
        #eps = tf.random_normal(shape=[max_batch_size, half_size_z], mean=0., stddev=1.)
        eps = tf.random_normal(shape=[max_batch_size, size_z], mean=0., stddev=1.)
        eps = eps[:tf.shape(z)[0],:]

        # Get sample z ~ Q(z|X):
        z_sampled = mu + tf.exp(log_sigma / 2) * eps

        # D_KL(Q(z|X) || P(z|X)):
        # TODO: where is sum?!
        d_kl = 0.5 * (tf.exp(log_sigma) + tf.square(mu) - 1. - log_sigma)

        # Reshape back and feed to decoder:

        z_sampled_dec = tf.nn.elu(
            linear(
                x=z_sampled,
                size=h * w * c,
                name='dec_dense',
                initializer=normalized_columns_initializer(1.0),
                reuse=reuse
            )
        )

        decoder_layers = conv2d_decoder(
            z=tf.reshape(z_sampled_dec, [-1, h, w, c]),
            layer_config=layer_config,
            layer_shapes=shapes,
            pad=pad,
            resize_method=resize_method,
            reuse=reuse
        )
        y_hat = decoder_layers[-1]
        return encoder_layers, z_sampled, decoder_layers, y_hat, d_kl