def __init__(self, input_width, proposal_width, graph_spec):
        """
        Params:
            input_width: Integer giving size of input
            proposal_width: Size of internal proposal
            graph_spec: Instance of GraphStateSpec giving graph spec
        """
        self._input_width = input_width
        self._graph_spec = graph_spec
        self._proposal_width = proposal_width

        self._proposer_gru = BaseGRULayer(input_width,
                                          proposal_width,
                                          name="newnodes_proposer")

        self._proposer_stack = LayerStack(proposal_width,
                                          1 + graph_spec.num_node_ids,
                                          [proposal_width],
                                          bias_shift=3.0,
                                          name="newnodes_proposer_post")
        isize = 2 * graph_spec.num_node_ids + graph_spec.node_state_size
        self._vote_stack = LayerStack(isize,
                                      1, [isize],
                                      activation=T.nnet.sigmoid,
                                      bias_shift=-3.0,
                                      name="newnodes_vote")
    def __init__(self, input_width, state_size, num_words):
        self._input_width = input_width
        self._state_size = state_size
        self._num_words = num_words

        self._seq_gru = BaseGRULayer(input_width,
                                     state_size,
                                     name="output_seq_gru")
        self._transform_stack = LayerStack(state_size,
                                           num_words,
                                           activation=T.nnet.softmax,
                                           name="output_seq_transf")
class OutputSequenceTransformation(object):
    """
    Transforms a representation vector into a sequence of outputs
    """
    def __init__(self, input_width, state_size, num_words):
        self._input_width = input_width
        self._state_size = state_size
        self._num_words = num_words

        self._seq_gru = BaseGRULayer(input_width,
                                     state_size,
                                     name="output_seq_gru")
        self._transform_stack = LayerStack(state_size,
                                           num_words,
                                           activation=T.nnet.softmax,
                                           name="output_seq_transf")

    @property
    def params(self):
        return self._seq_gru.params + self._transform_stack.params

    def process(self, input_vector, seq_len):
        """
        Convert an input vector into a sequence of categorical distributions

        Params:
            input_vector: Vector of shape (n_batch, input_width)
            seq_len: How many outputs to produce

        Returns: Sequence distribution of shape (n_batch, seq_len, num_words)
        """
        n_batch = input_vector.shape[0]
        outputs_info = [self._seq_gru.initial_state(n_batch)]
        scan_step = lambda state, ipt: self._seq_gru.step(ipt, state)
        all_out, _ = theano.scan(scan_step,
                                 non_sequences=[input_vector],
                                 n_steps=seq_len,
                                 outputs_info=outputs_info)

        # all_out is of shape (seq_len, n_batch, state_size). Squash and apply layer
        flat_out = all_out.reshape([-1, self._state_size])
        flat_final = self._transform_stack.process(flat_out)
        final = flat_final.reshape([seq_len, n_batch,
                                    self._num_words]).dimshuffle([1, 0, 2])

        return final

    def snap_to_best(self, answer):
        """
        Convert output of process to the "best" answer, i.e. the answer with highest probability.
        """
        return categorical_best(answer)
    def __init__(self, input_width, graph_spec, dropout_keep=1):
        """
        Params:
            input_width: Integer giving size of input
            graph_spec: Instance of GraphStateSpec giving graph spec
        """
        self._input_width = input_width
        self._graph_spec = graph_spec

        self._update_gru = BaseGRULayer(input_width + graph_spec.num_node_ids,
                                        graph_spec.node_state_size,
                                        name="nodestateupdate",
                                        dropout_keep=dropout_keep)
Esempio n. 5
0
    def __init__(self, transfer_size, graph_spec, transfer_activation=identity, dropout_keep=1):
        """
        Params:
            transfer_size: Integer, how much to transfer
            graph_spec: Instance of GraphStateSpec giving graph spec
            transfer_activation: Activation function to use during transfer
        """
        self._transfer_size = transfer_size
        self._transfer_activation = transfer_activation
        self._graph_spec = graph_spec
        self._process_input_size = graph_spec.num_node_ids + graph_spec.node_state_size

        self._transfer_stack = LayerStack(self._process_input_size, 2 * graph_spec.num_edge_types * transfer_size, activation=self._transfer_activation, name="propagation_transfer", dropout_keep=dropout_keep, dropout_input=False, dropout_output=True)
        self._propagation_gru = BaseGRULayer(graph_spec.num_node_ids + self._transfer_size, graph_spec.node_state_size, name="propagation", dropout_keep=dropout_keep, dropout_input=False, dropout_output=True)
