def __init__(self, x_in, ob_space, ac_space, lstm_class, lstm_layers): # Flatten end expand with fake time dim to feed to LSTM bank: x = tf.expand_dims(batch_flatten(x_in), [0]) # x = tf.expand_dims(self.flatten_homebrew(x_in), [0]) try: if self.train_phase is not None: pass except: self.train_phase = tf.placeholder_with_default( tf.constant(False, dtype=tf.bool), shape=(), name='train_phase_flag_pl' ) self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) #print('GOT HERE 2, x:', x.shape) #print('GOT HERE 2, train_phase:', self.train_phase.shape) #print('GOT HERE 2, update_ops:', self.update_ops) # Define LSTM layers: lstm = [] for size in lstm_layers: lstm += [lstm_class(size, state_is_tuple=True)] self.lstm = rnn.MultiRNNCell(lstm, state_is_tuple=True) # self.lstm = lstm[0] # Get time_dimension as [1]-shaped tensor: step_size = tf.expand_dims(tf.shape(x)[1], [0]) #step_size = tf.shape(self.x)[:1] #print('GOT HERE 3') self.lstm_init_state = self.lstm.zero_state(1, dtype=tf.float32) lstm_state_pl = self.rnn_placeholders(self.lstm.zero_state(1, dtype=tf.float32)) self.lstm_state_pl_flatten = flatten_nested(lstm_state_pl) #print('GOT HERE 4, x:', x.shape) lstm_outputs, self.lstm_state_out = tf.nn.dynamic_rnn( self.lstm, x, initial_state=lstm_state_pl, sequence_length=step_size, time_major=False ) #print('GOT HERE 5') x = tf.reshape(lstm_outputs, [-1, lstm_layers[-1]]) self.logits = self.linear(x, ac_space, "action", self.normalized_columns_initializer(0.01)) self.vf = tf.reshape(self.linear(x, 1, "value", self.normalized_columns_initializer(1.0)), [-1]) self.sample = self.categorical_sample(self.logits, ac_space)[0, :] self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name) # Add moving averages to save list (meant for Batch_norm layer): moving_var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, tf.get_variable_scope().name + '.*moving.*') renorm_var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, tf.get_variable_scope().name + '.*renorm.*') self.var_list += moving_var_list + renorm_var_list
def conv2d_autoencoder(inputs, layer_config, resize_method=tf.image.ResizeMethod.BILINEAR, pad='SAME', linear_layer_ref=linear, name='base_conv2d_autoencoder', reuse=False, **kwargs): """ Basic convolutional autoencoder. Hidden state is passed through dense linear layer. Args: inputs: input tensor layer_config: layers configuration list: [layer_1_config, layer_2_config,...], where: layer_i_config = [num_filters(int), filter_size(list), stride(list)]; this list represent decoder part of autoencoder bottleneck, decoder part is inferred symmetrically resize_method: up-sampling method, one of supported tf.image.ResizeMethod's pad: str, padding scheme: 'SAME' or 'VALID' linear_layer_ref: linear layer class to use name: str, mame scope reuse: bool Returns: list of tensors holding encoded features, layer_wise from outer to inner tensor holding batch-wise flattened hidden state vector list of tensors holding decoded features, layer-wise from inner to outer tensor holding reconstructed output None value """ with tf.variable_scope(name, reuse=reuse): # Encode: encoder_layers, shapes = conv2d_encoder(x=inputs, layer_config=layer_config, pad=pad, reuse=reuse) # Flatten hidden state, pass through dense : z = batch_flatten(encoder_layers[-1]) h, w, c = encoder_layers[-1].get_shape().as_list()[1:] z = linear_layer_ref(x=z, size=h * w * c, name='hidden_dense', initializer=normalized_columns_initializer(1.0), reuse=reuse) # Reshape back and feed to decoder: decoder_layers = conv2d_decoder(z=tf.reshape(z, [-1, h, w, c]), layer_config=layer_config, layer_shapes=shapes, pad=pad, resize_method=resize_method, reuse=reuse) y_hat = decoder_layers[-1] return encoder_layers, z, decoder_layers, y_hat, None
def lstm_network(x, a_r, lstm_class=rnn.BasicLSTMCell, lstm_layers=(256, ), reuse=False): """Stage2 network: from features to flattened LSTM output. Defines [multi-layered] dynamic [possibly shared] LSTM network. Returns: batch-wise flattened output tensor; lstm initial state tensor; lstm state output tensor; lstm flattened feed placeholders as tuple. """ with tf.variable_scope('lstm', reuse=reuse): # Flatten, add action/reward and expand with fake time dim to feed LSTM bank: x = tf.concat([batch_flatten(x), a_r], axis=-1) x = tf.expand_dims(x, [0]) # Define LSTM layers: lstm = [] for size in lstm_layers: lstm += [lstm_class(size, state_is_tuple=True)] lstm = rnn.MultiRNNCell(lstm, state_is_tuple=True) # Get time_dimension as [1]-shaped tensor: step_size = tf.expand_dims(tf.shape(x)[1], [0]) lstm_init_state = lstm.zero_state(1, dtype=tf.float32) lstm_state_pl = rnn_placeholders(lstm.zero_state(1, dtype=tf.float32)) lstm_state_pl_flatten = flatten_nested(lstm_state_pl) lstm_outputs, lstm_state_out = tf.nn.dynamic_rnn( lstm, x, initial_state=lstm_state_pl, sequence_length=step_size, time_major=False) x_out = tf.reshape(lstm_outputs, [-1, lstm_layers[-1]]) return x_out, lstm_init_state, lstm_state_out, lstm_state_pl_flatten
def __init__(self, ob_space, ac_space, ff_size=64, **kwargs): """ Simple and computationally cheap feed-forward policy. Args: ob_space: dictionary of observation state shapes ac_space: discrete action space shape (length) ff_size: feed-forward dense layer size **kwargs not used """ kwargs.update(dict( conv_2d_filter_size=[3, 1], conv_2d_stride=[2, 1], )) self.ob_space = ob_space self.ac_space = ac_space self.aux_estimate = False self.callback = {} # Placeholders for obs. state input: self.on_state_in = nested_placeholders(ob_space, batch_dim=None, name='on_policy_state_in') self.off_state_in = nested_placeholders(ob_space, batch_dim=None, name='off_policy_state_in_pl') self.rp_state_in = nested_placeholders(ob_space, batch_dim=None, name='rp_state_in') # Placeholders for concatenated action [one-hot] and reward [scalar]: self.on_a_r_in = tf.placeholder(tf.float32, [None, ac_space + 1], name='on_policy_action_reward_in_pl') self.off_a_r_in = tf.placeholder(tf.float32, [None, ac_space + 1], name='off_policy_action_reward_in_pl') # Placeholders for rnn batch and time-step dimensions: self.on_batch_size = tf.placeholder(tf.int32, name='on_policy_batch_size') self.on_time_length = tf.placeholder(tf.int32, name='on_policy_sequence_size') self.off_batch_size = tf.placeholder(tf.int32, name='off_policy_batch_size') self.off_time_length = tf.placeholder(tf.int32, name='off_policy_sequence_size') # Base on-policy AAC network: # Conv. layers: on_aac_x = conv_2d_network(self.on_state_in['external'], ob_space['external'], ac_space, **kwargs) if False: # Reshape rnn inputs for batch training as [rnn_batch_dim, rnn_time_dim, flattened_depth]: x_shape_dynamic = tf.shape(on_aac_x) max_seq_len = tf.cast(x_shape_dynamic[0] / self.on_batch_size, tf.int32) x_shape_static = on_aac_x.get_shape().as_list() on_a_r_in = tf.reshape( self.on_a_r_in, [self.on_batch_size, max_seq_len, ac_space + 1]) on_aac_x = tf.reshape( on_aac_x, [self.on_batch_size, max_seq_len, np.prod(x_shape_static[1:])]) # Feed last action_reward [, internal obs. state] into LSTM along with external state features: on_stage2_input = [on_aac_x, on_a_r_in] if 'internal' in list(self.on_state_in.keys()): x_int_shape_static = self.on_state_in['internal'].get_shape( ).as_list() x_int = tf.reshape(self.on_state_in['internal'], [ self.on_batch_size, max_seq_len, np.prod(x_int_shape_static[1:]) ]) on_stage2_input.append(x_int) on_aac_x = tf.concat(on_stage2_input, axis=-1) on_aac_x = batch_flatten(on_aac_x) # Dense layer: on_x_dense_out = tf.nn.elu( linear(on_aac_x, ff_size, 'dense_pi_v', normalized_columns_initializer(0.01), reuse=False)) # Dummy: self.on_lstm_init_state = (LSTMStateTuple(c=np.zeros((1, 1)), h=np.zeros((1, 1))), ) self.on_lstm_state_out = (LSTMStateTuple(c=np.zeros((1, 1)), h=np.zeros((1, 1))), ) self.on_lstm_state_pl_flatten = [ tf.placeholder(shape=(None, 1), dtype=tf.float32, name='dummy_c'), tf.placeholder(shape=(None, 1), dtype=tf.float32, name='dummy_h') ] # Aac policy and value outputs and action-sampling function: [self.on_logits, self.on_vf, self.on_sample] = dense_aac_network(on_x_dense_out, ac_space) # Batch-norm related (useless, ignore): try: if self.train_phase is not None: pass except AttributeError: self.train_phase = tf.placeholder_with_default( tf.constant(False, dtype=tf.bool), shape=(), name='train_phase_flag_pl') self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) # Add moving averages to save list: moving_var_list = tf.get_collection( tf.GraphKeys.GLOBAL_VARIABLES, tf.get_variable_scope().name + '.*moving.*') renorm_var_list = tf.get_collection( tf.GraphKeys.GLOBAL_VARIABLES, tf.get_variable_scope().name + '.*renorm.*') # What to save: self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name) self.var_list += moving_var_list + renorm_var_list
def __init__( self, ob_space, ac_space, rp_sequence_size, lstm_class=tf.contrib.rnn.LayerNormBasicLSTMCell, #lstm_class=rnn.BasicLSTMCell, lstm_layers=(256, 256), aux_estimate=False, encode_internal_state=False, **kwargs): """ Defines [partially shared] on/off-policy networks for estimating action-logits, value function, reward and state 'pixel_change' predictions. Expects multi-modal observation as array of shape `ob_space`. Args: ob_space: dictionary of observation state shapes ac_space: discrete action space shape (length) rp_sequence_size: reward prediction sample length lstm_class: tf.nn.lstm class lstm_layers: tuple of LSTM layers sizes aux_estimate: (bool), if True - add auxiliary tasks estimations to self.callbacks dictionary. **kwargs not used """ # 1D plug-in: kwargs.update( dict( conv_2d_filter_size=[3, 1], conv_2d_stride=[2, 1], conv_2d_num_filters=32, pc_estimator_stride=[2, 1], duell_pc_x_inner_shape=(6, 1, 32), # [6,3,32] if swapping W-C dims duell_pc_filter_size=(4, 1), duell_pc_stride=(2, 1), )) self.ob_space = ob_space self.ac_space = ac_space self.rp_sequence_size = rp_sequence_size self.lstm_class = lstm_class self.lstm_layers = lstm_layers self.aux_estimate = aux_estimate self.callback = {} self.encode_internal_state = encode_internal_state self.debug = {} # Placeholders for obs. state input: self.on_state_in = nested_placeholders(ob_space, batch_dim=None, name='on_policy_state_in') self.off_state_in = nested_placeholders(ob_space, batch_dim=None, name='off_policy_state_in_pl') self.rp_state_in = nested_placeholders(ob_space, batch_dim=None, name='rp_state_in') # Placeholders for concatenated action [one-hot] and reward [scalar]: self.on_a_r_in = tf.placeholder(tf.float32, [None, ac_space + 1], name='on_policy_action_reward_in_pl') self.off_a_r_in = tf.placeholder(tf.float32, [None, ac_space + 1], name='off_policy_action_reward_in_pl') # Placeholders for rnn batch and time-step dimensions: self.on_batch_size = tf.placeholder(tf.int32, name='on_policy_batch_size') self.on_time_length = tf.placeholder(tf.int32, name='on_policy_sequence_size') self.off_batch_size = tf.placeholder(tf.int32, name='off_policy_batch_size') self.off_time_length = tf.placeholder(tf.int32, name='off_policy_sequence_size') # Base on-policy AAC network: # Conv. layers: on_aac_x = conv_2d_network(self.on_state_in['external'], ob_space['external'], ac_space, name='conv1d_external', **kwargs) # Aux min/max_loss: if 'raw_state' in list(self.on_state_in.keys()): self.raw_state = self.on_state_in['raw_state'] self.state_min_max = tf.nn.elu( linear(batch_flatten(on_aac_x), 2, "min_max", normalized_columns_initializer(0.01))) else: self.raw_state = None self.state_min_max = None # Reshape rnn inputs for batch training as [rnn_batch_dim, rnn_time_dim, flattened_depth]: x_shape_dynamic = tf.shape(on_aac_x) max_seq_len = tf.cast(x_shape_dynamic[0] / self.on_batch_size, tf.int32) x_shape_static = on_aac_x.get_shape().as_list() on_a_r_in = tf.reshape(self.on_a_r_in, [self.on_batch_size, max_seq_len, ac_space + 1]) on_aac_x = tf.reshape( on_aac_x, [self.on_batch_size, max_seq_len, np.prod(x_shape_static[1:])]) # Prepare `internal` state, if any: if 'internal' in list(self.on_state_in.keys()): if self.encode_internal_state: # Use convolution encoder: on_x_internal = conv_2d_network( self.on_state_in['internal'], ob_space['internal'], ac_space, name='conv1d_internal', # conv_2d_layer_ref=conv2d_dw, conv_2d_num_filters=32, conv_2d_num_layers=2, conv_2d_filter_size=[3, 1], conv_2d_stride=[2, 1], ) x_int_shape_static = on_x_internal.get_shape().as_list() on_x_internal = [ tf.reshape(on_x_internal, [ self.on_batch_size, max_seq_len, np.prod(x_int_shape_static[1:]) ]) ] self.debug['state_internal_enc'] = tf.shape(on_x_internal) else: # Feed as is: x_int_shape_static = self.on_state_in['internal'].get_shape( ).as_list() on_x_internal = tf.reshape(self.on_state_in['internal'], [ self.on_batch_size, max_seq_len, np.prod(x_int_shape_static[1:]) ]) self.debug['state_internal'] = tf.shape( self.on_state_in['internal']) on_x_internal = [on_x_internal] else: on_x_internal = [] # Not used: if 'reward' in list(self.on_state_in.keys()): x_rewards_shape_static = self.on_state_in['reward'].get_shape( ).as_list() x_rewards = tf.reshape(self.on_state_in['reward'], [ self.on_batch_size, max_seq_len, np.prod(x_rewards_shape_static[1:]) ]) self.debug['rewards'] = tf.shape(x_rewards) x_rewards = [x_rewards] else: x_rewards = [] self.debug['conv_input_to_lstm1'] = tf.shape(on_aac_x) # Feed last last_reward into LSTM_1 layer along with encoded `external` state features: on_stage2_1_input = [on_aac_x, on_a_r_in[..., -1][..., None]] #+ on_x_internal # Feed last_action, encoded `external` state, `internal` state into LSTM_2: on_stage2_2_input = [on_aac_x, on_a_r_in] + on_x_internal # LSTM_1 full input: on_aac_x = tf.concat(on_stage2_1_input, axis=-1) self.debug['concat_input_to_lstm1'] = tf.shape(on_aac_x) # First LSTM layer takes encoded `external` state: [on_x_lstm_1_out, self.on_lstm_1_init_state, self.on_lstm_1_state_out, self.on_lstm_1_state_pl_flatten] =\ lstm_network(on_aac_x, self.on_time_length, lstm_class, (lstm_layers[0],), name='lstm_1') self.debug['on_x_lstm_1_out'] = tf.shape(on_x_lstm_1_out) self.debug['self.on_lstm_1_state_out'] = tf.shape( self.on_lstm_1_state_out) self.debug['self.on_lstm_1_state_pl_flatten'] = tf.shape( self.on_lstm_1_state_pl_flatten) # For time_flat only: Reshape on_lstm_1_state_out from [1,2,20,size] -->[20,1,2,size] --> [20,1, 2xsize]: reshape_lstm_1_state_out = tf.transpose(self.on_lstm_1_state_out, [2, 0, 1, 3]) reshape_lstm_1_state_out_shape_static = reshape_lstm_1_state_out.get_shape( ).as_list() reshape_lstm_1_state_out = tf.reshape( reshape_lstm_1_state_out, [ self.on_batch_size, max_seq_len, np.prod(reshape_lstm_1_state_out_shape_static[-2:]) ], ) #self.debug['reshape_lstm_1_state_out'] = tf.shape(reshape_lstm_1_state_out) # Take policy logits off first LSTM-dense layer: # Reshape back to [batch, flattened_depth], where batch = rnn_batch_dim * rnn_time_dim: x_shape_static = on_x_lstm_1_out.get_shape().as_list() rsh_on_x_lstm_1_out = tf.reshape( on_x_lstm_1_out, [x_shape_dynamic[0], x_shape_static[-1]]) self.debug['reshaped_on_x_lstm_1_out'] = tf.shape(rsh_on_x_lstm_1_out) # Aac policy output and action-sampling function: [self.on_logits, _, self.on_sample] = dense_aac_network(rsh_on_x_lstm_1_out, ac_space, name='aac_dense_pi') # Second LSTM layer takes concatenated encoded 'external' state, LSTM_1 output, # last_action and `internal_state` (if present) tensors: on_stage2_2_input += [on_x_lstm_1_out] # Try: feed context instead of output #on_stage2_2_input = [reshape_lstm_1_state_out] + on_stage2_1_input # LSTM_2 full input: on_aac_x = tf.concat(on_stage2_2_input, axis=-1) self.debug['on_stage2_2_input'] = tf.shape(on_aac_x) [on_x_lstm_2_out, self.on_lstm_2_init_state, self.on_lstm_2_state_out, self.on_lstm_2_state_pl_flatten] = \ lstm_network(on_aac_x, self.on_time_length, lstm_class, (lstm_layers[-1],), name='lstm_2') self.debug['on_x_lstm_2_out'] = tf.shape(on_x_lstm_2_out) self.debug['self.on_lstm_2_state_out'] = tf.shape( self.on_lstm_2_state_out) self.debug['self.on_lstm_2_state_pl_flatten'] = tf.shape( self.on_lstm_2_state_pl_flatten) # Reshape back to [batch, flattened_depth], where batch = rnn_batch_dim * rnn_time_dim: x_shape_static = on_x_lstm_2_out.get_shape().as_list() on_x_lstm_out = tf.reshape(on_x_lstm_2_out, [x_shape_dynamic[0], x_shape_static[-1]]) self.debug['reshaped_on_x_lstm_out'] = tf.shape(on_x_lstm_out) # Aac value function: [_, self.on_vf, _] = dense_aac_network(on_x_lstm_out, ac_space, name='aac_dense_vfn') # Concatenate LSTM placeholders, init. states and context: self.on_lstm_init_state = (self.on_lstm_1_init_state, self.on_lstm_2_init_state) self.on_lstm_state_out = (self.on_lstm_1_state_out, self.on_lstm_2_state_out) self.on_lstm_state_pl_flatten = self.on_lstm_1_state_pl_flatten + self.on_lstm_2_state_pl_flatten #if False: # Temp. disable # Off-policy AAC network (shared): off_aac_x = conv_2d_network(self.off_state_in['external'], ob_space['external'], ac_space, name='conv1d_external', reuse=True, **kwargs) # Reshape rnn inputs for batch training as [rnn_batch_dim, rnn_time_dim, flattened_depth]: x_shape_dynamic = tf.shape(off_aac_x) max_seq_len = tf.cast(x_shape_dynamic[0] / self.off_batch_size, tf.int32) x_shape_static = off_aac_x.get_shape().as_list() off_a_r_in = tf.reshape( self.off_a_r_in, [self.off_batch_size, max_seq_len, ac_space + 1]) off_aac_x = tf.reshape( off_aac_x, [self.off_batch_size, max_seq_len, np.prod(x_shape_static[1:])]) # Prepare `internal` state, if any: if 'internal' in list(self.off_state_in.keys()): if self.encode_internal_state: # Use convolution encoder: off_x_internal = conv_2d_network( self.off_state_in['internal'], ob_space['internal'], ac_space, name='conv1d_internal', # conv_2d_layer_ref=conv2d_dw, conv_2d_num_filters=32, conv_2d_num_layers=2, conv_2d_filter_size=[3, 1], conv_2d_stride=[2, 1], reuse=True, ) x_int_shape_static = off_x_internal.get_shape().as_list() off_x_internal = [ tf.reshape(off_x_internal, [ self.off_batch_size, max_seq_len, np.prod(x_int_shape_static[1:]) ]) ] else: x_int_shape_static = self.off_state_in['internal'].get_shape( ).as_list() off_x_internal = tf.reshape(self.off_state_in['internal'], [ self.off_batch_size, max_seq_len, np.prod(x_int_shape_static[1:]) ]) off_x_internal = [off_x_internal] else: off_x_internal = [] off_stage2_1_input = [off_aac_x, off_a_r_in[..., -1][..., None]] #+ off_x_internal off_stage2_2_input = [off_aac_x, off_a_r_in] + off_x_internal off_aac_x = tf.concat(off_stage2_1_input, axis=-1) [off_x_lstm_1_out, _, _, self.off_lstm_1_state_pl_flatten] =\ lstm_network(off_aac_x, self.off_time_length, lstm_class, (lstm_layers[0],), name='lstm_1', reuse=True) # Reshape back to [batch, flattened_depth], where batch = rnn_batch_dim * rnn_time_dim: x_shape_static = off_x_lstm_1_out.get_shape().as_list() rsh_off_x_lstm_1_out = tf.reshape( off_x_lstm_1_out, [x_shape_dynamic[0], x_shape_static[-1]]) [self.off_logits, _, _] =\ dense_aac_network(rsh_off_x_lstm_1_out, ac_space, name='aac_dense_pi', reuse=True) off_stage2_2_input += [off_x_lstm_1_out] # LSTM_2 full input: off_aac_x = tf.concat(off_stage2_2_input, axis=-1) [off_x_lstm_2_out, _, _, self.off_lstm_2_state_pl_flatten] = \ lstm_network(off_aac_x, self.off_time_length, lstm_class, (lstm_layers[-1],), name='lstm_2', reuse=True) # Reshape back to [batch, flattened_depth], where batch = rnn_batch_dim * rnn_time_dim: x_shape_static = off_x_lstm_2_out.get_shape().as_list() off_x_lstm_out = tf.reshape(off_x_lstm_2_out, [x_shape_dynamic[0], x_shape_static[-1]]) # Aac value function: [_, self.off_vf, _] = dense_aac_network(off_x_lstm_out, ac_space, name='aac_dense_vfn', reuse=True) # Concatenate LSTM states: self.off_lstm_state_pl_flatten = self.off_lstm_1_state_pl_flatten + self.off_lstm_2_state_pl_flatten # Aux1: # `Pixel control` network. # # Define pixels-change estimation function: # Yes, it rather env-specific but for atari case it is handy to do it here, see self.get_pc_target(): [self.pc_change_state_in, self.pc_change_last_state_in, self.pc_target] =\ pixel_change_2d_estimator(ob_space['external'], **kwargs) self.pc_batch_size = self.off_batch_size self.pc_time_length = self.off_time_length self.pc_state_in = self.off_state_in self.pc_a_r_in = self.off_a_r_in self.pc_lstm_state_pl_flatten = self.off_lstm_state_pl_flatten # Shared conv and lstm nets, same off-policy batch: pc_x = off_x_lstm_out # PC duelling Q-network, outputs [None, 20, 20, ac_size] Q-features tensor: self.pc_q = duelling_pc_network(pc_x, self.ac_space, **kwargs) # Aux2: # `Value function replay` network. # # VR network is fully shared with ppo network but with `value` only output: # and has same off-policy batch pass with off_ppo network: self.vr_batch_size = self.off_batch_size self.vr_time_length = self.off_time_length self.vr_state_in = self.off_state_in self.vr_a_r_in = self.off_a_r_in self.vr_lstm_state_pl_flatten = self.off_lstm_state_pl_flatten self.vr_value = self.off_vf # Aux3: # `Reward prediction` network. self.rp_batch_size = tf.placeholder(tf.int32, name='rp_batch_size') # Shared conv. output: rp_x = conv_2d_network(self.rp_state_in['external'], ob_space['external'], ac_space, name='conv1d_external', reuse=True, **kwargs) # Flatten batch-wise: rp_x_shape_static = rp_x.get_shape().as_list() rp_x = tf.reshape(rp_x, [ self.rp_batch_size, np.prod(rp_x_shape_static[1:]) * (self.rp_sequence_size - 1) ]) # RP output: self.rp_logits = dense_rp_network(rp_x) # Batch-norm related (useless, ignore): try: if self.train_phase is not None: pass except AttributeError: self.train_phase = tf.placeholder_with_default( tf.constant(False, dtype=tf.bool), shape=(), name='train_phase_flag_pl') self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) # Add moving averages to save list: moving_var_list = tf.get_collection( tf.GraphKeys.GLOBAL_VARIABLES, tf.get_variable_scope().name + '.*moving.*') renorm_var_list = tf.get_collection( tf.GraphKeys.GLOBAL_VARIABLES, tf.get_variable_scope().name + '.*renorm.*') # What to save: self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name) self.var_list += moving_var_list + renorm_var_list # Callbacks: if self.aux_estimate: self.callback['pixel_change'] = self.get_pc_target
def beta_var_conv2d_autoencoder( inputs, layer_config, resize_method=tf.image.ResizeMethod.BILINEAR, pad='SAME', linear_layer_ref=linear, name='vae_conv2d', max_batch_size=256, reuse=False ): """ Variational autoencoder. Papers: https://arxiv.org/pdf/1312.6114.pdf https://arxiv.org/pdf/1606.05908.pdf http://www.matthey.me/pdf/betavae_iclr_2017.pdf Args: inputs: input tensor layer_config: layers configuration list: [layer_1_config, layer_2_config,...], where: layer_i_config = [num_filters(int), filter_size(list), stride(list)]; this list represent decoder part of autoencoder bottleneck, decoder part is inferred symmetrically resize_method: up-sampling method, one of supported tf.image.ResizeMethod's pad: str, padding scheme: 'SAME' or 'VALID' linear_layer_ref: linear layer class - not used name: str, mame scope max_batch_size: int, dynamic batch size should be no greater than this value reuse: bool Returns: list of tensors holding encoded features, layer_wise from outer to inner tensor holding batch-wise flattened hidden state vector list of tensors holding decoded features, layer-wise from inner to outer tensor holding reconstructed output tensor holding estimated KL divergence """ with tf.variable_scope(name, reuse=reuse): # Encode: encoder_layers, shapes = conv2d_encoder( x=inputs, layer_config=layer_config, pad=pad, reuse=reuse ) # Flatten hidden state, pass through dense: z_flat = batch_flatten(encoder_layers[-1]) h, w, c = encoder_layers[-1].get_shape().as_list()[1:] z = tf.nn.elu( linear( x=z_flat, size=h * w * c, name='enc_dense', initializer=normalized_columns_initializer(1.0), reuse=reuse ) ) # TODO: revert back to dubled Z-size # half_size_z = h * w * c # size_z = 2 * half_size_z size_z = int(h * w * c/2) z = tf.nn.elu( linear( #x=z_flat, x=z, #size=size_z, size=size_z * 2, name='hidden_dense', initializer=normalized_columns_initializer(1.0), reuse=reuse ) ) # Get sample parameters: #mu, log_sigma = tf.split(z, [half_size_z, half_size_z], axis=-1) mu, log_sigma = tf.split(z, [size_z, size_z], axis=-1) # Oversized noise generator: #eps = tf.random_normal(shape=[max_batch_size, half_size_z], mean=0., stddev=1.) eps = tf.random_normal(shape=[max_batch_size, size_z], mean=0., stddev=1.) eps = eps[:tf.shape(z)[0],:] # Get sample z ~ Q(z|X): z_sampled = mu + tf.exp(log_sigma / 2) * eps # D_KL(Q(z|X) || P(z|X)): # TODO: where is sum?! d_kl = 0.5 * (tf.exp(log_sigma) + tf.square(mu) - 1. - log_sigma) # Reshape back and feed to decoder: z_sampled_dec = tf.nn.elu( linear( x=z_sampled, size=h * w * c, name='dec_dense', initializer=normalized_columns_initializer(1.0), reuse=reuse ) ) decoder_layers = conv2d_decoder( z=tf.reshape(z_sampled_dec, [-1, h, w, c]), layer_config=layer_config, layer_shapes=shapes, pad=pad, resize_method=resize_method, reuse=reuse ) y_hat = decoder_layers[-1] return encoder_layers, z_sampled, decoder_layers, y_hat, d_kl