コード例 #1
0
    def __init__(self,
                 embedding_dim,
                 state_dim,
                 use_local_attention=False,
                 window_size=10,
                 **kwargs):
        super(SentenceEncoder, self).__init__(**kwargs)

        self.embedding_dim = embedding_dim
        self.state_dim = state_dim
        self.rnn = GRU(activation=Tanh(),
                       dim=state_dim,
                       attended_dim=embedding_dim)
        self.input_fork = Fork(
            [name for name in self.rnn.apply.sequences if name != 'mask'],
            prototype=Linear(),
            name='input_fork')
        self.energy_computer = SumMatchFunction_posTag(
            name="wordAtt_energy_comp")
        self.attention = SequenceContentAttention_withExInput(
            state_names=['states'],
            state_dims=[state_dim],
            attended_dim=embedding_dim,
            match_dim=state_dim,
            posTag_dim=self.state_dim,
            energy_computer=self.energy_computer,
            use_local_attention=use_local_attention,
            window_size=window_size,
            name="word_attention")

        self.children = [self.rnn, self.input_fork, self.attention]
コード例 #2
0
 def __init__(self,
              index2word,
              emb_size,
              class_inv_freq,
              hidden_size,
              bidirectional,
              learning_rate,
              model_save_path,
              pretrained_path=""):
     self.softmax = nn.Softmax()
     self.model = GRU(index2word,
                      emb_size,
                      hidden_size,
                      len(class_inv_freq),
                      bidirectional,
                      pretrained_path=pretrained_path)
     self.optimizer = optim.Adam(self.model.parameters(), lr=learning_rate)
     self.criterion = nn.CrossEntropyLoss()
     #self.criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor(class_inv_freq))
     self.model_save_path = model_save_path
コード例 #3
0
    def __init__(self, n_in, n_hidden, n_out, batch_size, use_gpu):
        super(GRUModel, self).__init__()
        self.n_in = n_in
        self.n_hidden = n_hidden
        self.n_out = n_out
        self.batch_size = batch_size
        self.use_gpu = use_gpu

        self.gru_layer = GRU(self.n_in, self.n_hidden, batch_size,
                             self.use_gpu)
        self.clf = nn.Linear(self.n_hidden, self.n_out)
        self.loss = nn.CrossEntropyLoss()
コード例 #4
0
    def __init__(self,
                 context_version="classic",
                 cut_gradient=False,
                 aggregate="sum",
                 discount=1,
                 return_coefficients=False,
                 W_regularizer=None,
                 W_context_regularizer=None,
                 u_regularizer=None,
                 b_regularizer=None,
                 W_constraint=None,
                 W_context_constraint=None,
                 u_constraint=None,
                 b_constraint=None,
                 bias=True,
                 **kwargs):

        self.context_version = context_version
        self.cut_gradient = cut_gradient
        self.aggregate = aggregate
        self.discount = discount
        self.supports_masking = True
        self.return_coefficients = return_coefficients
        self.init = initializers.get('glorot_uniform')

        self.W_regularizer = regularizers.get(W_regularizer)
        self.W_context_regularizer = regularizers.get(W_context_regularizer)
        self.u_regularizer = regularizers.get(u_regularizer)
        self.b_regularizer = regularizers.get(b_regularizer)

        self.W_constraint = constraints.get(W_constraint)
        self.W_context_constraint = constraints.get(W_context_constraint)
        self.u_constraint = constraints.get(u_constraint)
        self.b_constraint = constraints.get(b_constraint)

        self.bias = bias

        self.GRU = GRU(self, 100, 100)

        super(RecurrentContext, self).__init__(**kwargs)