Esempio n. 6
0
    def __init__(self, num_words, num_node_ids, word_node_mapping,
                 output_width):
        """
            num_words: Number of words in the input sequence
            word_node_mapping: Mapping of word idx to node idx for direct mapping
        """
        self._num_words = num_words
        self._num_node_ids = num_node_ids
        self._word_node_mapping = word_node_mapping
        self._output_width = output_width

        self._word_node_matrix = np.zeros([num_words, num_node_ids],
                                          np.float32)
        for word, node in word_node_mapping.items():
            self._word_node_matrix[word, node] = 1.0

        self._gru = BaseGRULayer(num_words,
                                 output_width,
                                 name="input_sequence")
Esempio n. 7
0
    def __init__(self, input_width, inform_width, proposal_width, graph_spec, use_old_aggregate=False, dropout_keep=1):
        """
        Params:
            input_width: Integer giving size of input
            inform_width: Size of internal aggregate
            proposal_width: Size of internal proposal
            graph_spec: Instance of GraphStateSpec giving graph spec
            use_old_aggregate: Use the old aggregation mode
        """
        self._input_width = input_width
        self._graph_spec = graph_spec
        self._proposal_width = proposal_width
        self._inform_width = inform_width

        aggregate_type = AggregateRepresentationTransformationSoftmax \
                                                if use_old_aggregate \
                                                else AggregateRepresentationTransformation

        self._inform_aggregate = aggregate_type(inform_width, graph_spec, dropout_keep, dropout_output=True)
        self._proposer_gru = BaseGRULayer(input_width+inform_width, proposal_width, name="newnodes_proposer", dropout_keep=dropout_keep, dropout_input=False, dropout_output=True)
        self._proposer_stack = LayerStack(proposal_width, 1+graph_spec.num_node_ids, [proposal_width], bias_shift=3.0, name="newnodes_proposer_post", dropout_keep=dropout_keep, dropout_input=False)
