def loop_fn(time, prev_output, prev_state, array_targets: tf.TensorArray, array_outputs: tf.TensorArray): """ Main decoder loop :param time: Step number :param prev_output: Output(prediction) from previous step :param prev_state: RNN state tensor from previous step :param array_targets: Predictions, each step will append new value to this array :param array_outputs: Raw RNN outputs (for regularization losses) :return: """ # Append previous predicted value to input features next_input = prev_output # Run RNN cell output, state = cell(next_input, prev_state) # Make prediction from RNN outputs projected_output = project_output(output) # Append step results to the buffer arrays if return_raw_outputs: array_outputs = array_outputs.write(time, output) array_targets = array_targets.write(time, projected_output) # Increment time and return return time + 1, projected_output, state, array_targets, array_outputs
def body(t, output_ta_t: tf.TensorArray, penalty_ta_t: tf.TensorArray): proj_a_t = proj_a_ta.read(t) proj_b_t = proj_b_ta.read(t) sequence_lengths_t = sequence_lengths_ta.read(t) proj_a_t_slice = proj_a_t[:sequence_lengths_t, :] proj_b_t_slice = proj_b_t[:sequence_lengths_t, :] energy = tf.tensordot(proj_a_t_slice, proj_b_t_slice, axes=[(1, ), (1, )]) # energy = energy*(1-tf.eye(sequence_lengths_t)) # mask diagonal # edges = tf.nn.sigmoid(energy) # edges = edges * (1 - tf.eye(sequence_lengths_t)) # mask diagonal edges = tf.nn.softmax(energy, axis=-1) # exp_edges = edges # for _ in range(params.series_depth): # exp_edges = tf.matmul(exp_edges, edges) exp_edges = tf.linalg.expm(input=edges) # penalty_t = tf.maximum(tf.trace(exp_edges) - tf.cast(sequence_lengths_t, tf.float32), 0) penalty_t = tf.trace(exp_edges) - tf.cast(sequence_lengths_t, tf.float32) # penalty_t = tf.reduce_sum(tf.maximum(tf.trace(exp_edges) - tf.cast(sequence_lengths_t, tf.float32), 0)) length_diff = L - sequence_lengths_t edges_padded = tf.pad(tensor=edges, paddings=[(0, length_diff), (0, length_diff)]) output_ta_t1 = output_ta_t.write(value=edges_padded, index=t) penalty_ta_t1 = penalty_ta_t.write(value=penalty_t, index=t) return t + 1, output_ta_t1, penalty_ta_t1
def body(step: int, current_input: tf.Tensor, previous_states: tf.Tensor, outputs: tf.TensorArray): current_states_list = [] for i, cell in enumerate(cells): with tf.variable_scope("rnn_cell_%d" % i): cell_previous_hidden_vector = tf.squeeze(tf.slice( previous_states, [0, i, 0], [-1, 1, -1]), axis=[1]) output, state = cell(current_input, cell_previous_hidden_vector) # Set current_input to output for next cell iteration current_input = output current_states_list.append(state) current_states = tf.concat( [tf.expand_dims(state, axis=1) for state in current_states_list], axis=1) with tf.variable_scope("fully_connected_output"): final_output = tf.contrib.layers.fully_connected( current_input, num_outputs=1, activation_fn=tf.nn.sigmoid) outputs.write(step, final_output) return step + 1, current_input, current_states, outputs
def loop_fn(time, prev_state, array_targets: tf.TensorArray, array_outputs: tf.TensorArray): """ Main rnn loop :param time: Day number :param prev_state: RNN state tensor from previous step :param array_targets: Predictions, each step will append new value to this array :param array_outputs: Raw RNN outputs (for regularization losses) :return: """ # RNN inputs for current step features = inputs_by_time[time] # [batch, train_window, readout_depth * n_heads] -> [batch, readout_depth * n_heads] # Append previous predicted value to input features next_input = tf.concat([features, self.vm_id], axis=1) # Run RNN cell output, state = cell(next_input, prev_state) # Make prediction from RNN outputs projected_output = project_output(output) # Append step results to the buffer arrays if return_raw_outputs: array_outputs = array_outputs.write(time, output) array_targets = array_targets.write(time, projected_output) # Increment time and return return time + 1, state, array_targets, array_outputs
def loop_fn(time, prev_output, prev_state, array_targets: tf.TensorArray, array_outputs: tf.TensorArray): """ Main decoder loop :param time: Day number :param prev_output: Output(prediction) from previous step :param prev_state: RNN state tensor from previous step :param array_targets: Predictions, each step will append new value to this array :param array_outputs: Raw RNN outputs (for regularization losses) :return: """ # RNN inputs for current step features = inputs_by_time[time] # [batch, predict_window, readout_depth * n_heads] -> [batch, readout_depth * n_heads] if attn_features is not None: # [batch_size, 1] + [batch_size, input_depth] attn = attn_features[:, time, :] # Append previous predicted value + attention vector to input features next_input = tf.concat([prev_output, features, attn], axis=1) else: # Append previous predicted value to input features next_input = tf.concat([prev_output, features], axis=1) # Run RNN cell output, state = cell(next_input, prev_state) # Make prediction from RNN outputs projected_output = project_output(output) # Append step results to the buffer arrays if return_raw_outputs: array_outputs = array_outputs.write(time, output) array_targets = array_targets.write(time, projected_output) # Increment time and return return time + 1, projected_output, state, array_targets, array_outputs
def loop_fn(time, prev_output, prev_state, array_targets: tf.TensorArray, array_outputs: tf.TensorArray): next_input = prev_output output, state = cell(next_input, prev_state) projected_output = project_fn(output) array_outputs = array_outputs.write(time, output) array_targets = array_targets.write(time, projected_output) return time + 1, projected_output, state, array_targets, array_outputs
def loop_fn_train(time, prev_output, prev_state, array_targets: tf.TensorArray, array_outputs: tf.TensorArray): next_input = tf.reshape(previous_y[:, time], (-1, 3)) output, state = decoder_cell(next_input, prev_state) projected_output = project_fn(output) array_outputs = array_outputs.write(time, output) array_targets = array_targets.write(time, projected_output) return time + 1, projected_output, state, array_targets, array_outputs
def get_energies(self, y: tf.Tensor, weights_in_time: tf.TensorArray): weight_sum = tf.cond( tf.greater(weights_in_time.size(), 0), lambda: tf.reduce_sum(weights_in_time.stack(), axis=0), lambda: 0.0) coverage = weight_sum / self.fertility * self.attention_mask logits = tf.reduce_sum( self.similarity_bias_vector * tf.tanh(self.hidden_features + y + self.coverage_weights * tf.expand_dims(tf.expand_dims(coverage, -1), -1)), [2, 3]) return logits
def loop_fn_inference(time, prev_output, prev_state, array_targets: tf.TensorArray, array_outputs: tf.TensorArray): next_input = prev_output output, state = decoder_cell(next_input, prev_state) projected_output = project_fn(output) array_outputs = array_outputs.write(time, output) array_targets = array_targets.write(time, projected_output) return time + 1, tf.nn.softmax( projected_output), state, array_targets, array_outputs
def loop_fn(time, prev_output, prev_state, array_targets: tf.TensorArray, array_outputs: tf.TensorArray): """ Main decoder loop. Args: time: time series step number. prev_output: Output(prediction) from previous step. prev_state: RNN state tensor from previous step. array_targets: Predictions, each step will append new value to this array. array_outputs: Raw RNN outputs (for regularization losses) Returns: (time + 1, projected_output, state, array_targets, array_outputs) projected_output: the prediction for this step. state: the updated state for this step. array_targets: the updated targets array. array_outputs: the updated hidden states array. """ # RNN inputs for current step features = inputs_by_time[time] # [batch, predict_window, readout_depth * n_heads] -> [batch, readout_depth * n_heads] # Append previous predicted value to input features next_input = tf.concat([prev_output, features], axis=1) # Run RNN cell output, state = cell(next_input, prev_state) # Make prediction from RNN outputs projected_output = project_output(output) # Append step results to the buffer arrays array_targets = array_targets.write(time, projected_output) # Increment time and return return time + 1, projected_output, state, array_targets, array_outputs
def _make_pairs(index: int, x: tf.TensorArray) -> Tuple[int, tf.TensorArray]: """Make (target, context) pairs for a given target word. Returns the next iteration index (variable), and a TensorArray containing the output values. Args: index: The index of the target word in the sequence tensor. x: Collection holding the output values. """ if randomly_offset: shift = tf.random.uniform((), maxval=window_size, dtype=tf.int32) else: shift = 0 # Calculate indices of context words to the left and right of the target. left = tf.range(tf.maximum(0, index - window_size + shift), index) right = tf.range(index + 1, tf.minimum(n, index + 1 + window_size - shift)) # Concatenate left and right tensors contexts = tf.concat([left, right], axis=0) contexts = tf.gather(sequence, contexts) # Create (target, context) pairs targets = tf.fill(tf.shape(contexts), sequence[index]) pairs = tf.stack([targets, contexts], axis=1) # Output values return index + 1, x.write(index, pairs)
def pad_prediction_tfarray(tfarray: tf.TensorArray, blank: int or tf.Tensor) -> tf.TensorArray: with tf.name_scope("pad_prediction_tfarray"): index = tf.constant(0, dtype=tf.int32) total = tfarray.size() max_length = find_max_length_prediction_tfarray(tfarray) def condition(index, _): return tf.less(index, total) def body(index, tfarray): prediction = tfarray.read(index) prediction = tf.pad( prediction, paddings=[[0, max_length - tf.shape(prediction)[0]]], mode="CONSTANT", constant_values=blank) tfarray = tfarray.write(index, prediction) return index + 1, tfarray index, tfarray = tf.while_loop(condition, body, loop_vars=[index, tfarray], swap_memory=False) return tfarray
def condition(time, all_outputs: tf.TensorArray, inputs, states): def check_outputs_ends(): def has_end_word(t): return tf.reduce_any(tf.equal(t, ANSWER_MAX)) output_label = tf.arg_max(all_outputs.stack(), 2) output_label = tf.Print(output_label, [output_label], "Output Labels: ") # The outputs are time-major, which means time is the first # dimension. Here I need to check whether all the generated # answers are ends with "</s>", so we need to transpose it # to batch-major. Because `map_fn` only map function by the # first dimension. batch_major_outputs = tf.transpose(output_label, (1, 0)) all_outputs_ends = tf.reduce_all( tf.map_fn(has_end_word, batch_major_outputs, dtype=tf.bool)) return all_outputs_ends # If the TensorArray has 0 size, stack() will trigger error, # so I have to use condition function to check whether the # size is 0. all_ends = tf.cond(tf.equal(all_outputs.size(), 0), lambda: tf.constant(False, tf.bool), check_outputs_ends) condition_result = tf.logical_and(tf.logical_not(all_ends), tf.less(time, ANSWER_MAX)) return condition_result
def body_fn(time, all_outputs: tf.TensorArray, inputs, state: LSTMStateTuple): with tf.variable_scope("body_fn"): if not use_generated_inputs: next_inputs = inputs inputs = inputs[:, time, :] # context: (batch, feature_size) # alpha: (batch, position_num) context, alpha = self._attention_layer(features, state.h) # todo: alpha regularization if self.selector: with tf.variable_scope("selector"): beta = fully_connected(state.h, num_outputs=1, activation_fn=tf.nn.sigmoid, weights_initializer=self.weight_initializer, biases_initializer=self.const_initializer) context = tf.multiply(beta, context, name="selected_context") # decoder_input: (batch, embedding_size + feature_size) decoder_input = tf.concat(values=[inputs, context], axis=1, name="decoder_input") output, nxt_state = rnn_cell(decoder_input, state=state) logits = self._decode_rnn_outputs(output, context, inputs, dropout=dropout) all_outputs = all_outputs.write(time, logits) if use_generated_inputs: # todo: if hard attention is used, the policy gradient should be the distribution log likelihood: # which means, we need to cache the sampled logit and then calc gradient of it with respect to weights. next_inputs = self._word_embedding(self._sampler(logits), reuse=True) return time + 1, all_outputs, next_inputs, nxt_state
def __stack_and_pad(tensor_array: tf.TensorArray, length: int): stacked = tensor_array.stack() stacked_shape = tf.shape(stacked) padding_size = length - stacked_shape[0] stacked_shape_tail = stacked_shape[1:] padding = tf.zeros(tf.concat([[padding_size], stacked_shape_tail], axis=0), dtype=tensor_array.dtype, name="tensor_array_padding") return tf.concat([stacked, padding], axis=0)
def _peel_element_from_iblt( self, iblt: tf.Tensor, iblt_values: tf.Tensor, out_strings: tf.TensorArray, out_counts: tf.TensorArray, out_tensor_values: tf.TensorArray ) -> Tuple[tf.Tensor, tf.Tensor, tf.TensorArray, tf.TensorArray, tf.TensorArray]: """Peels an element from IBLT and adds new peelable elements to queue.""" repetition, index = self.q.dequeue() iblt, hash_indices, data_string, count = self._decode_and_remove( iblt, repetition, index) tensor_value = self._decode_value(iblt_values, repetition, index) iblt_values = self._remove_value(iblt_values, hash_indices, tensor_value) if tf.strings.length(data_string) > 0: index = out_counts.size() out_counts = out_counts.write(index, count) out_strings = out_strings.write(index, data_string) out_tensor_values = out_tensor_values.write(index, tensor_value) for r in tf.range(self.repetitions, dtype=self._dtype): if self._is_peelable(iblt, r, hash_indices[r]): self.q.enqueue((r, hash_indices[r])) return iblt, iblt_values, out_strings, out_counts, out_tensor_values
def body(i, arr: tf.TensorArray): a_t = a_ta.read(i) #a_n = tf.norm(a_t, axis=-1, ord=2) b_t = b_ta.read(i) b_lengths_t = b_lengths_ta.read(i) b_part = b_t[:b_lengths_t, :] # (bl, d) b_n = tf.norm(b_part, axis=-1, ord=2) energy = tf.tensordot(a_t, b_part, axes=[(1, ), (1, )]) # (al, bl) energy = energy / tf.expand_dims(b_n, 0) attn = tf.nn.softmax(energy, axis=-1) pattn = tf.pad(attn, [[0, 0], [0, bl - b_lengths_t]]) return i + tf.constant(1, dtype=tf.int32), arr.write(i, pattn)
def initialise_tensor_arrays(height, width, num_diag, units): linear_inds_ta = TensorArray(dtype=tf.int32, size=num_diag, element_shape=tf.TensorShape([None]), clear_after_read=False, name='linear_inds', infer_shape=False) multi_inds_ta = TensorArray(dtype=tf.int32, size=num_diag, element_shape=tf.TensorShape([None, 2]), clear_after_read=False, name='mult_inds', infer_shape=False) activations_ta = TensorArray(dtype=tf.float32, size=height * width, element_shape=tf.TensorShape([None, units, 4])) return linear_inds_ta, multi_inds_ta, activations_ta
def body(step, batch_states_ta: tf.TensorArray, batch_outputs_ta: tf.TensorArray, batch_outputs_counts_ta: tf.TensorArray, batch_step_counts_ta: tf.TensorArray): with tf.variable_scope("initial_hidden_vector"): current_initial_hidden_vector_input = tf.gather( initial_hidden_vector_input, step, name="current_initial_hidden_vector_input") current_hidden_vector = self.__create_fully_connected_layers( current_initial_hidden_vector_input, [self.hidden_vector_size]) with tf.variable_scope("step_while_loop"): current_step_count = tf.gather(truth_padded_data.step_counts, step, name="current_step_count") current_outputs_padded = tf.gather( truth_padded_data.outputs_padded, step, name="current_outputs_padded") current_outputs_counts = tf.gather( truth_padded_data.outputs_counts, step, name="current_outputs_counts") current_states, current_outputs, current_outputs_counts, current_step_count = \ self.__step_while_loop( current_step_count, current_outputs_padded, current_outputs_counts, current_hidden_vector) return \ step + 1, \ batch_states_ta.write(step, current_states, "write_batch_states"), \ batch_outputs_ta.write(step, current_outputs, "write_batch_outputs"), \ batch_outputs_counts_ta.write(step, current_outputs_counts, "write_batch_outputs_counts"), \ batch_step_counts_ta.write(step, current_step_count, "write_step_counts")
def loop_fn(time, prev_output, prev_state, array_targets: tf.TensorArray, array_outputs: tf.TensorArray): """ Main decoder loop :param time: Day number :param prev_output: Output(prediction) from previous step (?,1) :param prev_state: RNN state tensor from previous step (1,?,267) :param array_targets: Predictions, each step will append new value to this array 预测结果 :param array_outputs: Raw RNN outputs (for regularization losses) 原始解码输出 :return: """ # RNN inputs for current step (?,23) features = inputs_by_time[time] # [batch, predict_window, readout_depth * n_heads] -> [batch, readout_depth * n_heads] # 根据上次预测结果加工成输入特征 if attn_features is not None: # (?,63,64)--(?,64) attn = attn_features[:, time, :] # Append previous predicted value + attention vector to input features (?,1)+(?,23)+(?,64)==(?,88) next_input = tf.concat([prev_output, features, attn], axis=1) else: # Append previous predicted value to input features (?,24) next_input = tf.concat([prev_output, features], axis=1) # Run RNN cell (?,88)-(?,267) with tf.variable_scope('decoder_12'): output, state = cell(next_input, prev_state) # Make prediction from RNN outputs (?,1) projected_output = project_output(output) # Append step results to the buffer arrays if return_raw_outputs: array_outputs = array_outputs.write(time, output) array_targets = array_targets.write(time, projected_output) # Increment time and return # output-转成预测值,state,[预测值],[output] return time + 1, projected_output, state, array_targets, array_outputs
def _serve_tfrecord(sequence_example_protos): input_size = tf.shape(sequence_example_protos)[0] features_dict = { feature: TensorArray(dtype=tf.float32, size=input_size) for feature in inputs } # Define loop index i = tf.constant(0) # Define loop condition def loop_condition(i, sequence_example_protos, features_dict): return tf.less(i, input_size) # Define loop body def loop_body(i, sequence_example_protos, features_dict): """ TODO: Modify parse_fn from parse_single_sequence_example -> parse_sequence_example to handle a batch of TFRecord proto """ features, labels = tfrecord_parse_fn( sequence_example_protos[i]) for feature, feature_val in features.items(): features_dict[feature] = features_dict[feature].write( i, tf.cast(feature_val, tf.float32)) i += 1 return i, sequence_example_protos, features_dict # Parse all SequenceExample protos to get features _, _, features_dict = tf.while_loop( cond=loop_condition, body=loop_body, loop_vars=[i, sequence_example_protos, features_dict], ) # Convert TensorArray to tensor features_dict = {k: v.stack() for k, v in features_dict.items()} # Run the model to get predictions predictions = self.model(inputs=features_dict) # Mask the padded records for key, value in predictions.items(): predictions[key] = tf.where(tf.equal(features_dict["mask"], 0), tf.constant(-np.inf), predictions[key]) return predictions
def skip_gram_pairs_from_word( word_index: int, skip_grams_array: tf.TensorArray) -> Tuple[int, tf.TensorArray]: """ Helper method for generating skip-gram target/context pairs from a single word integer. Parameters ---------- word_index : int Word integer representation of word. skip_grams_array : tf.TensorArray TensorArray containing generated skip-gram target/context pairs. Returns ------- next_word_index : int Next word_index to generate from. next_skip_grams_array : tf.TensorArray TensorArray containing newly generated skip-gram target/context pairs. """ # Get word integer word_int = word_indices[word_index] # Randomly sample window size window_size = tf.random.uniform([], minval=1, maxval=max_window_size + 1, dtype=tf.int32) # Generate positive samples left = tf.range(tf.maximum(word_index - window_size, 0), word_index) right = tf.range( word_index + 1, tf.minimum(word_index + 1 + window_size, tf.size(word_indices)), ) context_indices = tf.concat([left, right], axis=0) context_word_indices = tf.gather(word_indices, context_indices) positive_samples = tf.stack( [ tf.fill(tf.shape(context_word_indices), word_int), context_word_indices ], axis=1, ) return word_index + 1, skip_grams_array.write(word_index, positive_samples)
def find_max_length_prediction_tfarray(tfarray: tf.TensorArray) -> tf.Tensor: with tf.name_scope("find_max_length_prediction_tfarray"): index = tf.constant(0, dtype=tf.int32) total = tfarray.size() max_length = tf.constant(0, dtype=tf.int32) def condition(index, _): return tf.less(index, total) def body(index, max_length): prediction = tfarray.read(index) length = tf.shape(prediction)[0] max_length = tf.where(tf.greater(length, max_length), length, max_length) return index + 1, max_length index, max_length = tf.while_loop(condition, body, loop_vars=[index, max_length], swap_memory=False) return max_length
def _serve_tfrecord(protos): input_size = tf.shape(protos)[0] features_dict = { feature: TensorArray(dtype=dtype_map[feature], size=input_size) for feature in inputs } # Define loop index i = tf.constant(0) # Define loop condition def loop_condition(i, protos, features_dict): return tf.less(i, input_size) # Define loop body def loop_body(i, protos, features_dict): features, labels = tfrecord_parse_fn(protos[i]) for feature, feature_val in features.items(): features_dict[feature] = features_dict[feature].write( i, feature_val) i += 1 return i, protos, features_dict # Parse all SequenceExample protos to get features _, _, features_dict = tf.while_loop( cond=loop_condition, body=loop_body, loop_vars=[i, protos, features_dict], ) # Convert TensorArray to tensor features_dict = {k: v.stack() for k, v in features_dict.items()} # Run the model to get predictions predictions = model(inputs=features_dict) # Define a post hook if postprocessing_fn: predictions = postprocessing_fn(predictions, features_dict) return predictions
def condition(time, all_outputs: tf.TensorArray, caps, states): def has_end_word(t): return tf.reduce_any(tf.equal(t, END_WORD_INDEX)) def check_all_ends(): word_indexes = tf.argmax(all_outputs.stack(), axis=2) word_indexes = tf.transpose(word_indexes, [1, 0]) end_word_flags = tf.map_fn(has_end_word, word_indexes, dtype=tf.bool) check_res = tf.reduce_all(end_word_flags) return check_res with tf.variable_scope("cond_fn"): all_outputs_size = all_outputs.size() is_first_frame = tf.equal(all_outputs_size, 0) gen_ends = tf.cond(is_first_frame, lambda: tf.constant(False, tf.bool), check_all_ends) cond_res = tf.logical_and(tf.logical_not(gen_ends), tf.less(time, max_length)) return cond_res
def loop_body(i, array: tf.TensorArray): step_output = self.single_convolution(inputs, i) array = array.write(i, step_output) i += 1 return i, array
def body(i, arr: tf.TensorArray): energy_t = energy_ta.read(i) b_lengths_t = b_lengths_ta.read(i) attn = tf.nn.softmax(energy_t[:, :b_lengths_t], axis=-1) pattn = tf.pad(attn, [[0, 0], [0, bl - b_lengths_t]]) return i + tf.constant(1, dtype=tf.int32), arr.write(i, pattn)
def fast_MD_dynamic(input_data, units): """ carries out iteration over diagonals input input_data = (b,h,w,i,d) where d are the 4 direcitons units number of units in cell TODO calculate indices once and reuse """ _, height, width, inp_size, directions = input_data.get_shape().as_list() batch_size = tf.shape(input_data)[0] # make input height, width, batch, inp, direction input_data_transposed = tf.transpose(input_data, (1, 2, 0, 3, 4)) # needs to be square for current implemntation assert height == width # construct diagonal lstm cell # cell(inp, acti, cell) = acti, cell cell = diagonal_lstm(units, inp_size) # intial values num_diag = 2 * (height - 1) + 1 zeros = tf.stack([batch_size, 2, units, directions]) current_activations = tf.fill(zeros, 0.0) initial_state = tf.fill([batch_size, 1, units, directions], 0.0) current_states = tf.tile(initial_state, [1, 2, 1, 1]) diagonal = tf.constant(0) # will ultimately store our activations # when stacked will be h,w,b,u,d activations_ta = TensorArray(dtype=tf.float32, size=height * width, element_shape=tf.TensorShape([None, units, 4])) def pad_with_initial(tensor): """pads for edge activations/cells""" added_bot = tf.concat([tensor, initial_state], axis=1) added_all = tf.concat([initial_state, added_bot], axis=1) return added_all def body(activations_ta, current_activations, current_states, diagonal): """ process diagonal 0, 1, 2, ... """ # Get the diagonal values of the input # b x d x inp_size x direction input_diagonal = get_diagonal_values(diagonal, input_data_transposed) # need to pad aci/cell except in first iteration not_first_acti = tf.cond( diagonal < height, lambda: tf.pad( current_activations, [[0, 0], [1, 1], [0, 0], [0, 0]]), lambda: current_activations) current_activations = tf.cond(tf.equal(diagonal, 0), lambda: current_activations, lambda: not_first_acti) not_first_cell = tf.cond(diagonal < height, lambda: pad_with_initial(current_states), lambda: current_states) current_states = tf.cond(tf.equal(diagonal, 0), lambda: current_states, lambda: not_first_cell) # work out new activations current_activations, current_states = cell(input_diagonal, current_activations, current_states) # batch x diagonal x unit x direction current_states.set_shape([None, None, units, directions]) current_activations.set_shape([None, None, units, directions]) # get indices to place into activations indices = get_single_diagonal_indices(height, width, diagonal) # we transpose so that correct values from current activations go in the correct place # scatter works by using the first index # thus activations contains # batch x units x direction activations_ta = activations_ta.scatter( indices, tf.transpose(current_activations, (1, 0, 2, 3))) diagonal += 1 return activations_ta, current_activations, current_states, diagonal def cond(activations_ta, current_activations, current_states, diagonal): return diagonal < num_diag acti_shape = tf.TensorShape([None, None, units, directions]) cell_shape = tf.TensorShape([None, None, units, directions]) diag_shape = tf.TensorShape([]) ta_shape = tf.TensorShape(None) returned = tf.while_loop( cond=cond, body=body, loop_vars=[ activations_ta, current_activations, current_states, diagonal ], name='looooop', shape_invariants=[ta_shape, acti_shape, cell_shape, diag_shape], swap_memory=True) activations = returned[0].stack() activations.set_shape([height * width, None, units, directions]) activations = tf.transpose(activations, (1, 0, 2, 3)) activations = tf.split(activations, num_or_size_splits=height, axis=1) activations = tf.stack(activations, 1) return activations
def body(step: int, stack_1, stack_2, states_ta: tf.TensorArray, outputs_ta: tf.TensorArray, outputs_counts_ta: tf.TensorArray, return_value: tf.Tensor): # stack_1 = tf.Print(stack_1, [step, tf.slice(stack_1, [0, 1], [-1, 2])], "stack: ", summarize=100) # Rebuild `stack` tuple # TODO: Find way to avoid this by putting tuples into `tf.while_loop` arguments stack = stack_1, stack_2 # Get the state and hidden vector we have to deal with for this iteration state, hidden_vector, stack = state_stack.pop(stack) # Get the summary for all hidden vectors excluding this one hidden_vector_summary = state_stack.get_hidden_vector_summary( stack) # Get the number of outputs for padding # TODO: Doing this twice, once here, and once when calling create_guess_layers. Fix this num_outputs = tf.case(pred_fn_pairs=[ (tf.equal(state, i), lambda i=i: tf.constant( self.object_type.get_all_states()[i].num_outputs)) for i in range(len(self.object_type.get_all_states())) ], default=lambda: tf.constant(0)) # Call `create_guess_layers(...)` depending on what state we're in next_hidden_vector, current_choice = tf.case( pred_fn_pairs=[ (tf.equal(state, i), lambda i=i: create_guess_layers( hidden_vector_summary, return_value, hidden_vector, i)) for i in range(len(self.object_type.get_all_states())) ], default=lambda: (hidden_vector, tf.constant( 0, dtype=tf.float32) / tf.constant(0, tf.float32))) # Zero pad the current choice current_choice = tf.concat([ current_choice, tf.zeros([self.max_outputs - tf.shape(current_choice)[0]]) ], axis=0, name="current_choice_zero_padded") # Reshape the hidden vector so we know what size it is next_hidden_vector = tf.reshape(next_hidden_vector, [self.hidden_vector_size], name="next_hidden_vector_reshaped") if self.training: # If we're training, the choice we send to the update_state_stack_fn should be determined by the truth stack_update_choice = tf.gather(truth_outputs_padded, step, name="choice_from_input") else: # Otherwise, the choice should be what we outputted stack_update_choice = current_choice # stack = (tf.Print(stack[0], [step, stack[0][:stack_2]], "estack: ", summarize=100), stack[1]) # Update the state stack stack, return_value = self.__update_state_stack( state, stack, next_hidden_vector, stack_update_choice) return \ step + 1, \ (*stack), \ states_ta.write(step, state, "write_state"), \ outputs_ta.write(step, current_choice, "write_outputs"), \ outputs_counts_ta.write(step, num_outputs, "write_outputs_count"), \ return_value
def fallback_stack_tensor_array(self, arr: tf.TensorArray): arr_size = arr.size() results = arr.gather(tf.range(arr_size)) return results