コード例 #5
0
def main():
    # set up check points location
    utils.safe_mkdir_depths('../checkpoints/'+config.PROJECT_NAME+'/'+config.MODEL_NAME+'/')
    utils.safe_mkdir_depths('../log/'+config.PROJECT_NAME+'/'+config.MODEL_NAME+'/')
    logging.basicConfig(filename=config.LOG_PATH,level=logging.DEBUG)
    utils.safe_mkdir_depths('../visualization/'+config.PROJECT_NAME+'/'+config.MODEL_NAME+'/')

    parser = argparse.ArgumentParser()
    parser.add_argument('--mode', choices={'train', 'inference','transfer'},
                        default='train', help="mode. if not specified, it's in the train mode")
    args = parser.parse_args()

    if config.MODEL_NAME=='GRU':
        compute_graph = GRU(config.MODEL_NAME)
    elif config.MODEL_NAME =='LSTM':
        compute_graph = LSTM(config.MODEL_NAME)
    elif config.MODEL_NAME =='CNN':
        compute_graph = CNN(config.MODEL_NAME)

    #compute_graph.vocab_size = config.VOCAB_SIZE

    if args.mode == 'train':
        local_dest = config.TRAIN_DATA_PATH+config.TRAIN_DATA_NAME
        local_dest_label = config.TRAIN_DATA_PATH+config.TRAIN_LABEL_NAME
        validation_dest=config.VALIDATION_DATA_PATH+config.VALIDATION_DATA_NAME
        validation_dest_label=config.VALIDATION_DATA_PATH+config.VALIDATION_LABEL

        if config.PRETRAIN_EMBD_TAG:  # use start pretrain embd or not
            embd_dest = config.PRETRAIN_EMBD_PATH
            data.get_pretrain_embedding(compute_graph,embd_dest)

        iterator,training_init_op= data.get_data(local_dest,local_dest_label)
        next_element=iterator.get_next()
        _,validation_init_op =data.get_data(validation_dest,validation_dest_label,iterator)
        run_process.train(compute_graph,next_element,training_init_op,validation_init_op,config.EPOCH_NUM)

    elif args.mode == 'inference':
        local_dest = config.TEST_DATA_PATH + config.TEST_DATA_NAME
        local_dest_label=None
        if hasattr(config,'TEST_LABEL_NAME'):
            local_dest_label = config.TEST_DATA_PATH + config.TEST_LABEL_NAME

        if config.PRETRAIN_EMBD_TAG:  # use start pretrain embd or not
            embd_dest = config.PRETRAIN_EMBD_PATH
            data.get_pretrain_embedding(compute_graph,embd_dest)

        iterator,inference_init_op = data.get_data(local_dest, local_dest_label)
        next_element=iterator.get_next()
        run_process.inference(compute_graph,next_element,inference_init_op)

    elif args.mode == 'transfer':
        run_process.test_restore()
コード例 #6
0
    def __init__(self,
                 vocab_size,
                 topicWord_size,
                 embedding_dim,
                 state_dim,
                 topical_dim,
                 representation_dim,
                 match_function='SumMacthFunction',
                 use_doubly_stochastic=False,
                 lambda_ds=0.001,
                 use_local_attention=False,
                 window_size=10,
                 use_step_decay_cost=False,
                 use_concentration_cost=False,
                 lambda_ct=10,
                 use_stablilizer=False,
                 lambda_st=50,
                 theano_seed=None,
                 **kwargs):
        super(Decoder, self).__init__(**kwargs)
        self.vocab_size = vocab_size
        self.topicWord_size = topicWord_size
        self.embedding_dim = embedding_dim
        self.state_dim = state_dim
        self.representation_dim = representation_dim
        self.theano_seed = theano_seed

        # Initialize gru with special initial state
        self.transition = GRU(attended_dim=state_dim,
                              dim=state_dim,
                              activation=Tanh(),
                              name='decoder')

        self.energy_computer = globals()[match_function](name='energy_comp')

        # Initialize the attention mechanism
        self.attention = SequenceContentAttention(
            state_names=self.transition.apply.states,
            attended_dim=representation_dim,
            match_dim=state_dim,
            energy_computer=self.energy_computer,
            use_local_attention=use_local_attention,
            window_size=window_size,
            name="attention")

        self.topical_attention = SequenceContentAttention(
            state_names=self.transition.apply.states,
            attended_dim=topical_dim,
            match_dim=state_dim,
            energy_computer=self.energy_computer,
            use_local_attention=use_local_attention,
            window_size=window_size,
            name="topical_attention"
        )  #not sure whether the match dim would be correct.

        # Initialize the readout, note that SoftmaxEmitter emits -1 for
        # initial outputs which is used by LookupFeedBackWMT15
        readout = Readout(source_names=[
            'states', 'feedback', self.attention.take_glimpses.outputs[0]
        ],
                          readout_dim=self.vocab_size,
                          emitter=SoftmaxEmitter(initial_output=-1,
                                                 theano_seed=theano_seed),
                          feedback_brick=LookupFeedbackWMT15(
                              vocab_size, embedding_dim),
                          post_merge=InitializableFeedforwardSequence([
                              Bias(dim=state_dim, name='maxout_bias').apply,
                              Maxout(num_pieces=2, name='maxout').apply,
                              Linear(input_dim=state_dim / 2,
                                     output_dim=embedding_dim,
                                     use_bias=False,
                                     name='softmax0').apply,
                              Linear(input_dim=embedding_dim,
                                     name='softmax1').apply
                          ]),
                          merged_dim=state_dim,
                          name='readout')

        # calculate the readout of topic word,
        # no specific feedback brick, use the trival feedback break
        # no post_merge and merge, use Bias and Linear
        topicWordReadout = Readout(source_names=[
            'states', 'feedback', self.attention.take_glimpses.outputs[0]
        ],
                                   readout_dim=self.topicWord_size,
                                   emitter=SoftmaxEmitter(
                                       initial_output=-1,
                                       theano_seed=theano_seed),
                                   name='twReadout')

        # Build sequence generator accordingly
        self.sequence_generator = SequenceGenerator(
            readout=readout,
            topicWordReadout=topicWordReadout,
            topic_vector_names=['topicSumVector'],
            transition=self.transition,
            attention=self.attention,
            topical_attention=self.topical_attention,
            q_dim=self.state_dim,
            #q_name='topic_embedding',
            topical_name='topic_embedding',
            content_name='content_embedding',
            use_step_decay_cost=use_step_decay_cost,
            use_doubly_stochastic=use_doubly_stochastic,
            lambda_ds=lambda_ds,
            use_concentration_cost=use_concentration_cost,
            lambda_ct=lambda_ct,
            use_stablilizer=use_stablilizer,
            lambda_st=lambda_st,
            fork=Fork([
                name
                for name in self.transition.apply.sequences if name != 'mask'
            ],
                      prototype=Linear()))

        self.children = [self.sequence_generator]