class NewNodesVoteTransformation(object):
    """
    Transforms a graph state by adding nodes, conditioned on an input vector
    """
    def __init__(self, input_width, proposal_width, graph_spec):
        """
        Params:
            input_width: Integer giving size of input
            proposal_width: Size of internal proposal
            graph_spec: Instance of GraphStateSpec giving graph spec
        """
        self._input_width = input_width
        self._graph_spec = graph_spec
        self._proposal_width = proposal_width

        self._proposer_gru = BaseGRULayer(input_width,
                                          proposal_width,
                                          name="newnodes_proposer")

        self._proposer_stack = LayerStack(proposal_width,
                                          1 + graph_spec.num_node_ids,
                                          [proposal_width],
                                          bias_shift=3.0,
                                          name="newnodes_proposer_post")
        isize = 2 * graph_spec.num_node_ids + graph_spec.node_state_size
        self._vote_stack = LayerStack(isize,
                                      1, [isize],
                                      activation=T.nnet.sigmoid,
                                      bias_shift=-3.0,
                                      name="newnodes_vote")

    @property
    def params(self):
        return self._proposer_gru.params + self._proposer_stack.params + self._vote_stack.params

    @property
    def num_dropout_masks(self):
        return self._proposer_gru.num_dropout_masks

    def get_dropout_masks(self, srng, keep_frac):
        return self._proposer_gru.get_dropout_masks(srng, keep_frac)

    def get_candidates(self,
                       gstate,
                       input_vector,
                       max_candidates,
                       dropout_masks=None):
        """
        Get the current candidate new nodes. This is accomplished as follows:
          1. The proposer network, conditioned on the input vector, proposes multiple candidate nodes,
                along with a confidence
          2. Every existing node, conditioned on its own state and the candidate, votes on whether or not
                to accept this node
          3. A new node is created for each candidate node, with an existence strength given by
                confidence * [product of all votes], and an initial state state as proposed
        This method directly returns these new nodes for comparision

        Params:
            gstate: A GraphState giving the current state
            input_vector: A tensor of the form (n_batch, input_width)
            max_candidates: Integer, limit on the number of candidates to produce

        Returns:
            new_strengths: A tensor of the form (n_batch, new_node_idx)
            new_ids: A tensor of the form (n_batch, new_node_idx, num_node_ids)
        """
        n_batch = gstate.n_batch
        n_nodes = gstate.n_nodes
        outputs_info = [self._proposer_gru.initial_state(n_batch)]
        proposer_step = lambda st, ipt, *dm: self._proposer_gru.step(
            ipt, st, dm if dropout_masks is not None else None)
        raw_proposal_acts, _ = theano.scan(
            proposer_step,
            n_steps=max_candidates,
            non_sequences=[input_vector] +
            (dropout_masks if dropout_masks is not None else []),
            outputs_info=outputs_info)

        # raw_proposal_acts is of shape (candidate, n_batch, blah)
        flat_raw_acts = raw_proposal_acts.reshape([-1, self._proposal_width])
        flat_processed_acts = self._proposer_stack.process(flat_raw_acts)
        candidate_strengths = T.nnet.sigmoid(
            flat_processed_acts[:, 0]).reshape([max_candidates, n_batch])
        candidate_ids = T.nnet.softmax(flat_processed_acts[:, 1:]).reshape(
            [max_candidates, n_batch, self._graph_spec.num_node_ids])

        # Votes will be of shape (candidate, n_batch, n_nodes)
        # To generate this we want to assemble (candidate, n_batch, n_nodes, input_stuff),
        # squash to (parallel, input_stuff), do voting op, then unsquash
        candidate_id_part = T.shape_padaxis(candidate_ids, 2)
        node_id_part = T.shape_padaxis(gstate.node_ids, 0)
        node_state_part = T.shape_padaxis(gstate.node_states, 0)
        full_vote_input = broadcast_concat(
            [node_id_part, node_state_part, candidate_id_part], 3)
        flat_vote_input = full_vote_input.reshape(
            [-1, full_vote_input.shape[-1]])
        vote_result = self._vote_stack.process(flat_vote_input)
        final_votes_no = vote_result.reshape(
            [max_candidates, n_batch, n_nodes])
        weighted_votes_yes = 1 - final_votes_no * T.shape_padleft(
            gstate.node_strengths)
        # Add in the strength vote
        all_votes = T.concatenate(
            [T.shape_padright(candidate_strengths), weighted_votes_yes], 2)
        # Take the product -> (candidate, n_batch)
        chosen_strengths = T.prod(all_votes, 2)

        new_strengths = chosen_strengths.dimshuffle([1, 0])
        new_ids = candidate_ids.dimshuffle([1, 0, 2])
        return new_strengths, new_ids

    def process(self,
                gstate,
                input_vector,
                max_candidates,
                dropout_masks=None):
        """
        Process an input vector and update the state accordingly.
        """
        new_strengths, new_ids = self.get_candidates(gstate, input_vector,
                                                     max_candidates,
                                                     dropout_masks)
        new_gstate = gstate.with_additional_nodes(new_strengths, new_ids)
        return new_gstate
