def __init__(self, session, action_dist, obs_vectorizer, make_cell): super(RNNCellAC, self).__init__(session, action_dist, obs_vectorizer) obs_seq_shape = (None, None) + obs_vectorizer.out_shape self.obs_ph = tf.placeholder(tf.float32, obs_seq_shape) self.mask_ph = tf.placeholder(tf.float32, (None, None)) with tf.variable_scope('cell_input'): cell_input = self.cell_input_sequences() with tf.variable_scope('cell'): cell = make_cell() with tf.variable_scope('states'): if isinstance(cell.state_size, tuple): self.create_state_fields(tf.float32, tuple(nest.flatten(cell.state_size))) else: self.create_state_fields(tf.float32, cell.state_size) init_state = mix_init_states(self.is_init_state_ph, self.init_state_vars, self.first_state_ph) with tf.variable_scope('base'): if isinstance(init_state, tuple): init_state = nest.pack_sequence_as(cell.state_size, init_state) self.base_out, states = tf.nn.dynamic_rnn( cell, cell_input, sequence_length=self.seq_lens_ph, initial_state=init_state) if isinstance(states, tuple): self.states_out = tuple(nest.flatten(states)) else: self.states_out = states with tf.variable_scope('actor'): self.actor_out = self.actor(self.base_out) with tf.variable_scope('critic'): self.critic_out = self.critic(self.base_out)
def _prepare_memory(memory, memory_sequence_length, check_inner_dims_defined): """Convert to tensor and possibly mask `memory`. Args: memory: `Tensor`, shaped `[batch_size, max_time, ...]`. memory_sequence_length: `int32` `Tensor`, shaped `[batch_size]`. check_inner_dims_defined: Python boolean. If `True`, the `memory` argument's shape is checked to ensure all but the two outermost dimensions are fully defined. Returns: A (possibly masked), checked, new `memory`. Raises: ValueError: If `check_inner_dims_defined` is `True` and not `memory.shape[2:].is_fully_defined()`. """ memory = nest.map_structure( lambda m: tf.convert_to_tensor(m, name="memory"), memory) if memory_sequence_length is not None: memory_sequence_length = tf.convert_to_tensor( memory_sequence_length, name="memory_sequence_length") if check_inner_dims_defined: def _check_dims(m): if not m.get_shape()[2:].is_fully_defined(): raise ValueError("Expected memory %s to have fully defined inner dims, " "but saw shape: %s" % (m.name, m.get_shape())) nest.map_structure(_check_dims, memory) if memory_sequence_length is None: seq_len_mask = None else: seq_len_mask = tf.sequence_mask( memory_sequence_length, maxlen=tf.shape(nest.flatten(memory)[0])[1], dtype=nest.flatten(memory)[0].dtype) seq_len_batch_size = ( tf.dimension_value(memory_sequence_length.shape[0]) or tf.shape(memory_sequence_length)[0]) def _maybe_mask(m, seq_len_mask): rank = m.get_shape().ndims rank = rank if rank is not None else tf.rank(m) extra_ones = tf.ones(rank - 2, dtype=tf.int32) m_batch_size = tf.dimension_value( m.shape[0]) or tf.shape(m)[0] if memory_sequence_length is not None: message = ("memory_sequence_length and memory tensor batch sizes do not " "match.") with tf.control_dependencies([ tf.assert_equal( seq_len_batch_size, m_batch_size, message=message)]): seq_len_mask = tf.reshape( seq_len_mask, tf.concat((tf.shape(seq_len_mask), extra_ones), 0)) return m * seq_len_mask else: return m return nest.map_structure(lambda m: _maybe_mask(m, seq_len_mask), memory)
def step(self, step_type, reward, observation): if self._sync_checker.should_sync(self._step_number): self.sync() # bass the batch through pre-processing step_type, reward, obs, next_state = self._agent.step_preprocess( step_type, reward, observation, self.next_state) nest.assert_same_structure(self._obs_ph, observation) obs_feed_dict = { obs_ph: obs_val for obs_ph, obs_val in zip(nest.flatten(self._obs_ph), nest.flatten(observation)) } step_output = self.sess.run(self._step_output, feed_dict={ self._step_type_ph: step_type, self._reward_ph: reward, self._next_state_ph: next_state, **obs_feed_dict, }) if self.verbose and self._step_number % 100 == 0: print(step_output) self._next_state = step_output.next_state self._step_number += 1 return step_output
def forward_pass(self): x = tf.transpose(self.inputs, perm=[0, 2, 1, 3]) x = tf.reshape(x, [self.batch_size * NOTE_LEN, -1, 80]) # time model with tf.variable_scope('time_model'): time_out, time_state \ = tf.nn.dynamic_rnn(cell=self.time_lstm_cell, inputs=x, initial_state=self.time_state, dtype=tf.float32) for tensor in nest.flatten(time_state): tf.add_to_collection('time_state_output', tensor) # reshape from note invariant to time invariant hidden = tf.reshape( time_out, [self.batch_size, NOTE_LEN, -1, self.time_sizes[1]]) hidden = tf.transpose(hidden, perm=[0, 2, 1, 3]) hidden = tf.reshape(hidden, [-1, NOTE_LEN, self.time_sizes[1]]) self.time_out = hidden start_label = tf.zeros([self.batch_size * SEQ_LEN, 1, 2]) correct_choices, _ = tf.split(self.labels, [NOTE_LEN - 1, 1], 2) correct_choices = tf.reshape( correct_choices, [self.batch_size * SEQ_LEN, NOTE_LEN - 1, 2]) correct_choices = tf.concat([start_label, correct_choices], 1) hidden = tf.concat([hidden, correct_choices], 2) self.note_input = tf.placeholder_with_default( hidden, shape=[None, None, self.time_sizes[1] + 2]) note_len = tf.shape(self.note_input)[1] # note model with tf.variable_scope('note_model'): note_out, note_state \ = tf.nn.dynamic_rnn(cell=self.note_lstm_cell, inputs=self.note_input, initial_state=self.note_state, dtype=tf.float32) for tensor in nest.flatten(note_state): tf.add_to_collection('note_state_output', tensor) # dense layer W = tf.Variable( tf.random_normal([self.note_sizes[-1], 2], stddev=0.01, dtype=tf.float32)) b = tf.Variable(tf.random_normal([2], stddev=0.01, dtype=tf.float32)) note_out = tf.tensordot(note_out, W, axes=[[2], [0]]) + b note_out = tf.reshape(note_out, [self.batch_size, -1, note_len, 2]) return time_state, note_state, tf.nn.sigmoid(note_out)
def _feed_obs(obs, action): """prepare feed_dict from input obs""" feed_dict = {ob_ph: [ob_np] for ob_ph, ob_np in zip(nest.flatten(self.inputs_ph.X), nest.flatten(obs))} if self._state is not None: feed_dict[self.inputs_ph.S] = [self._state] # always one-step, non-terminal feed_dict[self.inputs_ph.M] = [np.zeros(shape=())] if action is not None: assert self.inputs_ph.A is not None for ac_ph, ac_np in zip(nest.flatten(self.inputs_ph.A), nest.flatten(action)): feed_dict[ac_ph] = [ac_np] return feed_dict
def rnn_model(inputs, shape, embedding_matrix): """make an unrolled RNN over the inputs. Not optimised for GPU""" with tf.variable_scope('rnn'): inputs = tf.nn.embedding_lookup(embedding_matrix, inputs) input_shape = inputs.get_shape().as_list() vocab_size = embedding_matrix.get_shape()[0].value cells = [tf.nn.rnn_cell.GRUCell(n) for n in shape] cell = tf.nn.rnn_cell.MultiRNNCell(cells) # won't work with LSTMs initial_state = tuple( tf.get_variable('state_{}'.format(i), shape=[input_shape[0], c.state_size], dtype=tf.float32, initializer=tf.zeros_initializer()) for i, c in enumerate(cells)) outputs, final_state = tf.nn.dynamic_rnn(cell, inputs, initial_state=initial_state) # we're always going to roll this over every time the output is # evaluated state_updates = nest.map_structure(tf.assign, initial_state, final_state) state_updates = nest.flatten(state_updates) with tf.control_dependencies(state_updates): outputs = tf.reshape(outputs, [-1, shape[-1]]) outputs = tf.layers.dense(outputs, vocab_size, activation=None) outputs = tf.reshape( outputs, [input_shape[0] or -1, input_shape[1], vocab_size]) return outputs
def _apply_rnn_encoder_output_layer(output_layer, time_major, hparams, mode, cell_outputs, cell_output_size): map_func = functools.partial( _forward_output_layers, output_layer=output_layer, time_major=time_major, hparams=hparams, mode=mode) cell_outputs_flat = nest.flatten(cell_outputs) cell_output_size_flat = nest.flatten(cell_output_size) o = [map_func(inputs=x, input_size=xs) for x, xs in zip(cell_outputs_flat, cell_output_size_flat)] outputs_flat, output_size_flat = zip(*o) outputs = nest.pack_sequence_as(cell_outputs, outputs_flat) output_size = nest.pack_sequence_as(cell_outputs, output_size_flat) return outputs, output_size
def _forward(self, obs, fetches, action=None): feed_dict = self._feed_obs(obs) if self._state is not None: fetches['state'] = self.net_out.S if action is not None: for ac_ph, ac_np in zip(nest.flatten(self.inputs_ph.A), nest.flatten(action)): feed_dict[ac_ph] = [ac_np] ret = self.sess.run(fetches, feed_dict=feed_dict) # NOTE: do not squeeze the batch_size dim here! self._last_state = self._state self._state = (None if self._state is None else _squeeze_batch_size_singleton_dim(ret['state'])) return ret
def _feed_obs(self, obs): """prepare feed_dict from input obs""" # NOTE1: safe to nest.flatten even an OrderedDict, in which case .X and obs # are guaranteed to have exactly the same keys leading to the same flattening # order! # NOTE2: the weired-looking [ob_np] is a quick hacking for native python # types (int32, float, etc.), which in effect inserts a leading # batch_size = 1 dim. feed_dict = { ob_ph: [ob_np] for ob_ph, ob_np in zip(nest.flatten(self.inputs_ph.X), nest.flatten(obs)) } if self._state is not None: feed_dict[self.inputs_ph.S] = [self._state] # always one-step, non-terminal feed_dict[self.inputs_ph.M] = [np.zeros(shape=())] return feed_dict
def __init__(self, inputs, labels, keep_prob, time_sizes=[300, 300], note_sizes=[100, 50]): self.inputs = inputs # input shape (batch, time, note, feature) self.labels = labels # label shape (batch, time, note, out) self.batch_size = tf.shape(self.inputs)[0] self.keep_prob = keep_prob self.time_sizes = time_sizes self.note_sizes = note_sizes self.time_lstm_cell = tf.contrib.rnn.MultiRNNCell([ tf.contrib.rnn.DropoutWrapper( tf.contrib.rnn.LSTMCell(sz, state_is_tuple=True), output_keep_prob=self.keep_prob) for sz in self.time_sizes ], state_is_tuple=True) self.time_state = nest.map_structure( lambda x: tf.placeholder_with_default(x, x.shape, x.op.name), self.time_lstm_cell.zero_state(self.batch_size * NOTE_LEN, tf.float32)) for tensor in nest.flatten(self.time_state): tf.add_to_collection('time_state_input', tensor) self.note_lstm_cell = tf.contrib.rnn.MultiRNNCell([ tf.contrib.rnn.DropoutWrapper( tf.contrib.rnn.LSTMCell(sz, state_is_tuple=True), output_keep_prob=self.keep_prob) for sz in self.note_sizes ], state_is_tuple=True) self.note_state = nest.map_structure( lambda x: tf.placeholder_with_default(x, x.shape, x.op.name), self.note_lstm_cell.zero_state(self.batch_size * SEQ_LEN, tf.float32)) for tensor in nest.flatten(self.note_state): tf.add_to_collection('note_state_input', tensor) self.final_time_state, self.final_note_state, self.prediction \ = self.forward_pass() self.loss = self.loss_function() self.optimize = self.optimizer()
def split_reshape(x, i, n): with tf.name_scope('split_reshape', [x, i, n]): s = tf_shape(x) if np.issubdtype(type(s[i]), np.integer): assert (s[i] % n) == 0, 'x.shape[i] must be divisible by n! {}/{}'.format( s[i], n) s = nest.flatten([s[:i], (s[i] / n), n, s[i + 1:]]) x = tf.reshape(x, s) return x
def step(self, step_type, reward, observation): # bass the batch through pre-processing step_type, reward, obs, next_state = self._agent.step_preprocess( step_type, reward, observation, self.next_state) nest.assert_same_structure(self._obs_ph, observation) obs_feed_dict = { obs_ph: obs_val for obs_ph, obs_val in zip(nest.flatten(self._obs_ph), nest.flatten(observation)) } step_output = self.sess.run(self._step_output, feed_dict={ self._step_type_ph: step_type, self._reward_ph: reward, self._next_state_ph: next_state, **obs_feed_dict, }) self._next_state = step_output.next_state self._step_number += 1 return step_output
def build_loss(self, model, input_data): entropy_list = nest.flatten(model.loss.entropy_loss) if isinstance(self.ent_coef, list): assert len(entropy_list) == len( self.ent_coef), 'Lengths of ent and ent_coef mismatch.' print('ent_coef: {}'.format(self.ent_coef)) entropy = tf.reduce_sum( [e * ec for e, ec in zip(entropy_list, self.ent_coef)]) else: entropy = tf.reduce_sum(entropy_list) * self.ent_coef distill_loss = tf.constant(0, dtype=tf.float32) if self.distillation: distill_losses = nest.flatten(model.loss.distill_loss) if isinstance(self.distill_coef, list): assert len(distill_losses) == len( self.distill_coef ), 'Lengths of distill and distill_coef mismatch.' print('distill_coef: {}'.format(self.distill_coef)) distill_loss = tf.reduce_sum([ d * dc for d, dc in zip(distill_losses, self.distill_coef) ]) else: distill_loss = tf.reduce_sum( distill_losses) * self.distill_coef if isinstance(self.vf_coef, list): value_shape = model.loss.value_loss.shape assert len(value_shape) == 1 and value_shape[0] == len( self.vf_coef) print('vf_coef: {}'.format(self.vf_coef)) value_loss = tf.reduce_sum(model.loss.value_loss * tf.constant(self.vf_coef)) else: value_loss = tf.reduce_sum(model.loss.value_loss) * self.vf_coef ep_loss = tf.constant(0, dtype=tf.float32) for loss_name, loss_coef in self.ep_loss_coef.items(): ep_loss += model.loss.loss_endpoints[loss_name] * loss_coef loss = (model.loss.pg_loss + value_loss - entropy + distill_loss + ep_loss) return loss, value_loss, model.loss.loss_endpoints.values()
def while_fn(*args): current_iteration = args[0] persistent_values = args[1] transient_values = args[2] current_tensor_arrays = args[3] if time_major: input_values = inputs[current_iteration] else: input_values = inputs[:, current_iteration] new_persistent, new_transient = loop_fn(input_values, persistent_values, transient_values) flat_new_persistent = nest.flatten(new_persistent) flat_tensor_arrays = nest.flatten(current_tensor_arrays) flat_written_tensor_arrays = [ ta.write(current_iteration, a) for ta, a in zip(flat_tensor_arrays, flat_new_persistent) ] new_tensor_arrays = nest.pack_sequence_as(current_tensor_arrays, flat_written_tensor_arrays) return current_iteration + 1, new_persistent, new_transient, new_tensor_arrays
def multi_head_xe_loss(inputs_action_logits, inputs_action_labels, inputs_mask_weights): """"Multi-head cross entropy loss. Args: inputs_action_logits: logits, organized in a nest.map_structure compatible structure of Tensors. inputs_action_labels: inputs_mask_weights: Returns: A Tensor, total loss. A structure of Tensor, per-head loss. The same structure as inputs. """ def _each_xe_loss(a_logits, a_label, weight): # make weight broadcast-able while a_label.shape.rank > weight.shape.rank: weight = tf.expand_dims(weight, axis=-1) if a_logits.shape.rank == a_label.shape.rank: # deemed as MultiBinary (multi label, each is zero/one) # e.g., a_label: (bs, 600), a_logits: (bs, 600) loss = tf.losses.sigmoid_cross_entropy( multi_class_labels=a_label, logits=a_logits, weights=weight, reduction=tf.losses.Reduction.NONE # keep the batch_size dim ) else: # deemed as Discrete (mutually exclusive multi-class) # e.g., a_label: (bs, d1,..., dM), a_logits: (bs, d1,..., dM, K) loss = tf.losses.sparse_softmax_cross_entropy( labels=a_label, logits=a_logits, weights=weight, reduction=tf.losses.Reduction.NONE # keep the batch_size dim ) # make sure the loss in shape (bs,) while loss.shape.rank > 1: loss = tf.reduce_sum(loss, axis=-1) return loss head_xe_loss = nest.map_structure(_each_xe_loss, inputs_action_logits, inputs_action_labels, inputs_mask_weights) final_xe_loss = tf.add_n(nest.flatten(head_xe_loss)) return final_xe_loss, head_xe_loss
def _local_spatial_attention(query): # If the query is a tuple (when stacked RNN/LSTM), flatten it if hasattr(query, "__iter__"): query_list = nest.flatten(query) for q in query_list: ndims = q.get_shape().ndims if ndims: assert ndims == 2 query = tf.concat(query_list, 1) with tf.variable_scope('local_spatial_attn_Wl'): h_s = query Wl_hs_bl = tf.keras.layers.Dense(units=local_attn_size, use_bias=True)(h_s) Wl_hs_bl = tf.reshape(Wl_hs_bl, [-1, 1, 1, local_attn_size]) score = tf.reduce_sum(vl * tf.nn.tanh(Wl_hs_bl + Ul_x), [2, 3]) # ! Ux is a 4 dims matrix, have to use reduce_sum here attention_weights = tf.nn.softmax(score) return attention_weights
def _attention(query): if hasattr(query, "__iter__"): query_list = nest.flatten(query) for q in query_list: # Check that ndims == 2 if specified. ndims = q.get_shape().ndims if ndims: assert ndims == 2 query = tf.concat(query_list, 1) with tf.variable_scope('attention'): d_s = query W_ds_b = tf.keras.layers.Dense(units=attn_size, use_bias=True)(d_s) W_ds_b = tf.reshape(W_ds_b, [-1, 1, 1, attn_size]) score = tf.reduce_sum(v_d * tf.nn.tanh(W_ds_b + W_h), [2, 3]) attention_weights = tf.nn.softmax(score) context_vector = tf.reduce_sum( tf.reshape(attention_weights, [-1, attn_length, 1, 1]) * h_o, [1, 2]) context_vector = tf.reshape(context_vector, [-1, attn_size]) return context_vector
def __call__(self, inputs): end_points = {} # Embedding layer with tf.device("/cpu:0"): inputs = tf.contrib.layers.embed_sequence( inputs, vocab_size=self.params.vocab_size, embed_dim=self.params.hidden_size, scope="embedding") end_points["inputs"] = inputs if self.params.keep_prob < 1 and self.training: inputs = tf.layers.dropout(inputs, rate=self.params.keep_prob, training=self.training) # RNN graph cell = tf.contrib.rnn.MultiRNNCell( [self.make_cell() for _ in range(self.params.num_layers)], state_is_tuple=True) end_points["cell"] = cell initial_state = cell.zero_state(self.params.batch_size, tf.float32) end_points["initial_state"] = initial_state for tensor in nest.flatten(initial_state): tf.add_to_collection('rnn_input_state', tensor) inputs = tf.unstack(inputs, num=self.params.num_steps, axis=1) outputs, state = tf.nn.static_rnn(cell, inputs, initial_state=initial_state) output = tf.reshape(tf.concat(outputs, 1), [-1, self.params.hidden_size]) end_points["output"] = output end_points["output_state"] = state logits = tf.layers.dense(output, self.params.vocab_size) logits = tf.reshape(logits, [ self.params.batch_size, self.params.num_steps, self.params.vocab_size ]) end_points["logits"] = logits return logits, end_points
def decoder(feed_previous_bool): reuse = None if feed_previous_bool else True with tf.variable_scope(tf.get_variable_scope(), reuse=reuse): outputs, state = embedding_attention_decoder( decoder_inputs, encoder_state, attention_states, dec_cell, num_decoder_symbols, embedding_size, num_heads=num_heads, output_size=output_size, output_projection=output_projection, feed_previous=feed_previous_bool, update_embedding_for_previous=False, initial_state_attention=initial_state_attention) state_list = [state] if nest.is_sequence(state): state_list = nest.flatten(state) return outputs + state_list
def sgdstore_model(in_seqs): """ Apply an sgdstore model to the sequences. """ controller = tf.contrib.rnn.BasicRNNCell(64) layer = sgdstore.Stack(sgdstore.FC(4, 32), sgdstore.FC(32, 4)) sgdcell = sgdstore.Cell(layer, train_batch=4, query_batch=4) cell = tf.contrib.rnn.MultiRNNCell([controller, sgdcell]) init_state = (controller.zero_state(1, tf.float32), sgdcell.random_state(1, tf.float32)) init_vars = [tf.Variable(x) for x in nest.flatten(init_state)] repeated_vars = [ tf.tile(x, multiples=[BATCH_SIZE] + ([1] * (len(x.get_shape()) - 1))) for x in init_vars ] init_state = nest.pack_sequence_as(cell.state_size, repeated_vars) query_res = tf.nn.dynamic_rnn(cell, in_seqs, initial_state=init_state)[0] return tf.contrib.layers.fully_connected(query_res, 1, activation_fn=None)
def __call__(self, inputs): end_points = {} # Embedding layer with tf.device("/cpu:0"): inputs = tf.contrib.layers.embed_sequence( inputs, vocab_size=self.params.vocab_size, embed_dim=self.params.hidden_size, scope="embedding") end_points["inputs"] = inputs if self.params.keep_prob < 1 and self.training: inputs = tf.layers.dropout( inputs, rate=self.params.keep_prob, training=self.training) # RNN graph cell = tf.contrib.rnn.MultiRNNCell( [self.make_cell() for _ in range(self.params.num_layers)], state_is_tuple=True) end_points["cell"] = cell initial_state = cell.zero_state(self.params.batch_size, tf.float32) end_points["initial_state"] = initial_state for tensor in nest.flatten(initial_state): tf.add_to_collection('rnn_input_state', tensor) inputs = tf.unstack(inputs, num=self.params.num_steps, axis=1) outputs, state = tf.nn.static_rnn(cell, inputs, initial_state=initial_state) output = tf.reshape(tf.concat(outputs, 1), [-1, self.params.hidden_size]) end_points["output"] = output end_points["output_state"] = state logits = tf.layers.dense(output, self.params.vocab_size) logits = tf.reshape(logits, [self.params.batch_size, self.params.num_steps, self.params.vocab_size]) end_points["logits"] = logits return logits, end_points
def axial_reshape(x, ix): with tf.name_scope('axial_reshape', [x, ix]): # rectify ix ix = [([e] if np.isscalar(e) else e) for e in ix] # resolve input shape s = tf_shape(x) s = x.get_shape().as_list() ix_f = nest.flatten(ix) assert (len(s) - 1) == np.max(ix_f) # assert input correctness # transpose if necessary if not np.all(np.diff(ix_f) == 1): x = tf.transpose(x, ix_f) # reshape tm = nest.pack_sequence_as(ix, [s[i] for i in ix_f]) s_out = [reduce(merge_dim, e, 1) for e in tm] x = tf.reshape(x, s_out) return x
def multi_head_neglogp_loss(inputs_action_pds, inputs_action_labels, inputs_mask_weights, set_loss=False): """"Multi-head neglogp loss. Args: inputs_action_pds: pds, organized in a nest.map_structure compatible structure of Tensors. inputs_action_labels: inputs_mask_weights: Returns: A Tensor, total loss. A structure of Tensor, per-head loss. The same structure as inputs. """ def _each_neglogp_loss(a_pd, a_label, weight): # make weight broadcast-able while (a_label.shape.rank > weight.shape.rank + isinstance(a_pd, MaskSeqCategoricalPd)): weight = tf.expand_dims(weight, axis=-1) if isinstance(a_pd, MaskSeqCategoricalPd): if set_loss: loss = a_pd.set_xe(a_label, mean=True) * weight else: loss = a_pd.neglogp(a_label, mean=True) * weight else: loss = a_pd.neglogp(a_label) * weight # make sure the loss in shape (bs,) while loss.shape.rank > 1: loss = tf.reduce_sum(loss, axis=-1) return loss head_neglogp_loss = nest.map_structure(_each_neglogp_loss, inputs_action_pds, inputs_action_labels, inputs_mask_weights) final_neglogp_loss = tf.add_n(nest.flatten(head_neglogp_loss)) return final_neglogp_loss, head_neglogp_loss
def _global_spatial_attention(query): if hasattr(query, "__iter__"): query_list = nest.flatten(query) for q in query_list: # Check that ndims == 2 if specified. ndims = q.get_shape().ndims if ndims: assert ndims == 2 query = tf.concat(query_list, 1) with tf.variable_scope('global_spatial_attn_Wl'): h_s = query Wg_hs_bg = tf.keras.layers.Dense(units=global_attn_size, use_bias=True)(h_s) Wg_hs_bg = tf.reshape(Wg_hs_bg, [-1, 1, 1, global_attn_size]) score = tf.reduce_sum(vg * tf.nn.tanh(Wg_hs_bg + Wg_Xl_ug), [2, 3]) attention_weights = tf.nn.softmax(score) # Sometimes it's not easy to find a measurement to denote similarity between sensors, # here we omit such prior knowledge in eq.[4]. # You can use "a = nn_ops.softmax((1-lambda)*s + lambda*sim)" to encode similarity info, # where: # sim: a vector with length n_sensors, describing the sim between the target sensor and the others # lambda: a trade-off. # attention_weights = tf.softmax((1-self.sm_rate)*score+self.sm_rate*self.similarity_graph) return attention_weights
def mnet_v6d6_loss(inputs: MNetV6Inputs, outer_fed_heads, value_head, consts: MNetV6Consts, nc: MNetV6Config, net_level_scope: str, structured_mw=None, scope=None): # regularization loss. Only `variable`s are involved, so it is safe to # collect them using regular expression, e.g., 'mnet_v5.*', regardless # of the current name_scope (e.g., 'mnet_v5_1', 'mnet_v5_2', ...) total_reg_loss = tf.losses.get_regularization_loss( scope='{}.*'.format(net_level_scope)) total_il_loss = None pg_loss = None value_loss = None entropy_loss = None distill_loss = None loss_endpoints = {} example_ac_sp = tp_utils.map_gym_space_to_structure( lambda x: None, nc.ac_space) with tf.variable_scope(scope, default_name='mnet_v6_losses'): if nc.use_loss_type in ['il', 'rl', 'rl_ppo', 'rl_ppo2', 'rl_vtrace']: # head masks and structure template if structured_mw is None: mw = _action_mask_weights(inputs_ab=inputs.A['A_AB'], inputs_arg_mask=consts.arg_mask, weights_include_ab=True) structured_mw = tp_utils.pack_sequence_as_structure_like_gym_space( nc.ac_space, mw) outer_fed_head_pds = nest.map_structure_up_to( example_ac_sp, lambda head: head.pd, outer_fed_heads) if nc.use_loss_type == 'il': # build imitation learning loss the cross entropy total_il_loss, head_xe_loss = tp_losses.multi_head_neglogp_loss( inputs_action_pds=outer_fed_head_pds, inputs_action_labels=inputs.A, inputs_mask_weights=structured_mw, set_loss=nc.il_multi_label_loss, ) assert type(head_xe_loss) == OrderedDict loss_endpoints = head_xe_loss elif nc.use_loss_type in ['rl', 'rl_ppo', 'rl_ppo2', 'rl_vtrace']: # build rl losses # the entropy regularizer entropy_loss = nest.map_structure_up_to( example_ac_sp, lambda head, mask: tf.reduce_mean(head.ent * mask), outer_fed_heads, structured_mw) # distillation loss, i.e., the teacher-student KL regularizer distill_loss = None ab_distill_loss = None if nc.distillation: outer_fed_head_pds = nest.map_structure_up_to( example_ac_sp, lambda head: head.pd, outer_fed_heads) distill_loss = tp_losses.distill_loss( student_pds=outer_fed_head_pds, teacher_logits=inputs.logits, masks=structured_mw) ab_pd = outer_fed_head_pds['A_AB'] teacher_logit = inputs.logits['A_AB'] # TODO: this is from definition of position encoding, remove it? first_4mins_mask = tf.cast( inputs.X['X_VEC_GAME_PROG'][:, -1] >= np.cos( 60 * 4 * np.power(10000, -62 / 64)), tf.float32) first_4mins_mask *= tf.cast((tf.reduce_sum( inputs.X['Z_BUILD_ORDER'], axis=[1, 2]) > 0), tf.float32) ab_distill_loss = tp_losses.distill_loss( ab_pd, teacher_logit, first_4mins_mask) # the main policy gradient loss outer_fed_head_neglogp = nest.map_structure_up_to( example_ac_sp, lambda head, ac: head.pd.neglogp(ac), outer_fed_heads, inputs.A) loss_endpoints = {} if nc.use_loss_type == 'rl' or nc.use_loss_type == 'rl_ppo': # PPO loss pg_loss, value_loss = tp_losses.ppo_loss( outer_fed_head_neglogp, inputs.neglogp, value_head, inputs.R, inputs.V, masks=structured_mw, reward_weights=nc.reward_weights, merge_pi=nc.merge_pi, adv_normalize=nc.adv_normalize, clip_range=nc.clip_range, sync_statistics=nc.sync_statistics, ) elif nc.use_loss_type in ['rl_ppo2', 'rl_vtrace']: # Note: we need convert the shape (batch_size, ...) to the shape # (T, B, ...) where T=nc.rollout_len, B=nc.nrollout, batch_size=B*T # When computing ppo2-loss and value-loss, only T-1 time steps are # used due to the value bootstrap at the tail. When doing so, the # [:-1] indexing, leading to (T - 1, B, ...) tensor slice, makes life # much easier def _batch_to_TB(tsr): return tf.transpose( tf.reshape(tsr, shape=(nc.nrollout, nc.rollout_len))) # make the len=n_action_heads lists for action-head stuff # for tensor entry, shape (batch_size, ...) -> shape (T, B, ...) neglogp_list = [ _batch_to_TB(neglogp) for neglogp in nest.flatten(outer_fed_head_neglogp) ] oldneglogp_list = [ _batch_to_TB(oldneglogp) for oldneglogp in nest.flatten(inputs.neglogp) ] mask_list = [ _batch_to_TB(mw) for mw in nest.flatten(structured_mw) ] # make the len=n_v lists for value-head stuff # for tensor entry, shape (batch_size, ...) -> shape (T, B, ...) # as aforementioned vpred_list = [ _batch_to_TB(v) for v in tf.split(value_head, nc.n_v, axis=1) ] reward_list = [ _batch_to_TB(r) for r in tf.split(inputs.r, nc.n_v, axis=1) ] discounts = _batch_to_TB(inputs.discount) # upgo_loss only use the win_loss, i.e, v[0] upgo_loss = tp_losses.upgo_loss( tf.stack(neglogp_list, axis=-1), tf.stack(oldneglogp_list, axis=-1), tf.stack(mask_list, axis=-1), vpred_list[0], reward_list[0], discounts) loss_endpoints['upgo_loss'] = upgo_loss if nc.use_loss_type == 'rl_ppo2': # PPO2 loss # reward_weights size should be consistent with n_v reward_weights = tf.squeeze( tf.convert_to_tensor(nc.reward_weights, tf.float32)) assert reward_weights.shape.as_list( )[0] == len(reward_list), ( 'For ppo2 loss, reward_weight size must be the same with number of' ' value head: each reward_weight element must correspond to one ' 'value-head exactly.') # lambda for td-lambda or lambda-return assert nc.lam is not None, ( 'building rl_ppo2, but lam for ' 'lambda-return is None.') lam = tf.convert_to_tensor(nc.lam, tf.float32) # for each value-head, compute the corresponding policy gradient loss # and the value loss pg_loss, value_loss = [], [] for vpred, reward in zip(vpred_list, reward_list): # compute the lambda-Return `R` in shape (T - 1, B) # [:-1] means discarding the last one, # [1:] means an off-one alignment. # back_prop=False means R = tf.stop_gradient(R) with tf.device("/cpu:0"): R = multistep_forward_view(reward[:-1], discounts[:-1], vpred[1:], lambda_=lam, back_prop=False) # compute the ppo2 loss using this value-head for each of the # n_action_heads action-head; then reduce them # [:-1] means discarding the last one and using only T - 1 time # steps _ploss = [ tp_losses.ppo2_loss( neglogp[:-1], oldneglogp[:-1], tf.stop_gradient(vpred)[:-1], R, # has been stop_gradient above mask[:-1], adv_normalize=nc.adv_normalize, clip_range=nc.clip_range, sync_statistics=nc.sync_statistics) for neglogp, oldneglogp, mask in zip( neglogp_list, oldneglogp_list, mask_list) ] pg_loss.append(tf.reduce_sum(_ploss)) # compute the value loss for this value-head value_loss.append( tf.reduce_mean(0.5 * tf.square(R - vpred[:-1]))) # element-wise times reward_weight and the pg_loss for that value-head pg_loss = tf.stack( pg_loss) * reward_weights # shape (n_v,) # make the final pg_loss, value_loss in desired format pg_loss = tf.reduce_sum(pg_loss) value_loss = tf.stack(value_loss) else: # vtrace loss # lambda for td-lambda or lambda-return assert nc.lam is not None, ( 'building rl_vtrace, but lam for ' 'td-lambda is None.') lam = tf.convert_to_tensor(nc.lam, tf.float32) value_loss = [] for values, rewards in zip(vpred_list, reward_list): value_loss.append( tp_losses.td_lambda(values, rewards, discounts, lam=lam)) shaped_values = tf.matmul(value_head, nc.reward_weights, transpose_b=True) shaped_rewards = tf.matmul(inputs.r, nc.reward_weights, transpose_b=True) values = tf.transpose( tf.reshape(shaped_values, shape=(nc.nrollout, nc.rollout_len))) rewards = tf.transpose( tf.reshape(shaped_rewards, shape=(nc.nrollout, nc.rollout_len))) pg_loss = tf.reduce_sum([ tp_losses.vtrace_loss(neglogp, oldneglogp, mask, values, rewards, discounts, 1.0, 1.0) for oldneglogp, neglogp, mask in zip( oldneglogp_list, neglogp_list, mask_list) ]) value_loss = tf.stack(value_loss) # TODO: maybe more rl endpoints # policy gradient loss must be scalar loss_endpoints['pg_loss'] = pg_loss # value loss can be scalar or vector if len(value_loss.shape) == 0: loss_endpoints['value_loss'] = value_loss else: for i in range(value_loss.shape[0]): loss_endpoints['value_loss_' + str(i)] = value_loss[i] for k, v in entropy_loss.items(): loss_endpoints['ent_' + k] = v if nc.distillation: for k, v in distill_loss.items(): loss_endpoints['distill_' + k] = v loss_endpoints['distill_ab_bf4mins'] = ab_distill_loss else: print('use_loss_type: {}. Nothing done.'.format(nc.use_loss_type)) pass return MNetV6Losses(total_reg_loss=total_reg_loss, total_il_loss=total_il_loss, pg_loss=pg_loss, value_loss=value_loss, entropy_loss=entropy_loss, distill_loss=distill_loss, loss_endpoints=loss_endpoints)
def __init__( self, cell, attention_mechanism, is_manual_attention, # 추가된 argument manual_alignments, # 추가된 argument attention_layer_size=None, alignment_history=False, cell_input_fn=None, output_attention=True, initial_cell_state=None, name=None): """Construct the `AttentionWrapper`. **NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in `AttentionWrapper`, then you must ensure that: - The encoder output has been tiled to `beam_width` via @{tf.contrib.seq2seq.tile_batch} (NOT `tf.tile`). - The `batch_size` argument passed to the `zero_state` method of this wrapper is equal to `true_batch_size * beam_width`. - The initial state created with `zero_state` above contains a `cell_state` value containing properly tiled final state from the encoder. An example: ``` tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch( encoder_outputs, multiplier=beam_width) tiled_encoder_final_state = tf.conrib.seq2seq.tile_batch( encoder_final_state, multiplier=beam_width) tiled_sequence_length = tf.contrib.seq2seq.tile_batch( sequence_length, multiplier=beam_width) attention_mechanism = MyFavoriteAttentionMechanism( num_units=attention_depth, memory=tiled_inputs, memory_sequence_length=tiled_sequence_length) attention_cell = AttentionWrapper(cell, attention_mechanism, ...) decoder_initial_state = attention_cell.zero_state( dtype, batch_size=true_batch_size * beam_width) decoder_initial_state = decoder_initial_state.clone( cell_state=tiled_encoder_final_state) ``` Args: cell: An instance of `RNNCell`. attention_mechanism: A list of `AttentionMechanism` instances or a single instance. attention_layer_size: A list of Python integers or a single Python integer, the depth of the attention (output) layer(s). If None (default), use the context as attention at each time step. Otherwise, feed the context and cell output into the attention layer to generate attention at each time step. If attention_mechanism is a list, attention_layer_size must be a list of the same length. alignment_history: Python boolean, whether to store alignment history from all time steps in the final output state (currently stored as a time major `TensorArray` on which you must call `stack()`). cell_input_fn: (optional) A `callable`. The default is: `lambda inputs, attention: tf.concat([inputs, attention], -1)`. output_attention: Python bool. If `True` (default), the output at each time step is the attention value. This is the behavior of Luong-style attention mechanisms. If `False`, the output at each time step is the output of `cell`. This is the behavior of Bhadanau-style attention mechanisms. In both cases, the `attention` tensor is propagated to the next time step via the state and is used there. This flag only controls whether the attention mechanism is propagated up to the next cell in an RNN stack or to the top RNN output. initial_cell_state: The initial state value to use for the cell when the user calls `zero_state()`. Note that if this value is provided now, and the user uses a `batch_size` argument of `zero_state` which does not match the batch size of `initial_cell_state`, proper behavior is not guaranteed. name: Name to use when creating ops. Raises: TypeError: `attention_layer_size` is not None and (`attention_mechanism` is a list but `attention_layer_size` is not; or vice versa). ValueError: if `attention_layer_size` is not None, `attention_mechanism` is a list, and its length does not match that of `attention_layer_size`. """ super(AttentionWrapper, self).__init__(name=name) self.is_manual_attention = is_manual_attention self.manual_alignments = manual_alignments rnn_cell_impl.assert_like_rnncell("cell", cell) if isinstance(attention_mechanism, (list, tuple)): self._is_multi = True attention_mechanisms = attention_mechanism for attention_mechanism in attention_mechanisms: if not isinstance(attention_mechanism, AttentionMechanism): raise TypeError( "attention_mechanism must contain only instances of " "AttentionMechanism, saw type: %s" % type(attention_mechanism).__name__) else: self._is_multi = False if not isinstance(attention_mechanism, AttentionMechanism): raise TypeError( "attention_mechanism must be an AttentionMechanism or list of " "multiple AttentionMechanism instances, saw type: %s" % type(attention_mechanism).__name__) attention_mechanisms = (attention_mechanism, ) if cell_input_fn is None: cell_input_fn = ( lambda inputs, attention: tf.concat([inputs, attention], -1)) else: if not callable(cell_input_fn): raise TypeError( "cell_input_fn must be callable, saw type: %s" % type(cell_input_fn).__name__) if attention_layer_size is not None: attention_layer_sizes = tuple(attention_layer_size if isinstance( attention_layer_size, (list, tuple)) else (attention_layer_size, )) if len(attention_layer_sizes) != len(attention_mechanisms): raise ValueError( "If provided, attention_layer_size must contain exactly one " "integer per attention_mechanism, saw: %d vs %d" % (len(attention_layer_sizes), len(attention_mechanisms))) self._attention_layers = tuple( layers_core.Dense(attention_layer_size, name="attention_layer", use_bias=False, dtype=attention_mechanisms[i].dtype) for i, attention_layer_size in enumerate(attention_layer_sizes)) self._attention_layer_size = sum(attention_layer_sizes) else: self._attention_layers = None self._attention_layer_size = sum( attention_mechanism.values.get_shape()[-1].value for attention_mechanism in attention_mechanisms) self._cell = cell self._attention_mechanisms = attention_mechanisms self._cell_input_fn = cell_input_fn self._output_attention = output_attention self._alignment_history = alignment_history with tf.name_scope(name, "AttentionWrapperInit"): if initial_cell_state is None: self._initial_cell_state = None else: final_state_tensor = nest.flatten(initial_cell_state)[-1] state_batch_size = (final_state_tensor.shape[0].value or tf.shape(final_state_tensor)[0]) error_message = ( "When constructing AttentionWrapper %s: " % self._base_name + "Non-matching batch sizes between the memory " "(encoder output) and initial_cell_state. Are you using " "the BeamSearchDecoder? You may need to tile your initial state " "via the tf.contrib.seq2seq.tile_batch function with argument " "multiple=beam_width.") with tf.control_dependencies( self._batch_size_checks(state_batch_size, error_message)): self._initial_cell_state = nest.map_structure( lambda s: tf.identity(s, name="check_initial_cell_state"), initial_cell_state)
def main(self): for _ in range(self.config.n_train_steps): system_logs = dict() # fetch the next training batch with U.Timer() as batch_timer: batch = self._exp_fetcher.get() with U.Timer() as step_timer: # run update step on the sampled batch feed_dict = { ph: val for ph, val in zip(nest.flatten(self._traj_phs), nest.flatten(batch)) } profile_kwargs = {} if self.global_step == self._profile_step: profile_kwargs = dict(options=tf.RunOptions( trace_level=tf.RunOptions.FULL_TRACE), run_metadata=tf.RunMetadata()) log_vals = self._agent.update(self.sess, feed_dict, profile_kwargs) if profile_kwargs: self._save_profile(**profile_kwargs) with U.Timer() as log_timer: for logger in self._loggers: logger.write(log_vals) # after first sess.run finishes send the metagraph. if self.global_step == 1: self._send_metagraph() # publish the variables if required. if self._publish_tracker.track_increment(): with U.Timer() as publish_timer: self._publish_variables() system_logs['publish_time_sec'] = publish_timer.to_seconds() # Checkpoint if required if self.global_step % self._checkpoint_every == 0: with U.Timer() as ckpt_timer: self._create_ckpt() system_logs['ckpt_time_sec'] = ckpt_timer.to_seconds() with U.Timer() as system_log_timer: # log system profile for logger in self._system_loggers: logger.write( dict(global_step=self.global_step, sps=self._batch_size * self._traj_length / float(step_timer.to_seconds()), per_step_time_sec=step_timer.to_seconds(), batch_fetch_time_sec=batch_timer.to_seconds(), **system_logs)) system_logs['log_time_sec'] = log_timer.to_seconds( ) + system_log_timer.to_seconds() self._publish_queue.put(None) # exit the thread once training ends.
def update(self, batch): # run update step on the sampled batch feed_dict = {ph: val for ph, val in zip(nest.flatten(self._traj_phs), nest.flatten(batch))} log_vals = self._agent.update(self.sess, feed_dict, {}) return log_vals
def ppo_loss(neglogp, oldneglogp, vpred, R, V, masks=None, reward_weights=None, merge_pi=True, adv_normalize=True, clip_range=0.1, sync_statistics=None): """"PPO loss. Not recommended, use `ppo2_loss` instead. Use it only for backwards compatibility. This ppo loss impl * can handle the structured action heads that it sums (NOT mean) the pg loss over each action head * couples the pg loss and value loss, returning them both which is slightly over-engineering when being convenient to call. Args: neglogp: neglogp, structure as ac_space. oldneglogp: neglogp of pi_old, same structure with neglogp. vpred: predicted v, in shape [batch_size, n_v] R: return from actor, in shape [batch_size, n_v] V: value from actor, in shape [batch_size, n_v] masks: action logits mask in 0/1 value. The same structure and shape with neglogp reward_weights: reward weights in shape [1, n_v] (merge_pi=True or False) or [len(neglogp), n_v] (MUST merge_pi=False) merge_pi: whether to merge pi, if True, original PPO, else split PPO (shared adv or independent adv decided by the shape of reward_weights) adv_normalize: if normalize advantage clip_range: clip range sync_statistics: if synchronize statistics across multi GPUs (if any) Returns: pg_loss: policy loss, a scalar in shape [] vf_loss: value loss, a Tensor in shape [n_v,] """ nest.assert_same_structure(neglogp, oldneglogp) ratio = (tf.stack(nest.flatten(oldneglogp), axis=1) - tf.stack(nest.flatten(neglogp), axis=1)) if masks is not None: nest.assert_same_structure(neglogp, masks) ratio = tf.stack(nest.flatten(masks), axis=1) * ratio if merge_pi: ratio = tf.exp(tf.reduce_sum(ratio, axis=-1, keepdims=True)) else: ratio = tf.exp(ratio) # normalize ADV adv = R - V if reward_weights is not None: adv = tf.matmul(adv, reward_weights, transpose_b=True) batch_mean = tf.reduce_mean(adv, axis=0) batch_mean_square = tf.reduce_mean(tf.square(adv), axis=0) if sync_statistics == 'horovod': # https://github.com/tensorpack/tensorpack/blob/07783edb998cec3ec91c4312b39bd754cf9ececa/tensorpack/models/batch_norm.py#L226-L231 import horovod.tensorflow as hvd batch_mean = hvd.allreduce(batch_mean, average=True) batch_mean_square = hvd.allreduce(batch_mean_square, average=True) adv = adv - batch_mean if adv_normalize: adv = adv / tf.sqrt(batch_mean_square + 1e-8) vpredclipped = V + tf.clip_by_value(vpred - V, -clip_range, clip_range) vf_losses1 = tf.square(vpred - R) vf_losses2 = tf.square(vpredclipped - R) # TODO: add sample weight here. also pg_loss, distill_loos, entropy vf_loss = .5 * tf.reduce_mean(tf.maximum(vf_losses1, vf_losses2), axis=0) pg_losses1 = -adv * ratio pg_losses2 = -adv * tf.clip_by_value(ratio, 1.0 - clip_range, 1.0 + clip_range) pg_loss = tf.reduce_sum( tf.reduce_mean(tf.where( tf.greater(tf.tile(adv, [1, ratio.shape[-1]]), 0), tf.maximum(pg_losses1, pg_losses2), pg_losses2), axis=0)) return pg_loss, vf_loss
def _build(self, inputs, sequence_length=None, initial_state=None, time_major=False, mode=None, **kwargs): """Feeds the inputs through the network and makes classification. The arguments are the same as in :class:`~texar.tf.modules.UnidirectionalRNNEncoder`. Args: inputs: A 3D Tensor of shape `[batch_size, max_time, dim]`. The first two dimensions `batch_size` and `max_time` may be exchanged if `time_major=True` is specified. sequence_length (optional): A 1D int tensor of shape `[batch_size]`. Sequence lengths of the batch inputs. Used to copy-through state and zero-out outputs when past a batch element's sequence length. initial_state (optional): Initial state of the RNN. time_major (bool): The shape format of the :attr:`inputs` and :attr:`outputs` Tensors. If `True`, these tensors are of shape `[max_time, batch_size, depth]`. If `False` (default), these tensors are of shape `[batch_size, max_time, depth]`. mode (optional): A tensor taking value in :tf_main:`tf.estimator.ModeKeys <estimator/ModeKeys>`, including `TRAIN`, `EVAL`, and `PREDICT`. Controls output layer dropout if the output layer is specified with :attr:`hparams`. If `None` (default), :func:`texar.tf.global_mode()` is used. return_cell_output (bool): Whether to return the output of the RNN cell. This is the results prior to the output layer. **kwargs: Optional keyword arguments of :tf_main:`tf.nn.dynamic_rnn <nn/dynamic_rnn>`, such as `swap_memory`, `dtype`, `parallel_iterations`, etc. Returns: A tuple `(logits, pred)`, containing the logits over classes and the predictions, respectively. - If "clas_strategy"=="final_time" or "all_time" - If "num_classes"==1, `logits` and `pred` are of both \ shape `[batch_size]` - If "num_classes">1, `logits` is of shape \ `[batch_size, num_classes]` and `pred` is of shape \ `[batch_size]`. - If "clas_strategy"=="time_wise", - If "num_classes"==1, `logits` and `pred` are of both \ shape `[batch_size, max_time]` - If "num_classes">1, `logits` is of shape \ `[batch_size, max_time, num_classes]` and `pred` is of shape \ `[batch_size, max_time]`. - If `time_major` is `True`, the batch and time dimensions are\ exchanged. """ enc_outputs, _, enc_output_size = self._encoder( inputs=inputs, sequence_length=sequence_length, initial_state=initial_state, time_major=time_major, mode=mode, return_output_size=True, **kwargs) # Flatten enc_outputs enc_outputs_flat = nest.flatten(enc_outputs) enc_output_size_flat = nest.flatten(enc_output_size) enc_output_dims_flat = [np.prod(xs) for xs in enc_output_size_flat] enc_outputs_flat = [ shapes.flatten(x, 2, xs) for x, xs in zip(enc_outputs_flat, enc_output_dims_flat) ] if len(enc_outputs_flat) == 1: enc_outputs_flat = enc_outputs_flat[0] else: enc_outputs_flat = tf.concat(enc_outputs_flat, axis=2) # Compute logits stra = self._hparams.clas_strategy if stra == 'time_wise': logits = enc_outputs_flat elif stra == 'final_time': if time_major: logits = enc_outputs_flat[-1, :, :] else: logits = enc_outputs_flat[:, -1, :] elif stra == 'all_time': if self._logit_layer is None: raise ValueError( 'logit layer must not be `None` if ' 'clas_strategy="all_time". Specify the logit layer by ' 'either passing the layer in the constructor or ' 'specifying the hparams.') if self._hparams.max_seq_length is None: raise ValueError( 'hparams.max_seq_length must not be `None` if ' 'clas_strategy="all_time"') else: raise ValueError( 'Unknown classification strategy: {}'.format(stra)) if self._logit_layer is not None: logit_input_dim = np.sum(enc_output_dims_flat) if stra == 'time_wise': logits, _ = _forward_single_output_layer( logits, logit_input_dim, self._logit_layer) elif stra == 'final_time': logits = self._logit_layer(logits) elif stra == 'all_time': # Pad `enc_outputs_flat` to have max_seq_length before flatten length_diff = self._hparams.max_seq_length - tf.shape( inputs)[1] length_diff = tf.reshape(length_diff, [1, 1]) # Set `paddings = [[0, 0], [0, length_dif], [0, 0]]` paddings = tf.pad(length_diff, paddings=[[1, 1], [1, 0]]) logit_input = tf.pad(enc_outputs_flat, paddings=paddings) logit_input_dim *= self._hparams.max_seq_length logit_input = tf.reshape(logit_input, [-1, logit_input_dim]) logits = self._logit_layer(logit_input) # Compute predications num_classes = self._hparams.num_classes is_binary = num_classes == 1 is_binary = is_binary or (num_classes <= 0 and logits.shape[-1] == 1) if stra == 'time_wise': if is_binary: pred = tf.squeeze(tf.greater(logits, 0), -1) logits = tf.squeeze(logits, -1) else: pred = tf.argmax(logits, axis=-1) else: if is_binary: pred = tf.greater(logits, 0) logits = tf.reshape(logits, [-1]) else: pred = tf.argmax(logits, axis=-1) pred = tf.reshape(pred, [-1]) pred = tf.cast(pred, tf.int64) if not self._built: self._add_internal_trainable_variables() # Add trainable variables of `self._logit_layer` # which may be constructed externally. if self._logit_layer: self._add_trainable_variable( self._logit_layer.trainable_variables) self._built = True return logits, pred
batch_size = 64 inputs, targets = (np.random.rand(1000, 2, 1).astype(np.float32), np.random.rand(1000, 1).astype(np.float32)) dataset = tf.data.Dataset.from_tensor_slices( (inputs, targets)).batch(batch_size).repeat() x, y = dataset.make_one_shot_iterator().get_next() # lstm = LSTMCell(2) cell = MultiRNNCell([LSTMCell(2), LSTMCell(2)]) state = nest.map_structure( lambda x: tf.placeholder_with_default(x, x.shape, x.op.name), cell.zero_state(batch_size, tf.float32)) for tensor in nest.flatten(state): tf.add_to_collection('rnn_state_input', tensor) out, new_state = tf.nn.dynamic_rnn(cell, x, initial_state=state) out = tf.reshape(out, [-1, 1]) pred = tf.layers.Dense(units=1)(out) loss = tf.losses.mean_squared_error(predictions=out, labels=y) for tensor in nest.flatten(new_state): tf.add_to_collection('rnn_state_output', tensor) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) print(sess.run(loss))