コード例 #7
0
class RecurrentContext(Layer):
    """
    Attention operation, with a context/query vector, for temporal data.
    Supports Masking.  by using a windowed context vector to assist the attention
    # Input shape
        4D tensor with shape: `(samples, sentence, steps, features)`.
    # Output shape
        2D tensor with shape: `(samples, sentence, features)`.

    How to use:
    Just put it on top of an RNN Layer (GRU/LSTM/SimpleRNN) with return_sequences=True.
    The dimensions are inferred based on the output shape of the RNN.

    Note: The layer has been tested with Keras 2.0.6

    Example:
        model.add(LSTM(64, return_sequences=True))
        model.add(ContextAwareSelfAttentionWindow())
        # next add a Dense layer (for classification/regression) or whatever...
    """
    def __init__(self,
                 context_version="classic",
                 cut_gradient=False,
                 aggregate="sum",
                 discount=1,
                 return_coefficients=False,
                 W_regularizer=None,
                 W_context_regularizer=None,
                 u_regularizer=None,
                 b_regularizer=None,
                 W_constraint=None,
                 W_context_constraint=None,
                 u_constraint=None,
                 b_constraint=None,
                 bias=True,
                 **kwargs):

        self.context_version = context_version
        self.cut_gradient = cut_gradient
        self.aggregate = aggregate
        self.discount = discount
        self.supports_masking = True
        self.return_coefficients = return_coefficients
        self.init = initializers.get('glorot_uniform')

        self.W_regularizer = regularizers.get(W_regularizer)
        self.W_context_regularizer = regularizers.get(W_context_regularizer)
        self.u_regularizer = regularizers.get(u_regularizer)
        self.b_regularizer = regularizers.get(b_regularizer)

        self.W_constraint = constraints.get(W_constraint)
        self.W_context_constraint = constraints.get(W_context_constraint)
        self.u_constraint = constraints.get(u_constraint)
        self.b_constraint = constraints.get(b_constraint)

        self.bias = bias

        self.GRU = GRU(self, 100, 100)

        super(RecurrentContext, self).__init__(**kwargs)

    def build(self, input_shape):
        assert len(input_shape) == 4

        self.W = self.add_weight((
            input_shape[-1],
            input_shape[-1],
        ),
                                 initializer=self.init,
                                 name='{}_W'.format(self.name),
                                 regularizer=self.W_regularizer,
                                 constraint=self.W_constraint)

        self.W_context = self.add_weight(
            (
                input_shape[-1],
                input_shape[-1],
            ),
            initializer=self.init,
            name='{}_W_context'.format(self.name),
            regularizer=self.W_context_regularizer,
            constraint=self.W_context_constraint)

        if self.bias:
            self.b = self.add_weight((input_shape[-1], ),
                                     initializer='zero',
                                     name='{}_b'.format(self.name),
                                     regularizer=self.b_regularizer,
                                     constraint=self.b_constraint)

        self.u = self.add_weight((input_shape[-1], ),
                                 initializer=self.init,
                                 name='{}_u'.format(self.name),
                                 regularizer=self.u_regularizer,
                                 constraint=self.u_constraint)

        super(RecurrentContext, self).build(input_shape)

    def compute_mask(self, input, input_mask=None):
        # do not pass the mask to the next layers
        return None

    def call(self, x, mask=None):

        #self.GRU.reset_h()

        def compute_att(res, x_):
            context = res[0]

            uit = dot_product(x_, self.W)
            c = dot_product(context, self.W_context)

            if self.bias:
                uit += self.b

            uit = K.tanh(tf.add(uit, K.expand_dims(c, 1)))

            ait = dot_product(uit, self.u)
            a = K.exp(ait)
            # apply mask after the exp. will be re-normalized next
            if mask is not None:
                # Cast the mask to floatX to avoid float64 upcasting in theano
                #a *= K.cast(mask, K.floatx())
                pass
            # in some cases especially in the early stages of training the sum may be almost zero
            # and this results in NaN's. A workaround is to add a very small positive number ε to the sum.
            a /= K.cast(
                K.sum(a, axis=1, keepdims=True) + K.epsilon(), K.floatx())
            a_ = K.expand_dims(a)
            weighted_input = x_ * a_
            attended = K.sum(weighted_input, axis=1)

            context = self.GRU.forward_pass(context, attended)

            return [context, a]

        x_t = tf.transpose(x, [1, 0, 2, 3])

        output, weights = tf.scan(compute_att,
                                  x_t,
                                  initializer=[
                                      K.zeros_like(x_t[0, :, 0, :]),
                                      K.zeros_like(x_t[0, :, :, 0])
                                  ])

        output = tf.transpose(output, [1, 0, 2])
        weights = tf.transpose(weights, [1, 0, 2])

        if self.return_coefficients:
            return [output, weights]

        else:
            return [outputs]

    def compute_output_shape(self, input_shape):
        if self.return_coefficients:
            return [(input_shape[0], input_shape[1], input_shape[-1]),
                    (input_shape[0], input_shape[1], input_shape[-1], 1)]
        else:
            return [(input_shape[0], input_shape[1], input_shape[-1])]