Esempio n. 9
0
class PropagationTransformation( object ):
    """
    Transforms a graph state by propagating info across the graph
    """
    def __init__(self, transfer_size, graph_spec, transfer_activation=identity, dropout_keep=1):
        """
        Params:
            transfer_size: Integer, how much to transfer
            graph_spec: Instance of GraphStateSpec giving graph spec
            transfer_activation: Activation function to use during transfer
        """
        self._transfer_size = transfer_size
        self._transfer_activation = transfer_activation
        self._graph_spec = graph_spec
        self._process_input_size = graph_spec.num_node_ids + graph_spec.node_state_size

        self._transfer_stack = LayerStack(self._process_input_size, 2 * graph_spec.num_edge_types * transfer_size, activation=self._transfer_activation, name="propagation_transfer", dropout_keep=dropout_keep, dropout_input=False, dropout_output=True)
        self._propagation_gru = BaseGRULayer(graph_spec.num_node_ids + self._transfer_size, graph_spec.node_state_size, name="propagation", dropout_keep=dropout_keep, dropout_input=False, dropout_output=True)

    @property
    def params(self):
        return self._propagation_gru.params +  self._transfer_stack.params

    def dropout_masks(self, srng, state_mask=None):
        return self._transfer_stack.dropout_masks(srng) + self._propagation_gru.dropout_masks(srng, use_output=state_mask)

    def split_dropout_masks(self, dropout_masks):
        transfer_used, dropout_masks = self._transfer_stack.split_dropout_masks(dropout_masks)
        gru_used, dropout_masks = self._propagation_gru.split_dropout_masks(dropout_masks)
        return (transfer_used+gru_used), dropout_masks

    def process(self, gstate, dropout_masks=Ellipsis):
        """
        Process a graph state.
          1. Data is transfered from each node to each other node along both forward and backward edges.
                This data is processed with a Wx+b style update, and an optional transformation is applied
          2. Nodes sum the transfered data, weighted by the existence of the other node and the edge.
          3. Nodes perform a GRU update with this input

        Params:
            gstate: A GraphState giving the current state
        """
        if dropout_masks is Ellipsis:
            dropout_masks = None
            append_masks = False
        else:
            append_masks = True

        node_obs = T.concatenate([gstate.node_ids, gstate.node_states],2)
        flat_node_obs = node_obs.reshape([-1, self._process_input_size])
        transformed, dropout_masks = self._transfer_stack.process(flat_node_obs,dropout_masks)
        transformed = transformed.reshape([gstate.n_batch, gstate.n_nodes, 2*self._graph_spec.num_edge_types, self._transfer_size])
        scaled_transformed = transformed * T.shape_padright(T.shape_padright(gstate.node_strengths))
        # scaled_transformed is of shape (n_batch, n_nodes, 2*num_edge_types, transfer_size)
        # We want to multiply  through by edge strengths, which are of shape
        # (n_batch, n_nodes, n_nodes, num_edge_types), both fwd and backward
        edge_strength_scale = T.concatenate([gstate.edge_strengths, gstate.edge_strengths.swapaxes(1,2)], 3)
        # edge_strength_scale is of (n_batch, n_nodes, n_nodes, 2*num_edge_types)
        intermed = T.shape_padaxis(scaled_transformed, 2) * T.shape_padright(edge_strength_scale)
        # intermed is of shape (n_batch, n_nodes "source", n_nodes "dest", 2*num_edge_types, transfer_size)
        # now reduce along the "source" and "edge_types" dimensions to get dest activations
        # of shape (n_batch, n_nodes, transfer_size)
        reduced_result = T.sum(T.sum(intermed, 3), 1)

        # now add information fom current node id
        full_input = T.concatenate([gstate.node_ids, reduced_result], 2)

        # we flatten to apply GRU
        flat_input = full_input.reshape([-1, self._graph_spec.num_node_ids + self._transfer_size])
        flat_state = gstate.node_states.reshape([-1, self._graph_spec.node_state_size])
        new_flat_state, dropout_masks = self._propagation_gru.step(flat_input, flat_state, dropout_masks)

        new_node_states = new_flat_state.reshape(gstate.node_states.shape)

        new_gstate = gstate.with_updates(node_states=new_node_states)
        if append_masks:
            return new_gstate, dropout_masks
        else:
            return new_gstate

    def process_multiple(self, gstate, iterations, dropout_masks=Ellipsis):
        """
        Run multiple propagagtion steps.

        Params:
            gstate: A GraphState giving the current state
            iterations: An integer. How many steps to propagate
        """
        if dropout_masks is Ellipsis:
            dropout_masks = None
            append_masks = False
        else:
            append_masks = True

        def _scan_step(cur_node_states, node_strengths, node_ids, edge_strengths, *dmasks):
            curstate = GraphState(node_strengths, node_ids, cur_node_states, edge_strengths)
            newstate, _ = self.process(curstate, dmasks if dropout_masks is not None else None)
            return newstate.node_states

        outputs_info = [gstate.node_states]
        used_dropout_masks, dropout_masks = self.split_dropout_masks(dropout_masks)
        all_node_states, _ = theano.scan(_scan_step, n_steps=iterations, non_sequences=[gstate.node_strengths, gstate.node_ids, gstate.edge_strengths] + used_dropout_masks, outputs_info=outputs_info)

        final_gstate = gstate.with_updates(node_states=all_node_states[-1,:,:,:])
        if append_masks:
            return final_gstate, dropout_masks
        else:
            return final_gstate
