def __init__(self, input_width, proposal_width, graph_spec): """ Params: input_width: Integer giving size of input proposal_width: Size of internal proposal graph_spec: Instance of GraphStateSpec giving graph spec """ self._input_width = input_width self._graph_spec = graph_spec self._proposal_width = proposal_width self._proposer_gru = BaseGRULayer(input_width, proposal_width, name="newnodes_proposer") self._proposer_stack = LayerStack(proposal_width, 1 + graph_spec.num_node_ids, [proposal_width], bias_shift=3.0, name="newnodes_proposer_post") isize = 2 * graph_spec.num_node_ids + graph_spec.node_state_size self._vote_stack = LayerStack(isize, 1, [isize], activation=T.nnet.sigmoid, bias_shift=-3.0, name="newnodes_vote")
def __init__(self, input_width, state_size, num_words): self._input_width = input_width self._state_size = state_size self._num_words = num_words self._seq_gru = BaseGRULayer(input_width, state_size, name="output_seq_gru") self._transform_stack = LayerStack(state_size, num_words, activation=T.nnet.softmax, name="output_seq_transf")
class OutputSequenceTransformation(object): """ Transforms a representation vector into a sequence of outputs """ def __init__(self, input_width, state_size, num_words): self._input_width = input_width self._state_size = state_size self._num_words = num_words self._seq_gru = BaseGRULayer(input_width, state_size, name="output_seq_gru") self._transform_stack = LayerStack(state_size, num_words, activation=T.nnet.softmax, name="output_seq_transf") @property def params(self): return self._seq_gru.params + self._transform_stack.params def process(self, input_vector, seq_len): """ Convert an input vector into a sequence of categorical distributions Params: input_vector: Vector of shape (n_batch, input_width) seq_len: How many outputs to produce Returns: Sequence distribution of shape (n_batch, seq_len, num_words) """ n_batch = input_vector.shape[0] outputs_info = [self._seq_gru.initial_state(n_batch)] scan_step = lambda state, ipt: self._seq_gru.step(ipt, state) all_out, _ = theano.scan(scan_step, non_sequences=[input_vector], n_steps=seq_len, outputs_info=outputs_info) # all_out is of shape (seq_len, n_batch, state_size). Squash and apply layer flat_out = all_out.reshape([-1, self._state_size]) flat_final = self._transform_stack.process(flat_out) final = flat_final.reshape([seq_len, n_batch, self._num_words]).dimshuffle([1, 0, 2]) return final def snap_to_best(self, answer): """ Convert output of process to the "best" answer, i.e. the answer with highest probability. """ return categorical_best(answer)
def __init__(self, input_width, graph_spec, dropout_keep=1): """ Params: input_width: Integer giving size of input graph_spec: Instance of GraphStateSpec giving graph spec """ self._input_width = input_width self._graph_spec = graph_spec self._update_gru = BaseGRULayer(input_width + graph_spec.num_node_ids, graph_spec.node_state_size, name="nodestateupdate", dropout_keep=dropout_keep)
def __init__(self, transfer_size, graph_spec, transfer_activation=identity, dropout_keep=1): """ Params: transfer_size: Integer, how much to transfer graph_spec: Instance of GraphStateSpec giving graph spec transfer_activation: Activation function to use during transfer """ self._transfer_size = transfer_size self._transfer_activation = transfer_activation self._graph_spec = graph_spec self._process_input_size = graph_spec.num_node_ids + graph_spec.node_state_size self._transfer_stack = LayerStack(self._process_input_size, 2 * graph_spec.num_edge_types * transfer_size, activation=self._transfer_activation, name="propagation_transfer", dropout_keep=dropout_keep, dropout_input=False, dropout_output=True) self._propagation_gru = BaseGRULayer(graph_spec.num_node_ids + self._transfer_size, graph_spec.node_state_size, name="propagation", dropout_keep=dropout_keep, dropout_input=False, dropout_output=True)
def __init__(self, num_words, num_node_ids, word_node_mapping, output_width): """ num_words: Number of words in the input sequence word_node_mapping: Mapping of word idx to node idx for direct mapping """ self._num_words = num_words self._num_node_ids = num_node_ids self._word_node_mapping = word_node_mapping self._output_width = output_width self._word_node_matrix = np.zeros([num_words, num_node_ids], np.float32) for word, node in word_node_mapping.items(): self._word_node_matrix[word, node] = 1.0 self._gru = BaseGRULayer(num_words, output_width, name="input_sequence")
def __init__(self, input_width, inform_width, proposal_width, graph_spec, use_old_aggregate=False, dropout_keep=1): """ Params: input_width: Integer giving size of input inform_width: Size of internal aggregate proposal_width: Size of internal proposal graph_spec: Instance of GraphStateSpec giving graph spec use_old_aggregate: Use the old aggregation mode """ self._input_width = input_width self._graph_spec = graph_spec self._proposal_width = proposal_width self._inform_width = inform_width aggregate_type = AggregateRepresentationTransformationSoftmax \ if use_old_aggregate \ else AggregateRepresentationTransformation self._inform_aggregate = aggregate_type(inform_width, graph_spec, dropout_keep, dropout_output=True) self._proposer_gru = BaseGRULayer(input_width+inform_width, proposal_width, name="newnodes_proposer", dropout_keep=dropout_keep, dropout_input=False, dropout_output=True) self._proposer_stack = LayerStack(proposal_width, 1+graph_spec.num_node_ids, [proposal_width], bias_shift=3.0, name="newnodes_proposer_post", dropout_keep=dropout_keep, dropout_input=False)
class NewNodesVoteTransformation(object): """ Transforms a graph state by adding nodes, conditioned on an input vector """ def __init__(self, input_width, proposal_width, graph_spec): """ Params: input_width: Integer giving size of input proposal_width: Size of internal proposal graph_spec: Instance of GraphStateSpec giving graph spec """ self._input_width = input_width self._graph_spec = graph_spec self._proposal_width = proposal_width self._proposer_gru = BaseGRULayer(input_width, proposal_width, name="newnodes_proposer") self._proposer_stack = LayerStack(proposal_width, 1 + graph_spec.num_node_ids, [proposal_width], bias_shift=3.0, name="newnodes_proposer_post") isize = 2 * graph_spec.num_node_ids + graph_spec.node_state_size self._vote_stack = LayerStack(isize, 1, [isize], activation=T.nnet.sigmoid, bias_shift=-3.0, name="newnodes_vote") @property def params(self): return self._proposer_gru.params + self._proposer_stack.params + self._vote_stack.params @property def num_dropout_masks(self): return self._proposer_gru.num_dropout_masks def get_dropout_masks(self, srng, keep_frac): return self._proposer_gru.get_dropout_masks(srng, keep_frac) def get_candidates(self, gstate, input_vector, max_candidates, dropout_masks=None): """ Get the current candidate new nodes. This is accomplished as follows: 1. The proposer network, conditioned on the input vector, proposes multiple candidate nodes, along with a confidence 2. Every existing node, conditioned on its own state and the candidate, votes on whether or not to accept this node 3. A new node is created for each candidate node, with an existence strength given by confidence * [product of all votes], and an initial state state as proposed This method directly returns these new nodes for comparision Params: gstate: A GraphState giving the current state input_vector: A tensor of the form (n_batch, input_width) max_candidates: Integer, limit on the number of candidates to produce Returns: new_strengths: A tensor of the form (n_batch, new_node_idx) new_ids: A tensor of the form (n_batch, new_node_idx, num_node_ids) """ n_batch = gstate.n_batch n_nodes = gstate.n_nodes outputs_info = [self._proposer_gru.initial_state(n_batch)] proposer_step = lambda st, ipt, *dm: self._proposer_gru.step( ipt, st, dm if dropout_masks is not None else None) raw_proposal_acts, _ = theano.scan( proposer_step, n_steps=max_candidates, non_sequences=[input_vector] + (dropout_masks if dropout_masks is not None else []), outputs_info=outputs_info) # raw_proposal_acts is of shape (candidate, n_batch, blah) flat_raw_acts = raw_proposal_acts.reshape([-1, self._proposal_width]) flat_processed_acts = self._proposer_stack.process(flat_raw_acts) candidate_strengths = T.nnet.sigmoid( flat_processed_acts[:, 0]).reshape([max_candidates, n_batch]) candidate_ids = T.nnet.softmax(flat_processed_acts[:, 1:]).reshape( [max_candidates, n_batch, self._graph_spec.num_node_ids]) # Votes will be of shape (candidate, n_batch, n_nodes) # To generate this we want to assemble (candidate, n_batch, n_nodes, input_stuff), # squash to (parallel, input_stuff), do voting op, then unsquash candidate_id_part = T.shape_padaxis(candidate_ids, 2) node_id_part = T.shape_padaxis(gstate.node_ids, 0) node_state_part = T.shape_padaxis(gstate.node_states, 0) full_vote_input = broadcast_concat( [node_id_part, node_state_part, candidate_id_part], 3) flat_vote_input = full_vote_input.reshape( [-1, full_vote_input.shape[-1]]) vote_result = self._vote_stack.process(flat_vote_input) final_votes_no = vote_result.reshape( [max_candidates, n_batch, n_nodes]) weighted_votes_yes = 1 - final_votes_no * T.shape_padleft( gstate.node_strengths) # Add in the strength vote all_votes = T.concatenate( [T.shape_padright(candidate_strengths), weighted_votes_yes], 2) # Take the product -> (candidate, n_batch) chosen_strengths = T.prod(all_votes, 2) new_strengths = chosen_strengths.dimshuffle([1, 0]) new_ids = candidate_ids.dimshuffle([1, 0, 2]) return new_strengths, new_ids def process(self, gstate, input_vector, max_candidates, dropout_masks=None): """ Process an input vector and update the state accordingly. """ new_strengths, new_ids = self.get_candidates(gstate, input_vector, max_candidates, dropout_masks) new_gstate = gstate.with_additional_nodes(new_strengths, new_ids) return new_gstate
class PropagationTransformation( object ): """ Transforms a graph state by propagating info across the graph """ def __init__(self, transfer_size, graph_spec, transfer_activation=identity, dropout_keep=1): """ Params: transfer_size: Integer, how much to transfer graph_spec: Instance of GraphStateSpec giving graph spec transfer_activation: Activation function to use during transfer """ self._transfer_size = transfer_size self._transfer_activation = transfer_activation self._graph_spec = graph_spec self._process_input_size = graph_spec.num_node_ids + graph_spec.node_state_size self._transfer_stack = LayerStack(self._process_input_size, 2 * graph_spec.num_edge_types * transfer_size, activation=self._transfer_activation, name="propagation_transfer", dropout_keep=dropout_keep, dropout_input=False, dropout_output=True) self._propagation_gru = BaseGRULayer(graph_spec.num_node_ids + self._transfer_size, graph_spec.node_state_size, name="propagation", dropout_keep=dropout_keep, dropout_input=False, dropout_output=True) @property def params(self): return self._propagation_gru.params + self._transfer_stack.params def dropout_masks(self, srng, state_mask=None): return self._transfer_stack.dropout_masks(srng) + self._propagation_gru.dropout_masks(srng, use_output=state_mask) def split_dropout_masks(self, dropout_masks): transfer_used, dropout_masks = self._transfer_stack.split_dropout_masks(dropout_masks) gru_used, dropout_masks = self._propagation_gru.split_dropout_masks(dropout_masks) return (transfer_used+gru_used), dropout_masks def process(self, gstate, dropout_masks=Ellipsis): """ Process a graph state. 1. Data is transfered from each node to each other node along both forward and backward edges. This data is processed with a Wx+b style update, and an optional transformation is applied 2. Nodes sum the transfered data, weighted by the existence of the other node and the edge. 3. Nodes perform a GRU update with this input Params: gstate: A GraphState giving the current state """ if dropout_masks is Ellipsis: dropout_masks = None append_masks = False else: append_masks = True node_obs = T.concatenate([gstate.node_ids, gstate.node_states],2) flat_node_obs = node_obs.reshape([-1, self._process_input_size]) transformed, dropout_masks = self._transfer_stack.process(flat_node_obs,dropout_masks) transformed = transformed.reshape([gstate.n_batch, gstate.n_nodes, 2*self._graph_spec.num_edge_types, self._transfer_size]) scaled_transformed = transformed * T.shape_padright(T.shape_padright(gstate.node_strengths)) # scaled_transformed is of shape (n_batch, n_nodes, 2*num_edge_types, transfer_size) # We want to multiply through by edge strengths, which are of shape # (n_batch, n_nodes, n_nodes, num_edge_types), both fwd and backward edge_strength_scale = T.concatenate([gstate.edge_strengths, gstate.edge_strengths.swapaxes(1,2)], 3) # edge_strength_scale is of (n_batch, n_nodes, n_nodes, 2*num_edge_types) intermed = T.shape_padaxis(scaled_transformed, 2) * T.shape_padright(edge_strength_scale) # intermed is of shape (n_batch, n_nodes "source", n_nodes "dest", 2*num_edge_types, transfer_size) # now reduce along the "source" and "edge_types" dimensions to get dest activations # of shape (n_batch, n_nodes, transfer_size) reduced_result = T.sum(T.sum(intermed, 3), 1) # now add information fom current node id full_input = T.concatenate([gstate.node_ids, reduced_result], 2) # we flatten to apply GRU flat_input = full_input.reshape([-1, self._graph_spec.num_node_ids + self._transfer_size]) flat_state = gstate.node_states.reshape([-1, self._graph_spec.node_state_size]) new_flat_state, dropout_masks = self._propagation_gru.step(flat_input, flat_state, dropout_masks) new_node_states = new_flat_state.reshape(gstate.node_states.shape) new_gstate = gstate.with_updates(node_states=new_node_states) if append_masks: return new_gstate, dropout_masks else: return new_gstate def process_multiple(self, gstate, iterations, dropout_masks=Ellipsis): """ Run multiple propagagtion steps. Params: gstate: A GraphState giving the current state iterations: An integer. How many steps to propagate """ if dropout_masks is Ellipsis: dropout_masks = None append_masks = False else: append_masks = True def _scan_step(cur_node_states, node_strengths, node_ids, edge_strengths, *dmasks): curstate = GraphState(node_strengths, node_ids, cur_node_states, edge_strengths) newstate, _ = self.process(curstate, dmasks if dropout_masks is not None else None) return newstate.node_states outputs_info = [gstate.node_states] used_dropout_masks, dropout_masks = self.split_dropout_masks(dropout_masks) all_node_states, _ = theano.scan(_scan_step, n_steps=iterations, non_sequences=[gstate.node_strengths, gstate.node_ids, gstate.edge_strengths] + used_dropout_masks, outputs_info=outputs_info) final_gstate = gstate.with_updates(node_states=all_node_states[-1,:,:,:]) if append_masks: return final_gstate, dropout_masks else: return final_gstate
class NewNodesInformTransformation( object ): """ Transforms a graph state by adding nodes, conditioned on an input vector """ def __init__(self, input_width, inform_width, proposal_width, graph_spec, use_old_aggregate=False, dropout_keep=1): """ Params: input_width: Integer giving size of input inform_width: Size of internal aggregate proposal_width: Size of internal proposal graph_spec: Instance of GraphStateSpec giving graph spec use_old_aggregate: Use the old aggregation mode """ self._input_width = input_width self._graph_spec = graph_spec self._proposal_width = proposal_width self._inform_width = inform_width aggregate_type = AggregateRepresentationTransformationSoftmax \ if use_old_aggregate \ else AggregateRepresentationTransformation self._inform_aggregate = aggregate_type(inform_width, graph_spec, dropout_keep, dropout_output=True) self._proposer_gru = BaseGRULayer(input_width+inform_width, proposal_width, name="newnodes_proposer", dropout_keep=dropout_keep, dropout_input=False, dropout_output=True) self._proposer_stack = LayerStack(proposal_width, 1+graph_spec.num_node_ids, [proposal_width], bias_shift=3.0, name="newnodes_proposer_post", dropout_keep=dropout_keep, dropout_input=False) @property def params(self): return self._proposer_gru.params + self._proposer_stack.params + self._inform_aggregate.params def dropout_masks(self, srng): return self._inform_aggregate.dropout_masks(srng) + self._proposer_gru.dropout_masks(srng) + self._proposer_stack.dropout_masks(srng) def get_candidates(self, gstate, input_vector, max_candidates, dropout_masks=Ellipsis): """ Get the current candidate new nodes. This is accomplished as follows: 1. Using the aggregate transformation, we gather information from nodes (who should have performed a state update already) 1. The proposer network, conditioned on the input and info, proposes multiple candidate nodes, along with a confidence 3. A new node is created for each candidate node, with an existence strength given by confidence, and an initial id as proposed This method directly returns these new nodes for comparision Params: gstate: A GraphState giving the current state input_vector: A tensor of the form (n_batch, input_width) max_candidates: Integer, limit on the number of candidates to produce Returns: new_strengths: A tensor of the form (n_batch, new_node_idx) new_ids: A tensor of the form (n_batch, new_node_idx, num_node_ids) """ if dropout_masks is Ellipsis: dropout_masks = None append_masks = False else: append_masks = True n_batch = gstate.n_batch n_nodes = gstate.n_nodes aggregated_repr, dropout_masks = self._inform_aggregate.process(gstate, dropout_masks) # aggregated_repr is of shape (n_batch, inform_width) full_input = T.concatenate([input_vector, aggregated_repr],1) outputs_info = [self._proposer_gru.initial_state(n_batch)] gru_dropout_masks, dropout_masks = self._proposer_gru.split_dropout_masks(dropout_masks) proposer_step = lambda st,ipt,*dm: self._proposer_gru.step(ipt, st, dm if dropout_masks is not None else None)[0] raw_proposal_acts, _ = theano.scan(proposer_step, n_steps=max_candidates, non_sequences=[full_input]+gru_dropout_masks, outputs_info=outputs_info) # raw_proposal_acts is of shape (candidate, n_batch, blah) flat_raw_acts = raw_proposal_acts.reshape([-1, self._proposal_width]) flat_processed_acts, dropout_masks = self._proposer_stack.process(flat_raw_acts, dropout_masks) candidate_strengths = T.nnet.sigmoid(flat_processed_acts[:,0]).reshape([max_candidates, n_batch]) candidate_ids = T.nnet.softmax(flat_processed_acts[:,1:]).reshape([max_candidates, n_batch, self._graph_spec.num_node_ids]) new_strengths = candidate_strengths.dimshuffle([1,0]) new_ids = candidate_ids.dimshuffle([1,0,2]) if append_masks: return new_strengths, new_ids, dropout_masks else: return new_strengths, new_ids def process(self, gstate, input_vector, max_candidates, dropout_masks=Ellipsis): """ Process an input vector and update the state accordingly. """ if dropout_masks is Ellipsis: dropout_masks = None append_masks = False else: append_masks = True new_strengths, new_ids, dropout_masks = self.get_candidates(gstate, input_vector, max_candidates, dropout_masks) new_gstate = gstate.with_additional_nodes(new_strengths, new_ids) if append_masks: return new_gstate, dropout_masks else: return new_gstate
class NodeStateUpdateTransformation(object): """ Transforms a graph state by updating note states, conditioned on an input vector """ def __init__(self, input_width, graph_spec, dropout_keep=1): """ Params: input_width: Integer giving size of input graph_spec: Instance of GraphStateSpec giving graph spec """ self._input_width = input_width self._graph_spec = graph_spec self._update_gru = BaseGRULayer(input_width + graph_spec.num_node_ids, graph_spec.node_state_size, name="nodestateupdate", dropout_keep=dropout_keep, dropout_input=False, dropout_output=True) @property def params(self): return self._update_gru.params def dropout_masks(self, srng, state_mask=None): return self._update_gru.dropout_masks(srng, use_output=state_mask) def process(self, gstate, input_vector, dropout_masks=Ellipsis): """ Process an input vector and update the state accordingly. Each node runs a GRU step with previous state from the node state and input from the vector. Params: gstate: A GraphState giving the current state input_vector: A tensor of the form (n_batch, input_width) """ # gstate.node_states is of shape (n_batch, n_nodes, node_state_width) # input_vector should be broadcasted to match this if dropout_masks is Ellipsis: dropout_masks = None append_masks = False else: append_masks = True prepped_input_vector = T.tile(T.shape_padaxis(input_vector, 1), [1, gstate.n_nodes, 1]) full_input = T.concatenate([gstate.node_ids, prepped_input_vector], 2) # we flatten to apply GRU flat_input = full_input.reshape( [-1, self._input_width + self._graph_spec.num_node_ids]) flat_state = gstate.node_states.reshape( [-1, self._graph_spec.node_state_size]) new_flat_state, dropout_masks = self._update_gru.step( flat_input, flat_state, dropout_masks) new_node_states = new_flat_state.reshape(gstate.node_states.shape) new_gstate = gstate.with_updates(node_states=new_node_states) if append_masks: return new_gstate, dropout_masks else: return new_gstate
class InputSequenceDirectTransformation(object): """ Transforms an input sequence into a representation vector """ def __init__(self, num_words, num_node_ids, word_node_mapping, output_width): """ num_words: Number of words in the input sequence word_node_mapping: Mapping of word idx to node idx for direct mapping """ self._num_words = num_words self._num_node_ids = num_node_ids self._word_node_mapping = word_node_mapping self._output_width = output_width self._word_node_matrix = np.zeros([num_words, num_node_ids], np.float32) for word, node in word_node_mapping.items(): self._word_node_matrix[word, node] = 1.0 self._gru = BaseGRULayer(num_words, output_width, name="input_sequence") @property def params(self): return self._gru.params def process(self, inputs): """ Process a set of inputs and return the final state Params: input_words: List of input indices. Should be an int tensor of shape (n_batch, input_len) Returns: repr_vect, node_vects repr_vect: The final representation vector, of shape (n_batch, output_width) node_vects: Direct-access vects for each node id, of shape (n_batch, num_node_ids, output_width) """ n_batch, input_len = inputs.shape valseq = inputs.dimshuffle([1, 0]) one_hot_vals = T.extra_ops.to_one_hot(inputs.flatten(), self._num_words)\ .reshape([n_batch, input_len, self._num_words]) one_hot_valseq = one_hot_vals.dimshuffle([1, 0, 2]) def scan_fn(idx_ipt, onehot_ipt, last_accum, last_state): # last_accum stores accumulated outputs per word type # and is of shape (n_batch, word_idx, output_width) gru_state = self._gru.step(onehot_ipt, last_state) new_accum = T.inc_subtensor( last_accum[T.arange(n_batch), idx_ipt, :], gru_state) return new_accum, gru_state outputs_info = [ T.zeros([n_batch, self._num_words, self._output_width]), self._gru.initial_state(n_batch) ] (all_accum, all_out), _ = theano.scan(scan_fn, sequences=[valseq, one_hot_valseq], outputs_info=outputs_info) # all_out is of shape (input_len, n_batch, self.output_width). We want last timestep repr_vect = all_out[-1, :, :] final_accum = all_accum[-1, :, :, :] # Now we also want to extract and accumulate the outputs that directly map to each word # We can do this by multipying the final accum's second dimension (word_idx) through by # the word_node_matrix resh_flat_final_accum = final_accum.dimshuffle([0, 2, 1]).reshape( [-1, self._num_words]) resh_flat_node_mat = T.dot(resh_flat_final_accum, self._word_node_matrix) node_vects = resh_flat_node_mat.reshape( [n_batch, self._output_width, self._num_node_ids]).dimshuffle([0, 2, 1]) return repr_vect, node_vects
class DirectReferenceUpdateTransformation(object): """ Transforms a graph state by updating note states, conditioned on a direct reference accumulation """ def __init__(self, input_width, graph_spec, dropout_keep=1): """ Params: input_width: Integer giving size of input graph_spec: Instance of GraphStateSpec giving graph spec """ self._input_width = input_width self._graph_spec = graph_spec self._update_gru = BaseGRULayer(input_width + graph_spec.num_node_ids, graph_spec.node_state_size, name="nodestateupdate", dropout_keep=dropout_keep) @property def params(self): return self._update_gru.params def dropout_masks(self, srng, state_mask=None): return self._update_gru.dropout_masks(srng, use_output=state_mask) def process(self, gstate, ref_matrix, dropout_masks=Ellipsis): """ Process a direct ref matrix and update the state accordingly. Each node runs a GRU step with previous state from the node state and input from the matrix. Params: gstate: A GraphState giving the current state ref_matrix: A tensor of the form (n_batch, num_node_ids, input_width) """ if dropout_masks is Ellipsis: dropout_masks = None append_masks = False else: append_masks = True # To process the input, we need to map from node id to node index # We can do this using the gstate.node_ids, of shape (n_batch, n_nodes, num_node_ids) prepped_input_vector = T.batched_dot(gstate.node_ids, ref_matrix) # prepped_input_vector is of shape (n_batch, n_nodes, input_width) # gstate.node_states is of shape (n_batch, n_nodes, node_state_width) # so they match nicely full_input = T.concatenate([gstate.node_ids, prepped_input_vector], 2) # we flatten to apply GRU flat_input = full_input.reshape( [-1, self._input_width + self._graph_spec.num_node_ids]) flat_state = gstate.node_states.reshape( [-1, self._graph_spec.node_state_size]) new_flat_state, dropout_masks = self._update_gru.step( flat_input, flat_state, dropout_masks) new_node_states = new_flat_state.reshape(gstate.node_states.shape) new_gstate = gstate.with_updates(node_states=new_node_states) if append_masks: return new_gstate, dropout_masks else: return new_gstate