コード例 #8
0
ファイル: Main.py プロジェクト: Alpaca-Man/NLP-Newcomer
hidden_size = 128  # 隐藏层维度
num_layers = 1  # GRU 层数
dropout = 0  # 失活率
bidirectional = False  # GRU 是否双向
lr = 1e-3  # 学习率
epoch = 100  # 训练次数

# 加载处理数据集
dataProcessor = DataProcessor(fileName, seq_len)
corpusContents, corpusLabels = dataProcessor.preTreatMent()

# GRU
model = GRU(dataProcessor.voc_size,
            emb_size,
            hidden_size,
            len(dataProcessor.label_types),
            num_layers=num_layers,
            dropout=dropout,
            bidirectional=bidirectional).to(device)
criterion = nn.CrossEntropyLoss().to(device)  # 损失函数
optimizer = optim.Adam(model.parameters(), lr=lr)  # 优化器

kf = KFold(n_splits=splits, shuffle=shuffle, random_state=None)
fold = 0

# 保留最好
bestFold = 1
bestScores = 0

for train, test in kf.split(corpusContents):
    fold += 1
コード例 #9
0
vocab = list(set(word_list))  # 单词表(不重复)
word2idx = {w: i for i, w in enumerate(vocab)}  # 单词索引
vocab_size = len(vocab)  # 单词个数