Esempio n. 10
0
class NewNodesInformTransformation( object ):
    """
    Transforms a graph state by adding nodes, conditioned on an input vector
    """
    def __init__(self, input_width, inform_width, proposal_width, graph_spec, use_old_aggregate=False, dropout_keep=1):
        """
        Params:
            input_width: Integer giving size of input
            inform_width: Size of internal aggregate
            proposal_width: Size of internal proposal
            graph_spec: Instance of GraphStateSpec giving graph spec
            use_old_aggregate: Use the old aggregation mode
        """
        self._input_width = input_width
        self._graph_spec = graph_spec
        self._proposal_width = proposal_width
        self._inform_width = inform_width

        aggregate_type = AggregateRepresentationTransformationSoftmax \
                                                if use_old_aggregate \
                                                else AggregateRepresentationTransformation

        self._inform_aggregate = aggregate_type(inform_width, graph_spec, dropout_keep, dropout_output=True)
        self._proposer_gru = BaseGRULayer(input_width+inform_width, proposal_width, name="newnodes_proposer", dropout_keep=dropout_keep, dropout_input=False, dropout_output=True)
        self._proposer_stack = LayerStack(proposal_width, 1+graph_spec.num_node_ids, [proposal_width], bias_shift=3.0, name="newnodes_proposer_post", dropout_keep=dropout_keep, dropout_input=False)

    @property
    def params(self):
        return self._proposer_gru.params + self._proposer_stack.params + self._inform_aggregate.params

    def dropout_masks(self, srng):
        return self._inform_aggregate.dropout_masks(srng) + self._proposer_gru.dropout_masks(srng) + self._proposer_stack.dropout_masks(srng)

    def get_candidates(self, gstate, input_vector, max_candidates, dropout_masks=Ellipsis):
        """
        Get the current candidate new nodes. This is accomplished as follows:
          1. Using the aggregate transformation, we gather information from nodes (who should have performed
                a state update already)
          1. The proposer network, conditioned on the input and info, proposes multiple candidate nodes,
                along with a confidence
          3. A new node is created for each candidate node, with an existence strength given by
                confidence, and an initial id as proposed
        This method directly returns these new nodes for comparision

        Params:
            gstate: A GraphState giving the current state
            input_vector: A tensor of the form (n_batch, input_width)
            max_candidates: Integer, limit on the number of candidates to produce

        Returns:
            new_strengths: A tensor of the form (n_batch, new_node_idx)
            new_ids: A tensor of the form (n_batch, new_node_idx, num_node_ids)
        """
        if dropout_masks is Ellipsis:
            dropout_masks = None
            append_masks = False
        else:
            append_masks = True

        n_batch = gstate.n_batch
        n_nodes = gstate.n_nodes

        aggregated_repr, dropout_masks = self._inform_aggregate.process(gstate, dropout_masks)
        # aggregated_repr is of shape (n_batch, inform_width)
        
        full_input = T.concatenate([input_vector, aggregated_repr],1)

        outputs_info = [self._proposer_gru.initial_state(n_batch)]
        gru_dropout_masks, dropout_masks = self._proposer_gru.split_dropout_masks(dropout_masks)
        proposer_step = lambda st,ipt,*dm: self._proposer_gru.step(ipt, st, dm if dropout_masks is not None else None)[0]
        raw_proposal_acts, _ = theano.scan(proposer_step, n_steps=max_candidates, non_sequences=[full_input]+gru_dropout_masks, outputs_info=outputs_info)

        # raw_proposal_acts is of shape (candidate, n_batch, blah)
        flat_raw_acts = raw_proposal_acts.reshape([-1, self._proposal_width])
        flat_processed_acts, dropout_masks = self._proposer_stack.process(flat_raw_acts, dropout_masks)
        candidate_strengths = T.nnet.sigmoid(flat_processed_acts[:,0]).reshape([max_candidates, n_batch])
        candidate_ids = T.nnet.softmax(flat_processed_acts[:,1:]).reshape([max_candidates, n_batch, self._graph_spec.num_node_ids])

        new_strengths = candidate_strengths.dimshuffle([1,0])
        new_ids = candidate_ids.dimshuffle([1,0,2])
        if append_masks:
            return new_strengths, new_ids, dropout_masks
        else:
            return new_strengths, new_ids

    def process(self, gstate, input_vector, max_candidates, dropout_masks=Ellipsis):
        """
        Process an input vector and update the state accordingly.
        """
        if dropout_masks is Ellipsis:
            dropout_masks = None
            append_masks = False
        else:
            append_masks = True
        new_strengths, new_ids, dropout_masks = self.get_candidates(gstate, input_vector, max_candidates, dropout_masks)
        new_gstate = gstate.with_additional_nodes(new_strengths, new_ids)
        if append_masks:
            return new_gstate, dropout_masks
        else:
            return new_gstate
Esempio n. 11
0
class NodeStateUpdateTransformation(object):
    """
    Transforms a graph state by updating note states, conditioned on an input vector
    """
    def __init__(self, input_width, graph_spec, dropout_keep=1):
        """
        Params:
            input_width: Integer giving size of input
            graph_spec: Instance of GraphStateSpec giving graph spec
        """
        self._input_width = input_width
        self._graph_spec = graph_spec

        self._update_gru = BaseGRULayer(input_width + graph_spec.num_node_ids,
                                        graph_spec.node_state_size,
                                        name="nodestateupdate",
                                        dropout_keep=dropout_keep,
                                        dropout_input=False,
                                        dropout_output=True)

    @property
    def params(self):
        return self._update_gru.params

    def dropout_masks(self, srng, state_mask=None):
        return self._update_gru.dropout_masks(srng, use_output=state_mask)

    def process(self, gstate, input_vector, dropout_masks=Ellipsis):
        """
        Process an input vector and update the state accordingly. Each node runs a GRU step
        with previous state from the node state and input from the vector.

        Params:
            gstate: A GraphState giving the current state
            input_vector: A tensor of the form (n_batch, input_width)
        """

        # gstate.node_states is of shape (n_batch, n_nodes, node_state_width)
        # input_vector should be broadcasted to match this
        if dropout_masks is Ellipsis:
            dropout_masks = None
            append_masks = False
        else:
            append_masks = True
        prepped_input_vector = T.tile(T.shape_padaxis(input_vector, 1),
                                      [1, gstate.n_nodes, 1])
        full_input = T.concatenate([gstate.node_ids, prepped_input_vector], 2)

        # we flatten to apply GRU
        flat_input = full_input.reshape(
            [-1, self._input_width + self._graph_spec.num_node_ids])
        flat_state = gstate.node_states.reshape(
            [-1, self._graph_spec.node_state_size])
        new_flat_state, dropout_masks = self._update_gru.step(
            flat_input, flat_state, dropout_masks)

        new_node_states = new_flat_state.reshape(gstate.node_states.shape)

        new_gstate = gstate.with_updates(node_states=new_node_states)
        if append_masks:
            return new_gstate, dropout_masks
        else:
            return new_gstate