# 封装成数据集 加载器
trainInput, trainTarget = make_data(sentences, word2idx, labels)  # 用于把数据集处理成数组
trainInput, trainTarget = torch.LongTensor(trainInput).to(
    device), torch.LongTensor(trainTarget).to(device)
trainDataSet = Data.TensorDataset(trainInput, trainTarget)
trainDataLoader = Data.DataLoader(trainDataSet, batch_size, shuffle=True)  # 打乱

# GRU
model = GRU(vocab_size,
            emb_size,
            hidden_size,
            num_classes,
            num_layers=num_layers,
            dropout=dropout,
            bidirectional=bidirectional).to(device)
criterion = nn.CrossEntropyLoss().to(device)  # 损失函数
optimizer = optim.Adam(model.parameters(), lr=lr)  # 优化器

# Training
train(model, epoch, trainDataLoader, criterion, optimizer)

# testSet及处理
test_sentences = ["i hate me", "you love me"]
test_labels = [0, 1]
testInput, testTarget = make_data(test_sentences, word2idx, test_labels)
testInput = torch.LongTensor(testInput).to(device)
testTarget = torch.LongTensor(testTarget).to(device)
コード例 #10
0
else:
    device = torch.device("cpu")

input_size = len(idx_to_word)

# the number of hidden layers
hidden_size = 128

# a size of output tensor
output_size = 11

# the number of layers
num_layers = 3

# a deep learning model to use
model = GRU(input_size, hidden_size, output_size, batch_size, device,
            num_layers)

# loss function
criterion = nn.CrossEntropyLoss()

# optimizer with backpropagation
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# the number of iteration
n_iter = 1000

# print frequency
print_every = n_iter / 10

# plot frequency
plot_every = n_iter / 100
コード例 #11
0
class RecurrentNet:
    def __init__(self,
                 index2word,
                 emb_size,
                 class_inv_freq,
                 hidden_size,
                 bidirectional,
                 learning_rate,
                 model_save_path,
                 pretrained_path=""):
        self.softmax = nn.Softmax()
        self.model = GRU(index2word,
                         emb_size,
                         hidden_size,
                         len(class_inv_freq),
                         bidirectional,
                         pretrained_path=pretrained_path)
        self.optimizer = optim.Adam(self.model.parameters(), lr=learning_rate)
        self.criterion = nn.CrossEntropyLoss()
        #self.criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor(class_inv_freq))
        self.model_save_path = model_save_path

    def fit(self, train_data, dev_data, num_epochs):
        self.trainEpoch(train_data, dev_data, num_epochs)

    def predict(self, test_data):
        self.model = torch.load(self.model_save_path)
        return self.evaluate(test_data, test=True)

    def train(self, train_data):
        self.model.train()
        iter_loss = 0
        for b in range(train_data.getNumBatches()):
            batch_size = train_data.getBatch(b)[0].shape[0]
            self.optimizer.zero_grad()
            x_b = torch.from_numpy(train_data.getBatch(b)[0]).long()
            y_b = torch.from_numpy(train_data.getBatch(b)[1][:, 0]).long()

            x_var = Variable(x_b)
            y_var = Variable(y_b)

            h = self.model.init_hidden(batch_size)

            y_prob, h = self.model(x_var, h)

            loss = self.criterion(y_prob, y_var)
            loss.backward()

            self.optimizer.step()
            iter_loss += loss.data[0]
        return iter_loss

    def evaluate(self, data, test=False):
        self.model.eval()
        y_pred = torch.zeros(data.num_instances, 1)
        y_test = np.zeros((data.num_instances, 1))
        count = 0
        for b in range(data.getNumBatches()):
            batch_size = data.getBatch(b)[0].shape[0]
            x_b = torch.from_numpy(data.getBatch(b)[0]).long()
            y_b = torch.from_numpy(data.getBatch(b)[1][:, 0]).long()

            x_var = Variable(x_b)
            y_var = Variable(y_b)

            h = self.model.init_hidden(batch_size)
            out, h = self.model(x_var, h)
            y_prob = self.softmax(out)
            _, preds = torch.max(y_prob.data, dim=1)
            y_test[count:count + batch_size, :] = data.getBatch(b)[1]
            y_pred[count:count + batch_size, :] = preds
            count = count + batch_size
        if test:
            return y_test, y_pred.numpy()
        else:
            corrects = (y_test == y_pred.numpy()).sum()
            return 1.0 * corrects / data.num_instances

    def trainEpoch(self, train_data, dev_data, num_epochs):
        best_acc = 0.0
        best_epoch = None
        for epoch in range(num_epochs):
            print('Epoch {}'.format(epoch), end=', ')
            train_data.shuffleBatches()
            iter_loss = self.train(train_data)
            print('training Loss: {:.3}'.format(iter_loss), end=', ')
            train_acc = self.evaluate(train_data)
            print('train accuracy: {}'.format(train_acc), end=', ')
            dev_acc = self.evaluate(dev_data)
            print('dev accuracy: {}'.format(dev_acc))
            if dev_acc >= best_acc:
                best_acc = dev_acc
                best_epoch = epoch
                torch.save(self.model, self.model_save_path)
        print("Saved the best model. (epoch " + str(best_epoch) + ")")
コード例 #12
0
class SentenceEncoder(Initializable):
    """Encoder of RNNsearch model."""
    def __init__(self,
                 embedding_dim,
                 state_dim,
                 use_local_attention=False,
                 window_size=10,
                 **kwargs):
        super(SentenceEncoder, self).__init__(**kwargs)

        self.embedding_dim = embedding_dim
        self.state_dim = state_dim
        self.rnn = GRU(activation=Tanh(),
                       dim=state_dim,
                       attended_dim=embedding_dim)
        self.input_fork = Fork(
            [name for name in self.rnn.apply.sequences if name != 'mask'],
            prototype=Linear(),
            name='input_fork')
        self.energy_computer = SumMatchFunction_posTag(
            name="wordAtt_energy_comp")
        self.attention = SequenceContentAttention_withExInput(
            state_names=['states'],
            state_dims=[state_dim],
            attended_dim=embedding_dim,
            match_dim=state_dim,
            posTag_dim=self.state_dim,
            energy_computer=self.energy_computer,
            use_local_attention=use_local_attention,
            window_size=window_size,
            name="word_attention")

        self.children = [self.rnn, self.input_fork, self.attention]

    def _push_allocation_config(self):

        self.input_fork.input_dim = self.embedding_dim
        self.input_fork.output_dims = [
            self.rnn.get_dim(name) for name in self.input_fork.output_names
        ]
        self.attention.state_dims = [self.state_dim]
        self.attention.state_dim = self.state_dim

    @recurrent(sequences=[
        'attended', 'preprocessed_attended', 'attended_mask', 'mask'
    ],
               states=['states'],
               outputs=['states'],
               contexts=['decoder_states'])
    def do_apply(self,
                 attended,
                 preprocessed_attended,
                 attended_mask,
                 decoder_states,
                 states,
                 mask=None):
        current_glimpses = self.attention.take_glimpses(
            attended, states, preprocessed_attended, attended_mask, states,
            **{'states': decoder_states})
        inputs = merge(
            self.input_fork.apply(current_glimpses[0], as_dict=True),
            {'states': states})

        next_states = self.rnn.apply(iterate=False, **inputs)
        if mask:
            next_states = (mask[:, None] * next_states +
                           (1 - mask[:, None]) * states)
        return next_states

    @application(inputs=[
        'attended', 'preprocessed_attended', 'attended_mask', 'decoder_states',
        'mask'
    ],
                 outputs=['cxt_representation'])
    def apply(self,
              attended,
              preprocessed_attended,
              attended_mask,
              decoder_states,
              mask=None):
        # Time as first dimension
        mask = mask.T
        cxt_representation = self.do_apply(attended,
                                           preprocessed_attended,
                                           attended_mask,
                                           decoder_states,
                                           mask=mask)

        return cxt_representation

    def get_dim(self, name):
        if name == 'mask':
            return 0
        if name in ['states']:
            return self.state_dim

    @application(inputs=['attended'], outputs=['preprocessed_attended'])
    def preprocess(self, attended):
        """Preprocess the sequence for computing attention weights.

        Parameters
        ----------
        attended : :class:`~tensor.TensorVariable`
            The attended sequence, time is the 1-st dimension.

        """
        return self.attention.preprocess(attended)

    @application(outputs=do_apply.states)
    def initial_states(self, batch_size, *args, **kwargs):
        attended = kwargs['attended']
        initial_state = attended[0, 0, :, -self.state_dim:]
        return initial_state