Esempio n. 12
0
class InputSequenceDirectTransformation(object):
    """
    Transforms an input sequence into a representation vector
    """
    def __init__(self, num_words, num_node_ids, word_node_mapping,
                 output_width):
        """
            num_words: Number of words in the input sequence
            word_node_mapping: Mapping of word idx to node idx for direct mapping
        """
        self._num_words = num_words
        self._num_node_ids = num_node_ids
        self._word_node_mapping = word_node_mapping
        self._output_width = output_width

        self._word_node_matrix = np.zeros([num_words, num_node_ids],
                                          np.float32)
        for word, node in word_node_mapping.items():
            self._word_node_matrix[word, node] = 1.0

        self._gru = BaseGRULayer(num_words,
                                 output_width,
                                 name="input_sequence")

    @property
    def params(self):
        return self._gru.params

    def process(self, inputs):
        """
        Process a set of inputs and return the final state

        Params:
            input_words: List of input indices. Should be an int tensor of shape (n_batch, input_len)

        Returns: repr_vect, node_vects
            repr_vect: The final representation vector, of shape (n_batch, output_width)
            node_vects: Direct-access vects for each node id, of shape (n_batch, num_node_ids, output_width)
        """
        n_batch, input_len = inputs.shape
        valseq = inputs.dimshuffle([1, 0])
        one_hot_vals = T.extra_ops.to_one_hot(inputs.flatten(), self._num_words)\
                    .reshape([n_batch, input_len, self._num_words])
        one_hot_valseq = one_hot_vals.dimshuffle([1, 0, 2])

        def scan_fn(idx_ipt, onehot_ipt, last_accum, last_state):
            # last_accum stores accumulated outputs per word type
            # and is of shape (n_batch, word_idx, output_width)
            gru_state = self._gru.step(onehot_ipt, last_state)
            new_accum = T.inc_subtensor(
                last_accum[T.arange(n_batch), idx_ipt, :], gru_state)
            return new_accum, gru_state

        outputs_info = [
            T.zeros([n_batch, self._num_words, self._output_width]),
            self._gru.initial_state(n_batch)
        ]
        (all_accum,
         all_out), _ = theano.scan(scan_fn,
                                   sequences=[valseq, one_hot_valseq],
                                   outputs_info=outputs_info)

        # all_out is of shape (input_len, n_batch, self.output_width). We want last timestep
        repr_vect = all_out[-1, :, :]

        final_accum = all_accum[-1, :, :, :]
        # Now we also want to extract and accumulate the outputs that directly map to each word
        # We can do this by multipying the final accum's second dimension (word_idx) through by
        # the word_node_matrix
        resh_flat_final_accum = final_accum.dimshuffle([0, 2, 1]).reshape(
            [-1, self._num_words])
        resh_flat_node_mat = T.dot(resh_flat_final_accum,
                                   self._word_node_matrix)
        node_vects = resh_flat_node_mat.reshape(
            [n_batch, self._output_width,
             self._num_node_ids]).dimshuffle([0, 2, 1])

        return repr_vect, node_vects
class DirectReferenceUpdateTransformation(object):
    """
    Transforms a graph state by updating note states, conditioned on a direct reference accumulation
    """
    def __init__(self, input_width, graph_spec, dropout_keep=1):
        """
        Params:
            input_width: Integer giving size of input
            graph_spec: Instance of GraphStateSpec giving graph spec
        """
        self._input_width = input_width
        self._graph_spec = graph_spec

        self._update_gru = BaseGRULayer(input_width + graph_spec.num_node_ids,
                                        graph_spec.node_state_size,
                                        name="nodestateupdate",
                                        dropout_keep=dropout_keep)

    @property
    def params(self):
        return self._update_gru.params

    def dropout_masks(self, srng, state_mask=None):
        return self._update_gru.dropout_masks(srng, use_output=state_mask)

    def process(self, gstate, ref_matrix, dropout_masks=Ellipsis):
        """
        Process a direct ref matrix and update the state accordingly. Each node runs a GRU step
        with previous state from the node state and input from the matrix.

        Params:
            gstate: A GraphState giving the current state
            ref_matrix: A tensor of the form (n_batch, num_node_ids, input_width)
        """
        if dropout_masks is Ellipsis:
            dropout_masks = None
            append_masks = False
        else:
            append_masks = True

        # To process the input, we need to map from node id to node index
        # We can do this using the gstate.node_ids, of shape (n_batch, n_nodes, num_node_ids)
        prepped_input_vector = T.batched_dot(gstate.node_ids, ref_matrix)

        # prepped_input_vector is of shape (n_batch, n_nodes, input_width)
        # gstate.node_states is of shape (n_batch, n_nodes, node_state_width)
        # so they match nicely
        full_input = T.concatenate([gstate.node_ids, prepped_input_vector], 2)

        # we flatten to apply GRU
        flat_input = full_input.reshape(
            [-1, self._input_width + self._graph_spec.num_node_ids])
        flat_state = gstate.node_states.reshape(
            [-1, self._graph_spec.node_state_size])
        new_flat_state, dropout_masks = self._update_gru.step(
            flat_input, flat_state, dropout_masks)

        new_node_states = new_flat_state.reshape(gstate.node_states.shape)

        new_gstate = gstate.with_updates(node_states=new_node_states)
        if append_masks:
            return new_gstate, dropout_masks
        else:
            return new_gstate