Exemplo n.º 1
0
    def _pop_op(everything, accum, everything_max = None,
            everything_min = None, word = None, aword = None,
            one_step=False, use_noise=True):

        rval = proj_h[0](accum[0], one_step=one_step, use_noise=use_noise)
        for si in xrange(1,state['decoder_stack']):
            rval += proj_h[si](accum[si], one_step=one_step, use_noise=use_noise)
        if state['mult_out']:
            rval = rval * everything
        else:
            rval = rval + everything

        if aword and state['avg_word']:
            wcode = aword
            if one_step:
                if state['mult_out']:
                    rval = rval * wcode
                else:
                    rval = rval + wcode
            else:
                if not isinstance(wcode, TT.TensorVariable):
                    wcode = wcode.out
                shape = wcode.shape
                rshape = rval.shape
                rval = rval.reshape([rshape[0]/shape[0], shape[0], rshape[1]])
                wcode = wcode.dimshuffle('x', 0, 1)
                if state['mult_out']:
                    rval = rval * wcode
                else:
                    rval = rval + wcode
                rval = rval.reshape(rshape)
        if word and state['bigram']:
            if one_step:
                if state['mult_out']:
                    rval *= proj_word(emb_t(word, use_noise=use_noise), 
                            one_step=one_step, use_noise=use_noise)
                else:
                    rval += proj_word(emb_t(word, use_noise=use_noise), 
                            one_step=one_step, use_noise=use_noise)
            else:
                if isinstance(word, TT.TensorVariable):
                    shape = word.shape
                    ndim = word.ndim
                else:
                    shape = word.shape
                    ndim = word.out.ndim
                pword = proj_word(emb_t(word, use_noise=use_noise), 
                        one_step=one_step, use_noise=use_noise)
                shape_pword = pword.shape
                if ndim == 1:
                    pword = Shift()(pword.reshape([shape[0], 1, outdim]))
                else:
                    pword = Shift()(pword.reshape([shape[0], shape[1], outdim]))
                if state['mult_out']:
                    rval *= pword.reshape(shape_pword)
                else:
                    rval += pword.reshape(shape_pword)
        if state['deep_out']:
            rval = drop_layer(act_layer(rval), use_noise=use_noise)
        return rval
Exemplo n.º 2
0
    def build_decoder(self, c, y,
            c_mask=None,
            y_mask=None,
            step_num=None,
            mode=EVALUATION,
            given_init_states=None,
            T=1):
        """Create the computational graph of the RNN Decoder.
        A graph denoting the relation from context vector c to the target cost/samples.
        It works in two mode: evaluation and sampling/beamsearch.
        In the evaluation mode (for training pairs), it compute the graph from c to the cost from x to y.
        In the sampling/beamsearch mode(for sampling), it compute the graph from c to target samples y.

        :param c:
            representations produced by an encoder.
            (n_samples, dim) matrix if mode == sampling or
            (max_seq_len, batch_size, dim) matrix if mode == evaluation

        :param c_mask:
            if mode == evaluation a 0/1 matrix identifying valid positions in c

        :param y:
            if mode == evaluation
                target sequences, matrix of word indices of shape (max_seq_len, batch_size),
                where each column is a sequence
            if mode != evaluation
                a vector of previous words of shape (n_samples,)

        :param y_mask:
            if mode == evaluation a 0/1 matrix determining lengths
                of the target sequences, must be None otherwise

        :param mode:
            chooses on of three modes: evaluation, sampling and beam_search
                where evaluation means given both input and output sequences, sampling means given input and predict target

        :param given_init_states:
            for sampling and beam_search. A list of hidden states
                matrices for each layer, each matrix is (n_samples, dim)

        :param T:
            sampling temperature
        """

        # Check parameter consistency
        if mode == Decoder.EVALUATION:
            assert not given_init_states
        else:
            assert not y_mask
            assert given_init_states
            if mode == Decoder.BEAM_SEARCH:
                assert T == 1

        # For log-likelihood evaluation the representation
        # be replicated for conveniency. In case backward RNN is used
        # it is not done.
        # Shape if mode == evaluation
        #   (max_seq_len, batch_size, dim)
        # Shape if mode != evaluation
        #   (n_samples, dim)
        if not self.state['search']:
            if mode == Decoder.EVALUATION:
                c = PadLayer(y.shape[0])(c)
            else:
                assert step_num
                c_pos = TT.minimum(step_num, c.shape[0] - 1)

        # Low rank embeddings of all the input words.
        # Shape if mode == evaluation
        #   (n_words, rank_n_approx),
        # Shape if mode != evaluation
        #   (n_samples, rank_n_approx)
        approx_embeddings = self.approx_embedder(y)

        # Low rank embeddings are projected to contribute
        # to input, reset and update signals.
        # All the shapes if mode == evaluation:
        #   (n_words, dim)
        # where: n_words = max_seq_len * batch_size
        # All the shape if mode != evaluation:
        #   (n_samples, dim)
        input_signals = []
        reset_signals = []
        update_signals = []
        for level in range(self.num_levels):
            # Contributions directly from input words.
            input_signals.append(self.input_embedders[level](approx_embeddings))
            update_signals.append(self.update_embedders[level](approx_embeddings))
            reset_signals.append(self.reset_embedders[level](approx_embeddings))

            # Contributions from the encoded source sentence.
            if not self.state['search']:
                current_c = c if mode == Decoder.EVALUATION else c[c_pos]
                input_signals[level] += self.decode_inputers[level](current_c)
                update_signals[level] += self.decode_updaters[level](current_c)
                reset_signals[level] += self.decode_reseters[level](current_c)

        # Hidden layers' initial states.
        # Shapes if mode == evaluation:
        #   (batch_size, dim)
        # Shape if mode != evaluation:
        #   (n_samples, dim)
        init_states = given_init_states
        if not init_states:
            init_states = []
            for level in range(self.num_levels):
                init_c = c[0, :, -self.state['dim']:]
                init_states.append(self.initializers[level](init_c))

        # Hidden layers' states.
        # Shapes if mode == evaluation:
        #  (seq_len, batch_size, dim)
        # Shapes if mode != evaluation:
        #  (n_samples, dim)
        hidden_layers = []
        contexts = []
        # Default value for alignment must be smth computable
        alignment = TT.zeros((1,))
        for level in range(self.num_levels):
            if level > 0:
                input_signals[level] += self.inputers[level](hidden_layers[level - 1])
                update_signals[level] += self.updaters[level](hidden_layers[level - 1])
                reset_signals[level] += self.reseters[level](hidden_layers[level - 1])
            add_kwargs = (dict(state_before=init_states[level])
                        if mode != Decoder.EVALUATION
                        else dict(init_state=init_states[level],
                            batch_size=y.shape[1] if y.ndim == 2 else 1,
                            nsteps=y.shape[0]))
            if self.state['search']:
                add_kwargs['c'] = c
                add_kwargs['c_mask'] = c_mask
                add_kwargs['return_alignment'] = self.compute_alignment
                if mode != Decoder.EVALUATION:
                    add_kwargs['step_num'] = step_num
            result = self.transitions[level](
                    input_signals[level],
                    mask=y_mask,
                    gater_below=none_if_zero(update_signals[level]),
                    reseter_below=none_if_zero(reset_signals[level]),
                    one_step=mode != Decoder.EVALUATION,
                    use_noise=mode == Decoder.EVALUATION,
                    **add_kwargs)
            if self.state['search']:
                if self.compute_alignment:
                    #This implicitly wraps each element of result.out with a Layer to keep track of the parameters.
                    #It is equivalent to h=result[0], ctx=result[1] etc. 
                    h, ctx, alignment = result
                    if mode == Decoder.EVALUATION:
                        alignment = alignment.out
                else:
                    #This implicitly wraps each element of result.out with a Layer to keep track of the parameters.
                    #It is equivalent to h=result[0], ctx=result[1]
                    h, ctx = result
            else:
                h = result
                if mode == Decoder.EVALUATION:
                    ctx = c
                else:
                    ctx = ReplicateLayer(given_init_states[0].shape[0])(c[c_pos]).out
            hidden_layers.append(h)
            contexts.append(ctx)

        # In hidden_layers we do no have the initial state, but we need it.
        # Instead of it we have the last one, which we do not need.
        # So what we do is discard the last one and prepend the initial one.
        if mode == Decoder.EVALUATION:
            for level in range(self.num_levels):
                hidden_layers[level].out = TT.concatenate([
                    TT.shape_padleft(init_states[level].out),
                        hidden_layers[level].out])[:-1]

        # The output representation to be fed in softmax.
        # Shape if mode == evaluation
        #   (n_words, dim_r)
        # Shape if mode != evaluation
        #   (n_samples, dim_r)
        # ... where dim_r depends on 'deep_out' option.
        readout = self.repr_readout(contexts[0])
        for level in range(self.num_levels):
            if mode != Decoder.EVALUATION:
                read_from = init_states[level]
            else:
                read_from = hidden_layers[level]
            read_from_var = read_from if type(read_from) == theano.tensor.TensorVariable else read_from.out
            if read_from_var.ndim == 3:
                read_from_var = read_from_var[:,:,:self.state['dim']]
            else:
                read_from_var = read_from_var[:,:self.state['dim']]
            if type(read_from) != theano.tensor.TensorVariable:
                read_from.out = read_from_var
            else:
                read_from = read_from_var
            readout += self.hidden_readouts[level](read_from)
        if self.state['bigram']:
            if mode != Decoder.EVALUATION:
                check_first_word = (y > 0
                    if self.state['check_first_word']
                    else TT.ones((y.shape[0]), dtype="float32"))
                # padright is necessary as we want to multiply each row with a certain scalar
                readout += TT.shape_padright(check_first_word) * self.prev_word_readout(approx_embeddings).out
            else:
                if y.ndim == 1:
                    readout += Shift()(self.prev_word_readout(approx_embeddings).reshape(
                        (y.shape[0], 1, self.state['dim'])))
                else:
                    # This place needs explanation. When prev_word_readout is applied to
                    # approx_embeddings the resulting shape is
                    # (n_batches * sequence_length, repr_dimensionality). We first
                    # transform it into 3D tensor to shift forward in time. Then
                    # reshape it back.
                    readout += Shift()(self.prev_word_readout(approx_embeddings).reshape(
                        (y.shape[0], y.shape[1], self.state['dim']))).reshape(
                                readout.out.shape)
        for fun in self.output_nonlinearities:
            readout = fun(readout)

        if mode == Decoder.SAMPLING:
            sample = self.output_layer.get_sample(
                    state_below=readout,
                    temp=T)
            # Current SoftmaxLayer.get_cost is stupid,
            # that's why we have to reshape a lot.
            self.output_layer.get_cost(
                    state_below=readout.out,
                    temp=T,
                    target=sample)
            log_prob = self.output_layer.cost_per_sample
            return [sample] + [log_prob] + hidden_layers
        elif mode == Decoder.BEAM_SEARCH:
            return self.output_layer(
                    state_below=readout.out,
                    temp=T).out
        elif mode == Decoder.EVALUATION:#return cost expression w.r.t x and y from the output layer
            return (self.output_layer.train(#see Layer.train at /Groundhog/layers/basic.py. 
                                            #The actual output_layer for evaluation is the CostLayer/SoftmaxLayer at costlayers.py
                                            #The cost expression is constructed there in the SoftmaxLayer class 
                    state_below=readout,
                    target=y,
                    mask=y_mask,
                    reg=None),#reg is extra regularation expression for the cost
                    alignment
                    )
        else:
            raise Exception("Unknown mode for build_decoder")
Exemplo n.º 3
0
    def _pop_op(everything,
                accum,
                everything_max=None,
                everything_min=None,
                word=None,
                aword=None,
                one_step=False,
                use_noise=True):

        rval = proj_h[0](accum[0], one_step=one_step, use_noise=use_noise)
        for si in xrange(1, state['decoder_stack']):
            rval += proj_h[si](accum[si],
                               one_step=one_step,
                               use_noise=use_noise)
        if state['mult_out']:
            rval = rval * everything
        else:
            rval = rval + everything

        if aword and state['avg_word']:
            wcode = aword
            if one_step:
                if state['mult_out']:
                    rval = rval * wcode
                else:
                    rval = rval + wcode
            else:
                if not isinstance(wcode, TT.TensorVariable):
                    wcode = wcode.out
                shape = wcode.shape
                rshape = rval.shape
                rval = rval.reshape(
                    [rshape[0] / shape[0], shape[0], rshape[1]])
                wcode = wcode.dimshuffle('x', 0, 1)
                if state['mult_out']:
                    rval = rval * wcode
                else:
                    rval = rval + wcode
                rval = rval.reshape(rshape)
        if word and state['bigram']:
            if one_step:
                if state['mult_out']:
                    rval *= proj_word(emb_t(word, use_noise=use_noise),
                                      one_step=one_step,
                                      use_noise=use_noise)
                else:
                    rval += proj_word(emb_t(word, use_noise=use_noise),
                                      one_step=one_step,
                                      use_noise=use_noise)
            else:
                if isinstance(word, TT.TensorVariable):
                    shape = word.shape
                    ndim = word.ndim
                else:
                    shape = word.shape
                    ndim = word.out.ndim
                pword = proj_word(emb_t(word, use_noise=use_noise),
                                  one_step=one_step,
                                  use_noise=use_noise)
                shape_pword = pword.shape
                if ndim == 1:
                    pword = Shift()(pword.reshape([shape[0], 1, outdim]))
                else:
                    pword = Shift()(pword.reshape([shape[0], shape[1],
                                                   outdim]))
                if state['mult_out']:
                    rval *= pword.reshape(shape_pword)
                else:
                    rval += pword.reshape(shape_pword)
        if state['deep_out']:
            rval = drop_layer(act_layer(rval), use_noise=use_noise)
        return rval
Exemplo n.º 4
0
def jobman(state, channel):
    # load dataset
    state['null_sym_source'] = 15000
    state['null_sym_target'] = 15000
    state['n_sym_source'] = state['null_sym_source'] + 1
    state['n_sym_target'] = state['null_sym_target'] + 1

    state['nouts'] = state['n_sym_target']
    state['nins'] = state['n_sym_source']
    rng = numpy.random.RandomState(state['seed'])
    if state['loopIters'] > 0:
        train_data, valid_data, test_data = get_data(state)
    else:
        train_data = None
        valid_data = None
        test_data = None

    ########### Training graph #####################
    ## 1. Inputs
    if state['bs'] == 1:
        x = TT.lvector('x')
        x_mask = TT.vector('x_mask')
        y = TT.lvector('y')
        y0 = y
        y_mask = TT.vector('y_mask')
    else:
        x = TT.lmatrix('x')
        x_mask = TT.matrix('x_mask')
        y = TT.lmatrix('y')
        y0 = y
        y_mask = TT.matrix('y_mask')

    # 2. Layers and Operators
    bs = state['bs']

    embdim = state['dim_mlp']

    # Source Sentence
    emb = MultiLayer(rng,
                     n_in=state['nins'],
                     n_hids=[state['rank_n_approx']],
                     activation=[state['rank_n_activ']],
                     init_fn=state['weight_init_fn'],
                     weight_noise=state['weight_noise'],
                     scale=state['weight_scale'],
                     name='emb')

    emb_words = []
    if state['rec_gating']:
        gater_words = []
    if state['rec_reseting']:
        reseter_words = []
    for si in xrange(state['encoder_stack']):
        emb_words.append(
            MultiLayer(rng,
                       n_in=state['rank_n_approx'],
                       n_hids=[embdim],
                       activation=['lambda x:x'],
                       init_fn=state['weight_init_fn'],
                       weight_noise=state['weight_noise'],
                       scale=state['weight_scale'],
                       name='emb_words_%d' % si))
        if state['rec_gating']:
            gater_words.append(
                MultiLayer(rng,
                           n_in=state['rank_n_approx'],
                           n_hids=[state['dim']],
                           activation=['lambda x:x'],
                           init_fn=state['weight_init_fn'],
                           weight_noise=state['weight_noise'],
                           scale=state['weight_scale'],
                           learn_bias=False,
                           name='gater_words_%d' % si))
        if state['rec_reseting']:
            reseter_words.append(
                MultiLayer(rng,
                           n_in=state['rank_n_approx'],
                           n_hids=[state['dim']],
                           activation=['lambda x:x'],
                           init_fn=state['weight_init_fn'],
                           weight_noise=state['weight_noise'],
                           scale=state['weight_scale'],
                           learn_bias=False,
                           name='reseter_words_%d' % si))

    add_rec_step = []
    rec_proj = []
    if state['rec_gating']:
        rec_proj_gater = []
    if state['rec_reseting']:
        rec_proj_reseter = []
    for si in xrange(state['encoder_stack']):
        if si > 0:
            rec_proj.append(
                MultiLayer(rng,
                           n_in=state['dim'],
                           n_hids=[embdim],
                           activation=['lambda x:x'],
                           init_fn=state['rec_weight_init_fn'],
                           weight_noise=state['weight_noise'],
                           scale=state['rec_weight_scale'],
                           name='rec_proj_%d' % si))
            if state['rec_gating']:
                rec_proj_gater.append(
                    MultiLayer(rng,
                               n_in=state['dim'],
                               n_hids=[state['dim']],
                               activation=['lambda x:x'],
                               init_fn=state['weight_init_fn'],
                               weight_noise=state['weight_noise'],
                               scale=state['weight_scale'],
                               learn_bias=False,
                               name='rec_proj_gater_%d' % si))
            if state['rec_reseting']:
                rec_proj_reseter.append(
                    MultiLayer(rng,
                               n_in=state['dim'],
                               n_hids=[state['dim']],
                               activation=['lambda x:x'],
                               init_fn=state['weight_init_fn'],
                               weight_noise=state['weight_noise'],
                               scale=state['weight_scale'],
                               learn_bias=False,
                               name='rec_proj_reseter_%d' % si))

        add_rec_step.append(
            eval(state['rec_layer'])(rng,
                                     n_hids=state['dim'],
                                     activation=state['activ'],
                                     bias_scale=state['bias'],
                                     scale=state['rec_weight_scale'],
                                     init_fn=state['rec_weight_init_fn'],
                                     weight_noise=state['weight_noise_rec'],
                                     dropout=state['dropout_rec'],
                                     gating=state['rec_gating'],
                                     gater_activation=state['rec_gater'],
                                     reseting=state['rec_reseting'],
                                     reseter_activation=state['rec_reseter'],
                                     name='add_h_%d' % si))

    def _add_op(words_embeddings,
                words_mask=None,
                prev_val=None,
                si=0,
                state_below=None,
                gater_below=None,
                reseter_below=None,
                one_step=False,
                bs=1,
                init_state=None,
                use_noise=True):
        seqlen = words_embeddings.out.shape[0] // bs
        rval = words_embeddings
        gater = None
        reseter = None
        if state['rec_gating']:
            gater = gater_below
        if state['rec_reseting']:
            reseter = reseter_below
        if si > 0:
            rval += rec_proj[si - 1](state_below,
                                     one_step=one_step,
                                     use_noise=use_noise)
            if state['rec_gating']:
                projg = rec_proj_gater[si - 1](state_below,
                                               one_step=one_step,
                                               use_noise=use_noise)
                if gater: gater += projg
                else: gater = projg
            if state['rec_reseting']:
                projg = rec_proj_reseter[si - 1](state_below,
                                                 one_step=one_step,
                                                 use_noise=use_noise)
                if reseter: reseter += projg
                else: reseter = projg

        if not one_step:
            rval = add_rec_step[si](rval,
                                    nsteps=seqlen,
                                    batch_size=bs,
                                    mask=words_mask,
                                    gater_below=gater,
                                    reseter_below=reseter,
                                    one_step=one_step,
                                    init_state=init_state,
                                    use_noise=use_noise)
        else:
            rval = add_rec_step[si](rval,
                                    mask=words_mask,
                                    state_before=prev_val,
                                    gater_below=gater,
                                    reseter_below=reseter,
                                    one_step=one_step,
                                    init_state=init_state,
                                    use_noise=use_noise)
        return rval

    add_op = Operator(_add_op)

    # Target Sentence
    emb_t = MultiLayer(rng,
                       n_in=state['nouts'],
                       n_hids=[state['rank_n_approx']],
                       activation=[state['rank_n_activ']],
                       init_fn=state['weight_init_fn'],
                       weight_noise=state['weight_noise'],
                       scale=state['weight_scale'],
                       name='emb_t')

    emb_words_t = []
    if state['rec_gating']:
        gater_words_t = []
    if state['rec_reseting']:
        reseter_words_t = []
    for si in xrange(state['decoder_stack']):
        emb_words_t.append(
            MultiLayer(rng,
                       n_in=state['rank_n_approx'],
                       n_hids=[embdim],
                       activation=['lambda x:x'],
                       init_fn=state['weight_init_fn'],
                       weight_noise=state['weight_noise'],
                       scale=state['weight_scale'],
                       name='emb_words_t_%d' % si))
        if state['rec_gating']:
            gater_words_t.append(
                MultiLayer(rng,
                           n_in=state['rank_n_approx'],
                           n_hids=[state['dim']],
                           activation=['lambda x:x'],
                           init_fn=state['weight_init_fn'],
                           weight_noise=state['weight_noise'],
                           scale=state['weight_scale'],
                           learn_bias=False,
                           name='gater_words_t_%d' % si))
        if state['rec_reseting']:
            reseter_words_t.append(
                MultiLayer(rng,
                           n_in=state['rank_n_approx'],
                           n_hids=[state['dim']],
                           activation=['lambda x:x'],
                           init_fn=state['weight_init_fn'],
                           weight_noise=state['weight_noise'],
                           scale=state['weight_scale'],
                           learn_bias=False,
                           name='reseter_words_t_%d' % si))

    proj_everything_t = []
    if state['rec_gating']:
        gater_everything_t = []
    if state['rec_reseting']:
        reseter_everything_t = []
    for si in xrange(state['decoder_stack']):
        proj_everything_t.append(
            MultiLayer(rng,
                       n_in=state['dim'],
                       n_hids=[embdim],
                       activation=['lambda x:x'],
                       init_fn=state['weight_init_fn'],
                       weight_noise=state['weight_noise'],
                       scale=state['weight_scale'],
                       name='proj_everything_t_%d' % si,
                       learn_bias=False))
        if state['rec_gating']:
            gater_everything_t.append(
                MultiLayer(rng,
                           n_in=state['dim'],
                           n_hids=[state['dim']],
                           activation=['lambda x:x'],
                           init_fn=state['weight_init_fn'],
                           weight_noise=state['weight_noise'],
                           scale=state['weight_scale'],
                           name='gater_everything_t_%d' % si,
                           learn_bias=False))
        if state['rec_reseting']:
            reseter_everything_t.append(
                MultiLayer(rng,
                           n_in=state['dim'],
                           n_hids=[state['dim']],
                           activation=['lambda x:x'],
                           init_fn=state['weight_init_fn'],
                           weight_noise=state['weight_noise'],
                           scale=state['weight_scale'],
                           name='reseter_everything_t_%d' % si,
                           learn_bias=False))

    add_rec_step_t = []
    rec_proj_t = []
    if state['rec_gating']:
        rec_proj_t_gater = []
    if state['rec_reseting']:
        rec_proj_t_reseter = []
    for si in xrange(state['decoder_stack']):
        if si > 0:
            rec_proj_t.append(
                MultiLayer(rng,
                           n_in=state['dim'],
                           n_hids=[embdim],
                           activation=['lambda x:x'],
                           init_fn=state['rec_weight_init_fn'],
                           weight_noise=state['weight_noise'],
                           scale=state['rec_weight_scale'],
                           name='rec_proj_%d' % si))
            if state['rec_gating']:
                rec_proj_t_gater.append(
                    MultiLayer(rng,
                               n_in=state['dim'],
                               n_hids=[state['dim']],
                               activation=['lambda x:x'],
                               init_fn=state['weight_init_fn'],
                               weight_noise=state['weight_noise'],
                               scale=state['weight_scale'],
                               learn_bias=False,
                               name='rec_proj_t_gater_%d' % si))
            if state['rec_reseting']:
                rec_proj_t_reseter.append(
                    MultiLayer(rng,
                               n_in=state['dim'],
                               n_hids=[state['dim']],
                               activation=['lambda x:x'],
                               init_fn=state['weight_init_fn'],
                               weight_noise=state['weight_noise'],
                               scale=state['weight_scale'],
                               learn_bias=False,
                               name='rec_proj_t_reseter_%d' % si))

        add_rec_step_t.append(
            eval(state['rec_layer'])(rng,
                                     n_hids=state['dim'],
                                     activation=state['activ'],
                                     bias_scale=state['bias'],
                                     scale=state['rec_weight_scale'],
                                     init_fn=state['rec_weight_init_fn'],
                                     weight_noise=state['weight_noise_rec'],
                                     dropout=state['dropout_rec'],
                                     gating=state['rec_gating'],
                                     gater_activation=state['rec_gater'],
                                     reseting=state['rec_reseting'],
                                     reseter_activation=state['rec_reseter'],
                                     name='add_h_t_%d' % si))

    if state['encoder_stack'] > 1:
        encoder_proj = []
        for si in xrange(state['encoder_stack']):
            encoder_proj.append(
                MultiLayer(rng,
                           n_in=state['dim'],
                           n_hids=[state['dim'] * state['maxout_part']],
                           activation=['lambda x: x'],
                           init_fn=state['weight_init_fn'],
                           weight_noise=state['weight_noise'],
                           scale=state['weight_scale'],
                           name='encoder_proj_%d' % si,
                           learn_bias=(si == 0)))

        encoder_act_layer = UnaryOp(activation=eval(state['unary_activ']),
                                    indim=indim,
                                    pieces=pieces,
                                    rng=rng)

    def _add_t_op(words_embeddings,
                  everything=None,
                  words_mask=None,
                  prev_val=None,
                  one_step=False,
                  bs=1,
                  init_state=None,
                  use_noise=True,
                  gater_below=None,
                  reseter_below=None,
                  si=0,
                  state_below=None):
        seqlen = words_embeddings.out.shape[0] // bs

        rval = words_embeddings
        gater = None
        if state['rec_gating']:
            gater = gater_below
        reseter = None
        if state['rec_reseting']:
            reseter = reseter_below
        if si > 0:
            if isinstance(state_below, list):
                state_below = state_below[-1]
            rval += rec_proj_t[si - 1](state_below,
                                       one_step=one_step,
                                       use_noise=use_noise)
            if state['rec_gating']:
                projg = rec_proj_t_gater[si - 1](state_below,
                                                 one_step=one_step,
                                                 use_noise=use_noise)
                if gater: gater += projg
                else: gater = projg
            if state['rec_reseting']:
                projg = rec_proj_t_reseter[si - 1](state_below,
                                                   one_step=one_step,
                                                   use_noise=use_noise)
                if reseter: reseter += projg
                else: reseter = projg
        if everything:
            rval = rval + proj_everything_t[si](everything)
            if state['rec_gating']:
                everyg = gater_everything_t[si](everything,
                                                one_step=one_step,
                                                use_noise=use_noise)
                if gater: gater += everyg
                else: gater = everyg
            if state['rec_reseting']:
                everyg = reseter_everything_t[si](everything,
                                                  one_step=one_step,
                                                  use_noise=use_noise)
                if reseter: reseter += everyg
                else: reseter = everyg

        if not one_step:
            rval = add_rec_step_t[si](rval,
                                      nsteps=seqlen,
                                      batch_size=bs,
                                      mask=words_mask,
                                      one_step=one_step,
                                      init_state=init_state,
                                      gater_below=gater,
                                      reseter_below=reseter,
                                      use_noise=use_noise)
        else:
            rval = add_rec_step_t[si](rval,
                                      mask=words_mask,
                                      state_before=prev_val,
                                      one_step=one_step,
                                      gater_below=gater,
                                      reseter_below=reseter,
                                      use_noise=use_noise)
        return rval

    add_t_op = Operator(_add_t_op)

    outdim = state['dim_mlp']
    if not state['deep_out']:
        outdim = state['rank_n_approx']

    if state['bias_code']:
        bias_code = []
        for si in xrange(state['decoder_stack']):
            bias_code.append(
                MultiLayer(rng,
                           n_in=state['dim'],
                           n_hids=[state['dim']],
                           activation=[state['activ']],
                           bias_scale=[state['bias']],
                           scale=state['weight_scale'],
                           init_fn=state['weight_init_fn'],
                           weight_noise=state['weight_noise'],
                           name='bias_code_%d' % si))

    if state['avg_word']:
        word_code_nin = state['rank_n_approx']
        word_code = MultiLayer(rng,
                               n_in=word_code_nin,
                               n_hids=[outdim],
                               activation='lambda x:x',
                               bias_scale=[state['bias_mlp'] / 3],
                               scale=state['weight_scale'],
                               init_fn=state['weight_init_fn'],
                               weight_noise=state['weight_noise'],
                               learn_bias=False,
                               name='word_code')

    proj_code = MultiLayer(rng,
                           n_in=state['dim'],
                           n_hids=[outdim],
                           activation='lambda x: x',
                           bias_scale=[state['bias_mlp'] / 3],
                           scale=state['weight_scale'],
                           init_fn=state['weight_init_fn'],
                           weight_noise=state['weight_noise'],
                           learn_bias=False,
                           name='proj_code')

    proj_h = []
    for si in xrange(state['decoder_stack']):
        proj_h.append(
            MultiLayer(rng,
                       n_in=state['dim'],
                       n_hids=[outdim],
                       activation='lambda x: x',
                       bias_scale=[state['bias_mlp'] / 3],
                       scale=state['weight_scale'],
                       init_fn=state['weight_init_fn'],
                       weight_noise=state['weight_noise'],
                       name='proj_h_%d' % si))

    if state['bigram']:
        proj_word = MultiLayer(rng,
                               n_in=state['rank_n_approx'],
                               n_hids=[outdim],
                               activation=['lambda x:x'],
                               bias_scale=[state['bias_mlp'] / 3],
                               init_fn=state['weight_init_fn'],
                               weight_noise=state['weight_noise'],
                               scale=state['weight_scale'],
                               learn_bias=False,
                               name='emb_words_lm')

    if state['deep_out']:
        indim = 0
        pieces = 0
        act_layer = UnaryOp(activation=eval(state['unary_activ']))
        drop_layer = DropOp(rng=rng, dropout=state['dropout'])

    if state['deep_out']:
        indim = state['dim_mlp'] / state['maxout_part']
        rank_n_approx = state['rank_n_approx']
        rank_n_activ = state['rank_n_activ']
    else:
        indim = state['rank_n_approx']
        rank_n_approx = 0
        rank_n_activ = None
    output_layer = SoftmaxLayer(rng,
                                indim,
                                state['nouts'],
                                state['weight_scale'],
                                -1,
                                rank_n_approx=rank_n_approx,
                                rank_n_activ=rank_n_activ,
                                weight_noise=state['weight_noise'],
                                init_fn=state['weight_init_fn'],
                                name='out')

    def _pop_op(everything,
                accum,
                everything_max=None,
                everything_min=None,
                word=None,
                aword=None,
                one_step=False,
                use_noise=True):

        rval = proj_h[0](accum[0], one_step=one_step, use_noise=use_noise)
        for si in xrange(1, state['decoder_stack']):
            rval += proj_h[si](accum[si],
                               one_step=one_step,
                               use_noise=use_noise)
        if state['mult_out']:
            rval = rval * everything
        else:
            rval = rval + everything

        if aword and state['avg_word']:
            wcode = aword
            if one_step:
                if state['mult_out']:
                    rval = rval * wcode
                else:
                    rval = rval + wcode
            else:
                if not isinstance(wcode, TT.TensorVariable):
                    wcode = wcode.out
                shape = wcode.shape
                rshape = rval.shape
                rval = rval.reshape(
                    [rshape[0] / shape[0], shape[0], rshape[1]])
                wcode = wcode.dimshuffle('x', 0, 1)
                if state['mult_out']:
                    rval = rval * wcode
                else:
                    rval = rval + wcode
                rval = rval.reshape(rshape)
        if word and state['bigram']:
            if one_step:
                if state['mult_out']:
                    rval *= proj_word(emb_t(word, use_noise=use_noise),
                                      one_step=one_step,
                                      use_noise=use_noise)
                else:
                    rval += proj_word(emb_t(word, use_noise=use_noise),
                                      one_step=one_step,
                                      use_noise=use_noise)
            else:
                if isinstance(word, TT.TensorVariable):
                    shape = word.shape
                    ndim = word.ndim
                else:
                    shape = word.shape
                    ndim = word.out.ndim
                pword = proj_word(emb_t(word, use_noise=use_noise),
                                  one_step=one_step,
                                  use_noise=use_noise)
                shape_pword = pword.shape
                if ndim == 1:
                    pword = Shift()(pword.reshape([shape[0], 1, outdim]))
                else:
                    pword = Shift()(pword.reshape([shape[0], shape[1],
                                                   outdim]))
                if state['mult_out']:
                    rval *= pword.reshape(shape_pword)
                else:
                    rval += pword.reshape(shape_pword)
        if state['deep_out']:
            rval = drop_layer(act_layer(rval), use_noise=use_noise)
        return rval

    pop_op = Operator(_pop_op)

    # 3. Constructing the model
    gater_below = None
    if state['rec_gating']:
        gater_below = gater_words[0](emb(x))
    reseter_below = None
    if state['rec_reseting']:
        reseter_below = reseter_words[0](emb(x))
    encoder_acts = [
        add_op(emb_words[0](emb(x)),
               x_mask,
               bs=x_mask.shape[1],
               si=0,
               gater_below=gater_below,
               reseter_below=reseter_below)
    ]
    if state['encoder_stack'] > 1:
        everything = encoder_proj[0](last(encoder_acts[-1]))
    for si in xrange(1, state['encoder_stack']):
        gater_below = None
        if state['rec_gating']:
            gater_below = gater_words[si](emb(x))
        reseter_below = None
        if state['rec_reseting']:
            reseter_below = reseter_words[si](emb(x))
        encoder_acts.append(
            add_op(emb_words[si](emb(x)),
                   x_mask,
                   bs=x_mask.shape[1],
                   si=si,
                   state_below=encoder_acts[-1],
                   gater_below=gater_below,
                   reseter_below=reseter_below))
        if state['encoder_stack'] > 1:
            everything += encoder_proj[si](last(encoder_acts[-1]))

    if state['encoder_stack'] <= 1:
        encoder = encoder_acts[-1]
        everything = LastState(ntimes=True, n=y.shape[0])(encoder)
    else:
        everything = encoder_act_layer(everything)
        everything = everything.reshape(
            [1, everything.shape[0], everything.shape[1]])
        everything = LastState(ntimes=True, n=y.shape[0])(everything)

    if state['bias_code']:
        init_state = [bc(everything[-1]) for bc in bias_code]
    else:
        init_state = [None for bc in bias_code]

    if state['avg_word']:
        shape = x.shape
        pword = emb(x).out.reshape(
            [shape[0], shape[1], state['rank_n_approx']])
        pword = pword * x_mask.dimshuffle(0, 1, 'x')
        aword = pword.sum(0) / TT.maximum(1., x_mask.sum(0).dimshuffle(0, 'x'))
        aword = word_code(aword, use_noise=False)
    else:
        aword = None

    gater_below = None
    if state['rec_gating']:
        gater_below = gater_words_t[0](emb_t(y0))
    reseter_below = None
    if state['rec_reseting']:
        reseter_below = reseter_words_t[0](emb_t(y0))
    has_said = [
        add_t_op(emb_words_t[0](emb_t(y0)),
                 everything,
                 y_mask,
                 bs=y_mask.shape[1],
                 gater_below=gater_below,
                 reseter_below=reseter_below,
                 init_state=init_state[0],
                 si=0)
    ]
    for si in xrange(1, state['decoder_stack']):
        gater_below = None
        if state['rec_gating']:
            gater_below = gater_words_t[si](emb_t(y0))
        reseter_below = None
        if state['rec_reseting']:
            reseter_below = reseter_words_t[si](emb_t(y0))
        has_said.append(
            add_t_op(emb_words_t[si](emb_t(y0)),
                     everything,
                     y_mask,
                     bs=y_mask.shape[1],
                     state_below=has_said[-1],
                     gater_below=gater_below,
                     reseter_below=reseter_below,
                     init_state=init_state[si],
                     si=si))

    if has_said[0].out.ndim < 3:
        for si in xrange(state['decoder_stack']):
            shape_hs = has_said[si].shape
            if y0.ndim == 1:
                shape = y0.shape
                has_said[si] = Shift()(has_said[si].reshape(
                    [shape[0], 1, state['dim_mlp']]))
            else:
                shape = y0.shape
                has_said[si] = Shift()(has_said[si].reshape(
                    [shape[0], shape[1], state['dim_mlp']]))
            has_said[si].out = TT.set_subtensor(has_said[si].out[0, :, :],
                                                init_state[si])
            has_said[si] = has_said[si].reshape(shape_hs)
    else:
        for si in xrange(state['decoder_stack']):
            has_said[si] = Shift()(has_said[si])
            has_said[si].out = TT.set_subtensor(has_said[si].out[0, :, :],
                                                init_state[si])

    model = pop_op(proj_code(everything), has_said, word=y0, aword=aword)

    nll = output_layer.train(
        state_below=model, target=y0, mask=y_mask, reg=None) / TT.cast(
            y.shape[0] * y.shape[1], 'float32')

    valid_fn = None
    noise_fn = None

    x = TT.lvector(name='x')
    n_steps = TT.iscalar('nsteps')
    temp = TT.scalar('temp')
    gater_below = None
    if state['rec_gating']:
        gater_below = gater_words[0](emb(x))
    reseter_below = None
    if state['rec_reseting']:
        reseter_below = reseter_words[0](emb(x))
    encoder_acts = [
        add_op(emb_words[0](emb(x), use_noise=False),
               si=0,
               use_noise=False,
               gater_below=gater_below,
               reseter_below=reseter_below)
    ]
    if state['encoder_stack'] > 1:
        everything = encoder_proj[0](last(encoder_acts[-1]), use_noise=False)
    for si in xrange(1, state['encoder_stack']):
        gater_below = None
        if state['rec_gating']:
            gater_below = gater_words[si](emb(x))
        reseter_below = None
        if state['rec_reseting']:
            reseter_below = reseter_words[si](emb(x))
        encoder_acts.append(
            add_op(emb_words[si](emb(x), use_noise=False),
                   si=si,
                   state_below=encoder_acts[-1],
                   use_noise=False,
                   gater_below=gater_below,
                   reseter_below=reseter_below))
        if state['encoder_stack'] > 1:
            everything += encoder_proj[si](last(encoder_acts[-1]),
                                           use_noise=False)
    if state['encoder_stack'] <= 1:
        encoder = encoder_acts[-1]
        everything = last(encoder)
    else:
        everything = encoder_act_layer(everything)

    init_state = []
    for si in xrange(state['decoder_stack']):
        if state['bias_code']:
            init_state.append(
                TT.reshape(bias_code[si](everything, use_noise=False),
                           [1, state['dim']]))
        else:
            init_state.append(TT.alloc(numpy.float32(0), 1, state['dim']))

    if state['avg_word']:
        aword = emb(x, use_noise=False).out.mean(0)
        aword = word_code(aword, use_noise=False)
    else:
        aword = None

    def sample_fn(*args):
        aidx = 0
        word_tm1 = args[aidx]
        aidx += 1
        prob_tm1 = args[aidx]
        has_said_tm1 = []
        for si in xrange(state['decoder_stack']):
            aidx += 1
            has_said_tm1.append(args[aidx])
        aidx += 1
        ctx = args[aidx]
        if state['avg_word']:
            aidx += 1
            awrd = args[aidx]

        val = pop_op(proj_code(ctx),
                     has_said_tm1,
                     word=word_tm1,
                     aword=awrd,
                     one_step=True,
                     use_noise=False)
        sample = output_layer.get_sample(state_below=val, temp=temp)
        logp = output_layer.get_cost(state_below=val.out.reshape(
            [1, TT.cast(output_layer.n_in, 'int64')]),
                                     temp=temp,
                                     target=sample.reshape([1, 1]),
                                     use_noise=False)
        gater_below = None
        if state['rec_gating']:
            gater_below = gater_words_t[0](emb_t(sample))
        reseter_below = None
        if state['rec_reseting']:
            reseter_below = reseter_words_t[0](emb_t(sample))
        has_said_t = [
            add_t_op(emb_words_t[0](emb_t(sample)),
                     ctx,
                     prev_val=has_said_tm1[0],
                     gater_below=gater_below,
                     reseter_below=reseter_below,
                     one_step=True,
                     use_noise=True,
                     si=0)
        ]
        for si in xrange(1, state['decoder_stack']):
            gater_below = None
            if state['rec_gating']:
                gater_below = gater_words_t[si](emb_t(sample))
            reseter_below = None
            if state['rec_reseting']:
                reseter_below = reseter_words_t[si](emb_t(sample))
            has_said_t.append(
                add_t_op(emb_words_t[si](emb_t(sample)),
                         ctx,
                         prev_val=has_said_tm1[si],
                         gater_below=gater_below,
                         reseter_below=reseter_below,
                         one_step=True,
                         use_noise=True,
                         si=si,
                         state_below=has_said_t[-1]))
        for si in xrange(state['decoder_stack']):
            if isinstance(has_said_t[si], list):
                has_said_t[si] = has_said_t[si][-1]
        rval = [sample, TT.cast(logp, 'float32')] + has_said_t
        return rval

    sampler_params = [everything]
    if state['avg_word']:
        sampler_params.append(aword)

    states = [TT.alloc(numpy.int64(0), n_steps)]
    states.append(TT.alloc(numpy.float32(0), n_steps))
    states += init_state

    outputs, updates = scan(sample_fn,
                            states=states,
                            params=sampler_params,
                            n_steps=n_steps,
                            name='sampler_scan')
    samples = outputs[0]
    probs = outputs[1]

    sample_fn = theano.function([n_steps, temp, x],
                                [samples, probs.sum()],
                                updates=updates,
                                profile=False,
                                name='sample_fn')

    model = LM_Model(cost_layer=nll,
                     weight_noise_amount=state['weight_noise_amount'],
                     valid_fn=valid_fn,
                     sample_fn=sample_fn,
                     clean_before_noise_fn=False,
                     noise_fn=noise_fn,
                     indx_word=state['indx_word_target'],
                     indx_word_src=state['indx_word'],
                     character_level=False,
                     rng=rng)

    if state['loopIters'] > 0: algo = SGD(model, state, train_data)
    else: algo = None

    def hook_fn():
        if not hasattr(model, 'word_indxs'): model.load_dict()
        if not hasattr(model, 'word_indxs_src'):
            model.word_indxs_src = model.word_indxs
        old_offset = train_data.offset
        if state['sample_reset']: train_data.reset()
        ns = 0
        for sidx in xrange(state['sample_n']):
            while True:
                batch = train_data.next()
                if batch:
                    break
            x = batch['x']
            y = batch['y']
            #xbow = batch['x_bow']
            masks = batch['x_mask']
            if x.ndim > 1:
                for idx in xrange(x.shape[1]):
                    ns += 1
                    if ns > state['sample_max']:
                        break
                    print 'Input: ',
                    for k in xrange(x[:, idx].shape[0]):
                        print model.word_indxs_src[x[:, idx][k]],
                        if model.word_indxs_src[x[:, idx][k]] == '<eol>':
                            break
                    print ''
                    print 'Target: ',
                    for k in xrange(y[:, idx].shape[0]):
                        print model.word_indxs[y[:, idx][k]],
                        if model.word_indxs[y[:, idx][k]] == '<eol>':
                            break
                    print ''
                    senlen = len(x[:, idx])
                    if len(numpy.where(masks[:, idx] == 0)[0]) > 0:
                        senlen = numpy.where(masks[:, idx] == 0)[0][0]
                    if senlen < 1:
                        continue
                    xx = x[:senlen, idx]
                    #xx = xx.reshape([xx.shape[0], 1])
                    model.get_samples(state['seqlen'] + 1, 1, xx)
            else:
                ns += 1
                model.get_samples(state['seqlen'] + 1, 1, x)
            if ns > state['sample_max']:
                break
        train_data.offset = old_offset
        return

    main = MainLoop(train_data,
                    valid_data,
                    None,
                    model,
                    algo,
                    state,
                    channel,
                    reset=state['reset'],
                    hooks=hook_fn)
    if state['reload']: main.load()
    if state['loopIters'] > 0: main.main()

    if state['sampler_test']:
        # This is a test script: we only sample
        if not hasattr(model, 'word_indxs'): model.load_dict()
        if not hasattr(model, 'word_indxs_src'):
            model.word_indxs_src = model.word_indxs

        indx_word = pkl.load(open(state['word_indx'], 'rb'))

        try:
            while True:
                try:
                    seqin = raw_input('Input Sequence: ')
                    n_samples = int(raw_input('How many samples? '))
                    alpha = float(raw_input('Inverse Temperature? '))

                    seqin = seqin.lower()
                    seqin = seqin.split()

                    seqlen = len(seqin)
                    seq = numpy.zeros(seqlen + 1, dtype='int64')
                    for idx, sx in enumerate(seqin):
                        try:
                            seq[idx] = indx_word[sx]
                        except:
                            seq[idx] = indx_word[state['oov']]
                    seq[-1] = state['null_sym_source']

                except Exception:
                    print 'Something wrong with your input! Try again!'
                    continue

                sentences = []
                all_probs = []
                for sidx in xrange(n_samples):
                    #import ipdb; ipdb.set_trace()
                    [values, probs] = model.sample_fn(seqlen * 3, alpha, seq)
                    sen = []
                    for k in xrange(values.shape[0]):
                        if model.word_indxs[values[k]] == '<eol>':
                            break
                        sen.append(model.word_indxs[values[k]])
                    sentences.append(" ".join(sen))
                    all_probs.append(-probs)
                sprobs = numpy.argsort(all_probs)
                for pidx in sprobs:
                    print pidx, "(%f):" % (-all_probs[pidx]), sentences[pidx]
                print

        except KeyboardInterrupt:
            print 'Interrupted'
            pass
Exemplo n.º 5
0
def jobman(state, channel):
    # load dataset
    state['null_sym_source'] = 15000 
    state['null_sym_target'] = 15000
    state['n_sym_source'] = state['null_sym_source'] + 1
    state['n_sym_target'] = state['null_sym_target'] + 1

    state['nouts'] = state['n_sym_target']
    state['nins'] = state['n_sym_source']
    rng = numpy.random.RandomState(state['seed'])
    if state['loopIters'] > 0:
        train_data, valid_data, test_data = get_data(state)
    else:
        train_data = None
        valid_data = None
        test_data = None

    ########### Training graph #####################
    ## 1. Inputs
    if state['bs'] == 1:
        x = TT.lvector('x')
        x_mask = TT.vector('x_mask')
        y = TT.lvector('y')
        y0 = y
        y_mask = TT.vector('y_mask')
    else:
        x = TT.lmatrix('x')
        x_mask = TT.matrix('x_mask')
        y = TT.lmatrix('y')
        y0 = y
        y_mask = TT.matrix('y_mask')

    # 2. Layers and Operators
    bs = state['bs']

    embdim = state['dim_mlp']

    # Source Sentence
    emb = MultiLayer(
        rng,
        n_in=state['nins'],
        n_hids=[state['rank_n_approx']],
        activation=[state['rank_n_activ']],
        init_fn=state['weight_init_fn'],
        weight_noise=state['weight_noise'],
        scale=state['weight_scale'],
        name='emb')

    emb_words = []
    if state['rec_gating']:
        gater_words = []
    if state['rec_reseting']:
        reseter_words = []
    for si in xrange(state['encoder_stack']):
        emb_words.append(MultiLayer(
            rng,
            n_in=state['rank_n_approx'],
            n_hids=[embdim],
            activation=['lambda x:x'],
            init_fn=state['weight_init_fn'],
            weight_noise=state['weight_noise'],
            scale=state['weight_scale'],
            name='emb_words_%d'%si))
        if state['rec_gating']:
            gater_words.append(MultiLayer(
                rng,
                n_in=state['rank_n_approx'],
                n_hids=[state['dim']],
                activation=['lambda x:x'],
                init_fn=state['weight_init_fn'],
                weight_noise=state['weight_noise'],
                scale=state['weight_scale'],
                learn_bias = False,
                name='gater_words_%d'%si))
        if state['rec_reseting']:
            reseter_words.append(MultiLayer(
                rng,
                n_in=state['rank_n_approx'],
                n_hids=[state['dim']],
                activation=['lambda x:x'],
                init_fn=state['weight_init_fn'],
                weight_noise=state['weight_noise'],
                scale=state['weight_scale'],
                learn_bias = False,
                name='reseter_words_%d'%si))

    add_rec_step = []
    rec_proj = []
    if state['rec_gating']:
        rec_proj_gater = []
    if state['rec_reseting']:
        rec_proj_reseter = []
    for si in xrange(state['encoder_stack']):
        if si > 0:
            rec_proj.append(MultiLayer(
                rng,
                n_in=state['dim'],
                n_hids=[embdim],
                activation=['lambda x:x'],
                init_fn=state['rec_weight_init_fn'],
                weight_noise=state['weight_noise'],
                scale=state['rec_weight_scale'],
                name='rec_proj_%d'%si))
            if state['rec_gating']:
                rec_proj_gater.append(MultiLayer(
                    rng,
                    n_in=state['dim'],
                    n_hids=[state['dim']],
                    activation=['lambda x:x'],
                    init_fn=state['weight_init_fn'],
                    weight_noise=state['weight_noise'],
                    scale=state['weight_scale'],
                    learn_bias = False, 
                    name='rec_proj_gater_%d'%si))
            if state['rec_reseting']:
                rec_proj_reseter.append(MultiLayer(
                    rng,
                    n_in=state['dim'],
                    n_hids=[state['dim']],
                    activation=['lambda x:x'],
                    init_fn=state['weight_init_fn'],
                    weight_noise=state['weight_noise'],
                    scale=state['weight_scale'],
                    learn_bias = False, 
                    name='rec_proj_reseter_%d'%si))

        add_rec_step.append(eval(state['rec_layer'])(
                rng,
                n_hids=state['dim'],
                activation = state['activ'],
                bias_scale = state['bias'],
                scale=state['rec_weight_scale'],
                init_fn=state['rec_weight_init_fn'],
                weight_noise=state['weight_noise_rec'],
                dropout=state['dropout_rec'],
                gating=state['rec_gating'],
                gater_activation=state['rec_gater'],
                reseting=state['rec_reseting'],
                reseter_activation=state['rec_reseter'],
                name='add_h_%d'%si))

    def _add_op(words_embeddings, 
                words_mask=None,
                prev_val=None,
                si = 0, 
                state_below = None,
                gater_below = None,
                reseter_below = None,
                one_step=False, 
                bs=1, 
                init_state=None, 
                use_noise=True):
        seqlen = words_embeddings.out.shape[0]//bs
        rval = words_embeddings
        gater = None
        reseter = None
        if state['rec_gating']:
            gater = gater_below
        if state['rec_reseting']:
            reseter = reseter_below
        if si > 0:
            rval += rec_proj[si-1](state_below, one_step=one_step, 
                    use_noise=use_noise)
            if state['rec_gating']:
                projg = rec_proj_gater[si-1](state_below, one_step=one_step, 
                        use_noise = use_noise)
                if gater: gater += projg
                else: gater = projg
            if state['rec_reseting']:
                projg = rec_proj_reseter[si-1](state_below, one_step=one_step, 
                        use_noise = use_noise)
                if reseter: reseter += projg
                else: reseter = projg
            
        if not one_step:
            rval= add_rec_step[si](
                rval,
                nsteps=seqlen,
                batch_size=bs,
                mask=words_mask,
                gater_below = gater,
                reseter_below = reseter,
                one_step=one_step,
                init_state=init_state,
                use_noise = use_noise)
        else:
            rval= add_rec_step[si](
                rval,
                mask=words_mask,
                state_before=prev_val,
                gater_below = gater,
                reseter_below = reseter,
                one_step=one_step,
                init_state=init_state,
                use_noise = use_noise)
        return rval
    add_op = Operator(_add_op)

    # Target Sentence
    emb_t = MultiLayer(
        rng,
        n_in=state['nouts'],
        n_hids=[state['rank_n_approx']],
        activation=[state['rank_n_activ']],
        init_fn=state['weight_init_fn'],
        weight_noise=state['weight_noise'],
        scale=state['weight_scale'],
        name='emb_t')

    emb_words_t = []
    if state['rec_gating']:
        gater_words_t = []
    if state['rec_reseting']:
        reseter_words_t = []
    for si in xrange(state['decoder_stack']):
        emb_words_t.append(MultiLayer(
            rng,
            n_in=state['rank_n_approx'],
            n_hids=[embdim],
            activation=['lambda x:x'],
            init_fn=state['weight_init_fn'],
            weight_noise=state['weight_noise'],
            scale=state['weight_scale'],
            name='emb_words_t_%d'%si))
        if state['rec_gating']:
            gater_words_t.append(MultiLayer(
                rng,
                n_in=state['rank_n_approx'],
                n_hids=[state['dim']],
                activation=['lambda x:x'],
                init_fn=state['weight_init_fn'],
                weight_noise=state['weight_noise'],
                scale=state['weight_scale'],
                learn_bias=False,
                name='gater_words_t_%d'%si))
        if state['rec_reseting']:
            reseter_words_t.append(MultiLayer(
                rng,
                n_in=state['rank_n_approx'],
                n_hids=[state['dim']],
                activation=['lambda x:x'],
                init_fn=state['weight_init_fn'],
                weight_noise=state['weight_noise'],
                scale=state['weight_scale'],
                learn_bias=False,
                name='reseter_words_t_%d'%si))

    proj_everything_t = []
    if state['rec_gating']:
        gater_everything_t = []
    if state['rec_reseting']:
        reseter_everything_t = []
    for si in xrange(state['decoder_stack']):
        proj_everything_t.append(MultiLayer(
            rng,
            n_in=state['dim'],
            n_hids=[embdim],
            activation=['lambda x:x'],
            init_fn=state['weight_init_fn'],
            weight_noise=state['weight_noise'],
            scale=state['weight_scale'],
            name='proj_everything_t_%d'%si,
            learn_bias = False))
        if state['rec_gating']:
            gater_everything_t.append(MultiLayer(
                rng,
                n_in=state['dim'],
                n_hids=[state['dim']],
                activation=['lambda x:x'],
                init_fn=state['weight_init_fn'],
                weight_noise=state['weight_noise'],
                scale=state['weight_scale'],
                name='gater_everything_t_%d'%si,
                learn_bias = False))
        if state['rec_reseting']:
            reseter_everything_t.append(MultiLayer(
                rng,
                n_in=state['dim'],
                n_hids=[state['dim']],
                activation=['lambda x:x'],
                init_fn=state['weight_init_fn'],
                weight_noise=state['weight_noise'],
                scale=state['weight_scale'],
                name='reseter_everything_t_%d'%si,
                learn_bias = False))

    add_rec_step_t = []
    rec_proj_t = []
    if state['rec_gating']:
        rec_proj_t_gater = []
    if state['rec_reseting']:
        rec_proj_t_reseter = []
    for si in xrange(state['decoder_stack']):
        if si > 0:
            rec_proj_t.append(MultiLayer(
                rng,
                n_in=state['dim'],
                n_hids=[embdim],
                activation=['lambda x:x'],
                init_fn=state['rec_weight_init_fn'],
                weight_noise=state['weight_noise'],
                scale=state['rec_weight_scale'],
                name='rec_proj_%d'%si))
            if state['rec_gating']:
                rec_proj_t_gater.append(MultiLayer(
                    rng,
                    n_in=state['dim'],
                    n_hids=[state['dim']],
                    activation=['lambda x:x'],
                    init_fn=state['weight_init_fn'],
                    weight_noise=state['weight_noise'],
                    scale=state['weight_scale'],
                    learn_bias=False,
                    name='rec_proj_t_gater_%d'%si))
            if state['rec_reseting']:
                rec_proj_t_reseter.append(MultiLayer(
                    rng,
                    n_in=state['dim'],
                    n_hids=[state['dim']],
                    activation=['lambda x:x'],
                    init_fn=state['weight_init_fn'],
                    weight_noise=state['weight_noise'],
                    scale=state['weight_scale'],
                    learn_bias=False,
                    name='rec_proj_t_reseter_%d'%si))

        add_rec_step_t.append(eval(state['rec_layer'])(
                rng,
                n_hids=state['dim'],
                activation = state['activ'],
                bias_scale = state['bias'],
                scale=state['rec_weight_scale'],
                init_fn=state['rec_weight_init_fn'],
                weight_noise=state['weight_noise_rec'],
                dropout=state['dropout_rec'],
                gating=state['rec_gating'],
                gater_activation=state['rec_gater'],
                reseting=state['rec_reseting'],
                reseter_activation=state['rec_reseter'],
                name='add_h_t_%d'%si))

    if state['encoder_stack'] > 1:
        encoder_proj = []
        for si in xrange(state['encoder_stack']):
            encoder_proj.append(MultiLayer(
                rng,
                n_in=state['dim'],
                n_hids=[state['dim'] * state['maxout_part']],
                activation=['lambda x: x'],
                init_fn=state['weight_init_fn'],
                weight_noise=state['weight_noise'],
                scale=state['weight_scale'],
                name='encoder_proj_%d'%si,
                learn_bias = (si == 0)))

        encoder_act_layer = UnaryOp(activation=eval(state['unary_activ']), 
                indim = indim, pieces = pieces, rng=rng)

    def _add_t_op(words_embeddings, everything = None, words_mask=None,
                prev_val=None,one_step=False, bs=1, 
                init_state=None, use_noise=True,
                gater_below = None,
                reseter_below = None,
                si = 0, state_below = None):
        seqlen = words_embeddings.out.shape[0]//bs

        rval = words_embeddings
        gater = None
        if state['rec_gating']:
            gater = gater_below
        reseter = None
        if state['rec_reseting']:
            reseter = reseter_below
        if si > 0:
            if isinstance(state_below, list):
                state_below = state_below[-1]
            rval += rec_proj_t[si-1](state_below, 
                    one_step=one_step, use_noise=use_noise)
            if state['rec_gating']:
                projg = rec_proj_t_gater[si-1](state_below, one_step=one_step, 
                        use_noise = use_noise)
                if gater: gater += projg
                else: gater = projg
            if state['rec_reseting']:
                projg = rec_proj_t_reseter[si-1](state_below, one_step=one_step, 
                        use_noise = use_noise)
                if reseter: reseter += projg
                else: reseter = projg
        if everything:
            rval = rval + proj_everything_t[si](everything)
            if state['rec_gating']:
                everyg = gater_everything_t[si](everything, one_step=one_step, use_noise=use_noise)
                if gater: gater += everyg
                else: gater = everyg
            if state['rec_reseting']:
                everyg = reseter_everything_t[si](everything, one_step=one_step, use_noise=use_noise)
                if reseter: reseter += everyg
                else: reseter = everyg

        if not one_step:
            rval = add_rec_step_t[si](
                rval,
                nsteps=seqlen,
                batch_size=bs,
                mask=words_mask,
                one_step=one_step,
                init_state=init_state,
                gater_below = gater,
                reseter_below = reseter,
                use_noise = use_noise)
        else:
            rval = add_rec_step_t[si](
                rval,
                mask=words_mask,
                state_before=prev_val,
                one_step=one_step,
                gater_below = gater,
                reseter_below = reseter,
                use_noise = use_noise)
        return rval
    add_t_op = Operator(_add_t_op)

    outdim = state['dim_mlp']
    if not state['deep_out']:
        outdim = state['rank_n_approx']

    if state['bias_code']:
        bias_code = []
        for si in xrange(state['decoder_stack']):
            bias_code.append(MultiLayer(
                rng,
                n_in=state['dim'],
                n_hids=[state['dim']],
                activation = [state['activ']],
                bias_scale = [state['bias']],
                scale=state['weight_scale'],
                init_fn=state['weight_init_fn'],
                weight_noise=state['weight_noise'],
                name='bias_code_%d'%si))

    if state['avg_word']:
        word_code_nin = state['rank_n_approx']
        word_code = MultiLayer(
            rng,
            n_in=word_code_nin,
            n_hids=[outdim],
            activation = 'lambda x:x',
            bias_scale = [state['bias_mlp']/3],
            scale=state['weight_scale'],
            init_fn=state['weight_init_fn'],
            weight_noise=state['weight_noise'],
            learn_bias = False,
            name='word_code')

    proj_code = MultiLayer(
        rng,
        n_in=state['dim'],
        n_hids=[outdim],
        activation = 'lambda x: x',
        bias_scale = [state['bias_mlp']/3],
        scale=state['weight_scale'],
        init_fn=state['weight_init_fn'],
        weight_noise=state['weight_noise'],
        learn_bias = False,
        name='proj_code')

    proj_h = []
    for si in xrange(state['decoder_stack']):
        proj_h.append(MultiLayer(
            rng,
            n_in=state['dim'],
            n_hids=[outdim],
            activation = 'lambda x: x',
            bias_scale = [state['bias_mlp']/3],
            scale=state['weight_scale'],
            init_fn=state['weight_init_fn'],
            weight_noise=state['weight_noise'],
            name='proj_h_%d'%si))

    if state['bigram']:
        proj_word = MultiLayer(
            rng,
            n_in=state['rank_n_approx'],
            n_hids=[outdim],
            activation=['lambda x:x'],
            bias_scale = [state['bias_mlp']/3],
            init_fn=state['weight_init_fn'],
            weight_noise=state['weight_noise'],
            scale=state['weight_scale'],
            learn_bias = False,
            name='emb_words_lm')

    if state['deep_out']:
        indim = 0
        pieces = 0
        act_layer = UnaryOp(activation=eval(state['unary_activ']))
        drop_layer = DropOp(rng=rng, dropout=state['dropout'])

    if state['deep_out']:
        indim = state['dim_mlp'] / state['maxout_part']
        rank_n_approx = state['rank_n_approx']
        rank_n_activ = state['rank_n_activ']
    else:
        indim = state['rank_n_approx']
        rank_n_approx = 0
        rank_n_activ = None
    output_layer = SoftmaxLayer(
            rng,
            indim, 
            state['nouts'],
            state['weight_scale'],
            -1, 
            rank_n_approx = rank_n_approx,
            rank_n_activ = rank_n_activ,
            weight_noise=state['weight_noise'],
            init_fn=state['weight_init_fn'],
            name='out')

    def _pop_op(everything, accum, everything_max = None,
            everything_min = None, word = None, aword = None,
            one_step=False, use_noise=True):

        rval = proj_h[0](accum[0], one_step=one_step, use_noise=use_noise)
        for si in xrange(1,state['decoder_stack']):
            rval += proj_h[si](accum[si], one_step=one_step, use_noise=use_noise)
        if state['mult_out']:
            rval = rval * everything
        else:
            rval = rval + everything

        if aword and state['avg_word']:
            wcode = aword
            if one_step:
                if state['mult_out']:
                    rval = rval * wcode
                else:
                    rval = rval + wcode
            else:
                if not isinstance(wcode, TT.TensorVariable):
                    wcode = wcode.out
                shape = wcode.shape
                rshape = rval.shape
                rval = rval.reshape([rshape[0]/shape[0], shape[0], rshape[1]])
                wcode = wcode.dimshuffle('x', 0, 1)
                if state['mult_out']:
                    rval = rval * wcode
                else:
                    rval = rval + wcode
                rval = rval.reshape(rshape)
        if word and state['bigram']:
            if one_step:
                if state['mult_out']:
                    rval *= proj_word(emb_t(word, use_noise=use_noise), 
                            one_step=one_step, use_noise=use_noise)
                else:
                    rval += proj_word(emb_t(word, use_noise=use_noise), 
                            one_step=one_step, use_noise=use_noise)
            else:
                if isinstance(word, TT.TensorVariable):
                    shape = word.shape
                    ndim = word.ndim
                else:
                    shape = word.shape
                    ndim = word.out.ndim
                pword = proj_word(emb_t(word, use_noise=use_noise), 
                        one_step=one_step, use_noise=use_noise)
                shape_pword = pword.shape
                if ndim == 1:
                    pword = Shift()(pword.reshape([shape[0], 1, outdim]))
                else:
                    pword = Shift()(pword.reshape([shape[0], shape[1], outdim]))
                if state['mult_out']:
                    rval *= pword.reshape(shape_pword)
                else:
                    rval += pword.reshape(shape_pword)
        if state['deep_out']:
            rval = drop_layer(act_layer(rval), use_noise=use_noise)
        return rval

    pop_op = Operator(_pop_op)

    # 3. Constructing the model
    gater_below = None
    if state['rec_gating']:
        gater_below = gater_words[0](emb(x))
    reseter_below = None
    if state['rec_reseting']:
        reseter_below = reseter_words[0](emb(x))
    encoder_acts = [add_op(emb_words[0](emb(x)), x_mask, 
        bs=x_mask.shape[1], 
        si=0, gater_below=gater_below, reseter_below=reseter_below)]
    if state['encoder_stack'] > 1:
        everything = encoder_proj[0](last(encoder_acts[-1]))
    for si in xrange(1,state['encoder_stack']):
        gater_below = None
        if state['rec_gating']:
            gater_below = gater_words[si](emb(x))
        reseter_below = None
        if state['rec_reseting']:
            reseter_below = reseter_words[si](emb(x))
        encoder_acts.append(add_op(emb_words[si](emb(x)), 
                x_mask, bs=x_mask.shape[1], 
                si=si, state_below=encoder_acts[-1], 
                gater_below=gater_below,
                reseter_below=reseter_below))
        if state['encoder_stack'] > 1:
            everything += encoder_proj[si](last(encoder_acts[-1]))

    if state['encoder_stack'] <= 1:
        encoder = encoder_acts[-1]
        everything = LastState(ntimes=True,n=y.shape[0])(encoder)
    else:
        everything = encoder_act_layer(everything)
        everything = everything.reshape([1, everything.shape[0], everything.shape[1]])
        everything = LastState(ntimes=True,n=y.shape[0])(everything)

    if state['bias_code']:
        init_state = [bc(everything[-1]) for bc in bias_code]
    else:
        init_state = [None for bc in bias_code]

    if state['avg_word']:
        shape = x.shape
        pword = emb(x).out.reshape([shape[0], shape[1], state['rank_n_approx']])
        pword = pword * x_mask.dimshuffle(0, 1, 'x')
        aword = pword.sum(0) / TT.maximum(1., x_mask.sum(0).dimshuffle(0, 'x'))
        aword = word_code(aword, use_noise=False)
    else:
        aword = None

    gater_below = None
    if state['rec_gating']:
        gater_below = gater_words_t[0](emb_t(y0))
    reseter_below = None
    if state['rec_reseting']:
        reseter_below = reseter_words_t[0](emb_t(y0))
    has_said = [add_t_op(emb_words_t[0](emb_t(y0)), 
            everything,
            y_mask, bs=y_mask.shape[1], 
            gater_below = gater_below,
            reseter_below = reseter_below,
            init_state=init_state[0], 
            si=0)]
    for si in xrange(1,state['decoder_stack']):
        gater_below = None
        if state['rec_gating']:
            gater_below = gater_words_t[si](emb_t(y0))
        reseter_below = None
        if state['rec_reseting']:
            reseter_below = reseter_words_t[si](emb_t(y0))
        has_said.append(add_t_op(emb_words_t[si](emb_t(y0)), 
                everything,
                y_mask, bs=y_mask.shape[1], 
                state_below = has_said[-1],
                gater_below = gater_below,
                reseter_below = reseter_below,
                init_state=init_state[si], 
                si=si))

    if has_said[0].out.ndim < 3:
        for si in xrange(state['decoder_stack']):
            shape_hs = has_said[si].shape
            if y0.ndim == 1:
                shape = y0.shape
                has_said[si] = Shift()(has_said[si].reshape([shape[0], 1, state['dim_mlp']]))
            else:
                shape = y0.shape
                has_said[si] = Shift()(has_said[si].reshape([shape[0], shape[1], state['dim_mlp']]))
            has_said[si].out = TT.set_subtensor(has_said[si].out[0, :, :], init_state[si])
            has_said[si] = has_said[si].reshape(shape_hs)
    else:
        for si in xrange(state['decoder_stack']):
            has_said[si] = Shift()(has_said[si])
            has_said[si].out = TT.set_subtensor(has_said[si].out[0, :, :], init_state[si])

    model = pop_op(proj_code(everything), has_said, word=y0, aword = aword)

    nll = output_layer.train(state_below=model, target=y0,
                 mask=y_mask, reg=None) / TT.cast(y.shape[0]*y.shape[1], 'float32')

    valid_fn = None
    noise_fn = None

    x = TT.lvector(name='x')
    n_steps = TT.iscalar('nsteps')
    temp = TT.scalar('temp')
    gater_below = None
    if state['rec_gating']:
        gater_below = gater_words[0](emb(x))
    reseter_below = None
    if state['rec_reseting']:
        reseter_below = reseter_words[0](emb(x))
    encoder_acts = [add_op(emb_words[0](emb(x),use_noise=False), 
            si=0, 
            use_noise=False, 
            gater_below=gater_below,
            reseter_below=reseter_below)]
    if state['encoder_stack'] > 1:
        everything = encoder_proj[0](last(encoder_acts[-1]), use_noise=False)
    for si in xrange(1,state['encoder_stack']):
        gater_below = None
        if state['rec_gating']:
            gater_below = gater_words[si](emb(x))
        reseter_below = None
        if state['rec_reseting']:
            reseter_below = reseter_words[si](emb(x))
        encoder_acts.append(add_op(emb_words[si](emb(x),use_noise=False), 
                si=si, 
                state_below=encoder_acts[-1], use_noise=False,
                gater_below = gater_below, 
                reseter_below = reseter_below))
        if state['encoder_stack'] > 1:
            everything += encoder_proj[si](last(encoder_acts[-1]), use_noise=False)
    if state['encoder_stack'] <= 1:
        encoder = encoder_acts[-1]
        everything = last(encoder)
    else:
        everything = encoder_act_layer(everything)

    init_state = []
    for si in xrange(state['decoder_stack']):
        if state['bias_code']:
            init_state.append(TT.reshape(bias_code[si](everything, 
                use_noise=False), [1, state['dim']]))
        else:
            init_state.append(TT.alloc(numpy.float32(0), 1, state['dim']))

    if state['avg_word']:
        aword = emb(x,use_noise=False).out.mean(0)
        aword = word_code(aword, use_noise=False)
    else:
        aword = None

    def sample_fn(*args):
        aidx = 0; word_tm1 = args[aidx]
        aidx += 1; prob_tm1 = args[aidx]
        has_said_tm1 = []
        for si in xrange(state['decoder_stack']):
            aidx += 1; has_said_tm1.append(args[aidx])
        aidx += 1; ctx = args[aidx]
        if state['avg_word']:
            aidx += 1; awrd = args[aidx]
        
        val = pop_op(proj_code(ctx), has_said_tm1, word=word_tm1,
                aword=awrd, one_step=True, use_noise=False)
        sample = output_layer.get_sample(state_below=val, temp=temp)
        logp = output_layer.get_cost(
                state_below=val.out.reshape([1, TT.cast(output_layer.n_in, 'int64')]), 
                temp=temp, target=sample.reshape([1,1]), use_noise=False)
        gater_below = None
        if state['rec_gating']:
            gater_below = gater_words_t[0](emb_t(sample))
        reseter_below = None
        if state['rec_reseting']:
            reseter_below = reseter_words_t[0](emb_t(sample))
        has_said_t = [add_t_op(emb_words_t[0](emb_t(sample)), 
                ctx,
                prev_val=has_said_tm1[0], 
                gater_below=gater_below,
                reseter_below=reseter_below,
                one_step=True, use_noise=True,
                si=0)]
        for si in xrange(1, state['decoder_stack']):
            gater_below = None
            if state['rec_gating']:
                gater_below = gater_words_t[si](emb_t(sample))
            reseter_below = None
            if state['rec_reseting']:
                reseter_below = reseter_words_t[si](emb_t(sample))
            has_said_t.append(add_t_op(emb_words_t[si](emb_t(sample)), 
                    ctx,
                    prev_val=has_said_tm1[si], 
                    gater_below=gater_below,
                    reseter_below=reseter_below,
                    one_step=True, use_noise=True,
                    si=si, state_below=has_said_t[-1]))
        for si in xrange(state['decoder_stack']):
            if isinstance(has_said_t[si], list):
                has_said_t[si] = has_said_t[si][-1]
        rval = [sample, TT.cast(logp, 'float32')] + has_said_t
        return rval

    sampler_params = [everything]
    if state['avg_word']:
        sampler_params.append(aword)

    states = [TT.alloc(numpy.int64(0), n_steps)]
    states.append(TT.alloc(numpy.float32(0), n_steps))
    states += init_state

    outputs, updates = scan(sample_fn,
            states = states,
            params = sampler_params,
            n_steps= n_steps,
            name='sampler_scan'
            )
    samples = outputs[0]
    probs = outputs[1]

    sample_fn = theano.function(
        [n_steps, temp, x], [samples, probs.sum()],
        updates=updates,
        profile=False, name='sample_fn')

    model = LM_Model(
        cost_layer = nll,
        weight_noise_amount=state['weight_noise_amount'],
        valid_fn = valid_fn,
        sample_fn  = sample_fn,
        clean_before_noise_fn = False,
        noise_fn = noise_fn,
        indx_word=state['indx_word_target'],
        indx_word_src=state['indx_word'],
        character_level = False,
        rng = rng)

    if state['loopIters'] > 0: algo = SGD(model, state, train_data)
    else: algo = None

    def hook_fn():
        if not hasattr(model, 'word_indxs'): model.load_dict()
        if not hasattr(model, 'word_indxs_src'):
            model.word_indxs_src = model.word_indxs
        old_offset = train_data.offset
        if state['sample_reset']: train_data.reset()
        ns = 0
        for sidx in xrange(state['sample_n']):
            while True:
                batch = train_data.next()
                if batch:
                    break
            x = batch['x']
            y = batch['y']
            #xbow = batch['x_bow']
            masks = batch['x_mask']
            if x.ndim > 1:
                for idx in xrange(x.shape[1]):
                    ns += 1
                    if ns > state['sample_max']:
                        break
                    print 'Input: ',
                    for k in xrange(x[:,idx].shape[0]):
                        print model.word_indxs_src[x[:,idx][k]],
                        if model.word_indxs_src[x[:,idx][k]] == '<eol>':
                            break
                    print ''
                    print 'Target: ',
                    for k in xrange(y[:,idx].shape[0]):
                        print model.word_indxs[y[:,idx][k]],
                        if model.word_indxs[y[:,idx][k]] == '<eol>':
                            break
                    print ''
                    senlen = len(x[:,idx])
                    if len(numpy.where(masks[:,idx]==0)[0]) > 0:
                        senlen = numpy.where(masks[:,idx]==0)[0][0]
                    if senlen < 1:
                        continue
                    xx = x[:senlen, idx]
                    #xx = xx.reshape([xx.shape[0], 1])
                    model.get_samples(state['seqlen']+1,  1, xx)
            else:
                ns += 1
                model.get_samples(state['seqlen']+1,  1, x)
            if ns > state['sample_max']:
                break
        train_data.offset = old_offset
        return

    main = MainLoop(train_data, valid_data, None, model, algo, state, channel,
            reset = state['reset'], hooks = hook_fn)
    if state['reload']: main.load()
    if state['loopIters'] > 0: main.main()

    if state['sampler_test']:
        # This is a test script: we only sample
        if not hasattr(model, 'word_indxs'): model.load_dict()
        if not hasattr(model, 'word_indxs_src'):
            model.word_indxs_src = model.word_indxs

        indx_word=pkl.load(open(state['word_indx'],'rb'))

        try:
            while True:
                try:
                    seqin = raw_input('Input Sequence: ')
                    n_samples = int(raw_input('How many samples? '))
                    alpha = float(raw_input('Inverse Temperature? '))

                    seqin = seqin.lower()
                    seqin = seqin.split()

                    seqlen = len(seqin)
                    seq = numpy.zeros(seqlen+1, dtype='int64')
                    for idx,sx in enumerate(seqin):
                        try:
                            seq[idx] = indx_word[sx]
                        except:
                            seq[idx] = indx_word[state['oov']]
                    seq[-1] = state['null_sym_source']

                except Exception:
                    print 'Something wrong with your input! Try again!'
                    continue

                sentences = []
                all_probs = []
                for sidx in xrange(n_samples):
                    #import ipdb; ipdb.set_trace()
                    [values, probs] = model.sample_fn(seqlen * 3, alpha, seq)
                    sen = []
                    for k in xrange(values.shape[0]):
                        if model.word_indxs[values[k]] == '<eol>':
                            break
                        sen.append(model.word_indxs[values[k]])
                    sentences.append(" ".join(sen))
                    all_probs.append(-probs)
                sprobs = numpy.argsort(all_probs)
                for pidx in sprobs:
                    print pidx,"(%f):"%(-all_probs[pidx]),sentences[pidx]
                print

        except KeyboardInterrupt:
            print 'Interrupted'
            pass
def do_experiment(state, channel):
    logging.basicConfig(level=logging.DEBUG,
            format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")
    logger.debug("Starting state: {}".format(state))

    def maxout(x):
        shape = x.shape
        if x.ndim == 1:
            shape1 = TT.cast(shape[0] / state['maxout_part'], 'int64')
            shape2 = TT.cast(state['maxout_part'], 'int64')
            x = x.reshape([shape1, shape2])
            x = x.max(1)
        else:
            shape1 = TT.cast(shape[1] / state['maxout_part'], 'int64')
            shape2 = TT.cast(state['maxout_part'], 'int64')
            x = x.reshape([shape[0], shape1, shape2])
            x = x.max(2)
        return x

    logger.info("Start loading")
    rng = numpy.random.RandomState(state['seed'])
    train_data, valid_data, test_data = get_data(state)

    logger.info("Build layers")
    if state['bs'] == 1:
        x = TT.lvector('x')
        x_mask = TT.vector('x_mask')
        y = TT.lvector('y')
        y0 = y
        y_mask = TT.vector('y_mask')
    else:
        x = TT.lmatrix('x')
        x_mask = TT.matrix('x_mask')
        y = TT.lmatrix('y')
        y0 = y
        y_mask = TT.matrix('y_mask')
    scoring_inputs = [x, x_mask, y, y_mask]

    bs = state['bs']

    # Dimensionality of word embedings.
    # The same as state['dim'] and in fact equals the number of hidden units.
    embdim = state['dim_mlp']

    logger.info("Source sentence")
    # Low-rank embeddings
    emb = MultiLayer(
        rng,
        n_in=state['nins'],
        n_hids=[state['rank_n_approx']],
        activation=[state['rank_n_activ']],
        init_fn=state['weight_init_fn'],
        weight_noise=state['weight_noise'],
        scale=state['weight_scale'],
        name='emb')

    emb_words = []
    if state['rec_gating']:
        gater_words = []
    if state['rec_reseting']:
        reseter_words = []
    # si always stands for the number in stack of RNNs (which is actually 1)
    for si in xrange(state['encoder_stack']):
        # In paper it is multiplication by W
        emb_words.append(MultiLayer(
            rng,
            n_in=state['rank_n_approx'],
            n_hids=[embdim],
            activation=['lambda x:x'],
            init_fn=state['weight_init_fn'],
            weight_noise=state['weight_noise'],
            scale=state['weight_scale'],
            name='emb_words_%d'%si))
        # In paper it is multiplication by W_z
        if state['rec_gating']:
            gater_words.append(MultiLayer(
                rng,
                n_in=state['rank_n_approx'],
                n_hids=[state['dim']],
                activation=['lambda x:x'],
                init_fn=state['weight_init_fn'],
                weight_noise=state['weight_noise'],
                scale=state['weight_scale'],
                learn_bias = False,
                name='gater_words_%d'%si))
        # In paper it is multiplication by W_r
        if state['rec_reseting']:
            reseter_words.append(MultiLayer(
                rng,
                n_in=state['rank_n_approx'],
                n_hids=[state['dim']],
                activation=['lambda x:x'],
                init_fn=state['weight_init_fn'],
                weight_noise=state['weight_noise'],
                scale=state['weight_scale'],
                learn_bias = False,
                name='reseter_words_%d'%si))

    add_rec_step = []
    rec_proj = []
    if state['rec_gating']:
        rec_proj_gater = []
    if state['rec_reseting']:
        rec_proj_reseter = []
    for si in xrange(state['encoder_stack']):
        if si > 0:
            rec_proj.append(MultiLayer(
                rng,
                n_in=state['dim'],
                n_hids=[embdim],
                activation=['lambda x:x'],
                init_fn=state['rec_weight_init_fn'],
                weight_noise=state['weight_noise'],
                scale=state['rec_weight_scale'],
                name='rec_proj_%d'%si))
            if state['rec_gating']:
                rec_proj_gater.append(MultiLayer(
                    rng,
                    n_in=state['dim'],
                    n_hids=[state['dim']],
                    activation=['lambda x:x'],
                    init_fn=state['weight_init_fn'],
                    weight_noise=state['weight_noise'],
                    scale=state['weight_scale'],
                    learn_bias=False,
                    name='rec_proj_gater_%d'%si))
            if state['rec_reseting']:
                rec_proj_reseter.append(MultiLayer(
                    rng,
                    n_in=state['dim'],
                    n_hids=[state['dim']],
                    activation=['lambda x:x'],
                    init_fn=state['weight_init_fn'],
                    weight_noise=state['weight_noise'],
                    scale=state['weight_scale'],
                    learn_bias=False,
                    name='rec_proj_reseter_%d'%si))

        # This should be U from paper
        add_rec_step.append(eval(state['rec_layer'])(
                rng,
                n_hids=state['dim'],
                activation = state['activ'],
                bias_scale = state['bias'],
                scale=state['rec_weight_scale'],
                init_fn=state['rec_weight_init_fn'],
                weight_noise=state['weight_noise_rec'],
                dropout=state['dropout_rec'],
                gating=state['rec_gating'],
                gater_activation=state['rec_gater'],
                reseting=state['rec_reseting'],
                reseter_activation=state['rec_reseter'],
                name='add_h_%d'%si))

    def _add_op(words_embeddings,
                words_mask=None,
                prev_val=None,
                si = 0,
                state_below = None,
                gater_below = None,
                reseter_below = None,
                one_step=False,
                bs=1,
                init_state=None,
                use_noise=True):
        seqlen = words_embeddings.out.shape[0]//bs
        rval = words_embeddings
        gater = None
        reseter = None
        if state['rec_gating']:
            gater = gater_below
        if state['rec_reseting']:
            reseter = reseter_below
        if si > 0:
            rval += rec_proj[si-1](state_below, one_step=one_step,
                    use_noise=use_noise)
            if state['rec_gating']:
                projg = rec_proj_gater[si-1](state_below, one_step=one_step,
                        use_noise = use_noise)
                if gater: gater += projg
                else: gater = projg
            if state['rec_reseting']:
                projg = rec_proj_reseter[si-1](state_below, one_step=one_step,
                        use_noise = use_noise)
                if reseter: reseter += projg
                else: reseter = projg

        if not one_step:
            rval= add_rec_step[si](
                rval,
                nsteps=seqlen,
                batch_size=bs,
                mask=words_mask,
                gater_below = gater,
                reseter_below = reseter,
                one_step=one_step,
                init_state=init_state,
                use_noise = use_noise)
        else:
            #Here we link the Encoder part
            rval= add_rec_step[si](
                rval,
                mask=words_mask,
                state_before=prev_val,
                gater_below = gater,
                reseter_below = reseter,
                one_step=one_step,
                init_state=init_state,
                use_noise = use_noise)
        return rval
    add_op = Operator(_add_op)

    logger.info("Target sequence")
    emb_t = MultiLayer(
        rng,
        n_in=state['nouts'],
        n_hids=[state['rank_n_approx']],
        activation=[state['rank_n_activ']],
        init_fn=state['weight_init_fn'],
        weight_noise=state['weight_noise'],
        scale=state['weight_scale'],
        name='emb_t')

    emb_words_t = []
    if state['rec_gating']:
        gater_words_t = []
    if state['rec_reseting']:
        reseter_words_t = []
    for si in xrange(state['decoder_stack']):
        emb_words_t.append(MultiLayer(
            rng,
            n_in=state['rank_n_approx'],
            n_hids=[embdim],
            activation=['lambda x:x'],
            init_fn=state['weight_init_fn'],
            weight_noise=state['weight_noise'],
            scale=state['weight_scale'],
            name='emb_words_t_%d'%si))
        if state['rec_gating']:
            gater_words_t.append(MultiLayer(
                rng,
                n_in=state['rank_n_approx'],
                n_hids=[state['dim']],
                activation=['lambda x:x'],
                init_fn=state['weight_init_fn'],
                weight_noise=state['weight_noise'],
                scale=state['weight_scale'],
                learn_bias=False,
                name='gater_words_t_%d'%si))
        if state['rec_reseting']:
            reseter_words_t.append(MultiLayer(
                rng,
                n_in=state['rank_n_approx'],
                n_hids=[state['dim']],
                activation=['lambda x:x'],
                init_fn=state['weight_init_fn'],
                weight_noise=state['weight_noise'],
                scale=state['weight_scale'],
                learn_bias=False,
                name='reseter_words_t_%d'%si))

    proj_everything_t = []
    if state['rec_gating']:
        gater_everything_t = []
    if state['rec_reseting']:
        reseter_everything_t = []
    for si in xrange(state['decoder_stack']):
        # This stands for the matrix C from the text
        proj_everything_t.append(MultiLayer(
            rng,
            n_in=state['dim'],
            n_hids=[embdim],
            activation=['lambda x:x'],
            init_fn=state['weight_init_fn'],
            weight_noise=state['weight_noise'],
            scale=state['weight_scale'],
            name='proj_everything_t_%d'%si,
            learn_bias = False))
        if state['rec_gating']:
            gater_everything_t.append(MultiLayer(
                rng,
                n_in=state['dim'],
                n_hids=[state['dim']],
                activation=['lambda x:x'],
                init_fn=state['weight_init_fn'],
                weight_noise=state['weight_noise'],
                scale=state['weight_scale'],
                name='gater_everything_t_%d'%si,
                learn_bias = False))
        if state['rec_reseting']:
            reseter_everything_t.append(MultiLayer(
                rng,
                n_in=state['dim'],
                n_hids=[state['dim']],
                activation=['lambda x:x'],
                init_fn=state['weight_init_fn'],
                weight_noise=state['weight_noise'],
                scale=state['weight_scale'],
                name='reseter_everything_t_%d'%si,
                learn_bias = False))

    add_rec_step_t = []
    rec_proj_t = []
    if state['rec_gating']:
        rec_proj_t_gater = []
    if state['rec_reseting']:
        rec_proj_t_reseter = []
    for si in xrange(state['decoder_stack']):
        if si > 0:
            rec_proj_t.append(MultiLayer(
                rng,
                n_in=state['dim'],
                n_hids=[embdim],
                activation=['lambda x:x'],
                init_fn=state['rec_weight_init_fn'],
                weight_noise=state['weight_noise'],
                scale=state['rec_weight_scale'],
                name='rec_proj_%d'%si))
            if state['rec_gating']:
                rec_proj_t_gater.append(MultiLayer(
                    rng,
                    n_in=state['dim'],
                    n_hids=[state['dim']],
                    activation=['lambda x:x'],
                    init_fn=state['weight_init_fn'],
                    weight_noise=state['weight_noise'],
                    scale=state['weight_scale'],
                    learn_bias=False,
                    name='rec_proj_t_gater_%d'%si))
            if state['rec_reseting']:
                rec_proj_t_reseter.append(MultiLayer(
                    rng,
                    n_in=state['dim'],
                    n_hids=[state['dim']],
                    activation=['lambda x:x'],
                    init_fn=state['weight_init_fn'],
                    weight_noise=state['weight_noise'],
                    scale=state['weight_scale'],
                    learn_bias=False,
                    name='rec_proj_t_reseter_%d'%si))

        # This one stands for gating, resetting and applying non-linearity in Decoder
        add_rec_step_t.append(eval(state['rec_layer'])(
                rng,
                n_hids=state['dim'],
                activation = state['activ'],
                bias_scale = state['bias'],
                scale=state['rec_weight_scale'],
                init_fn=state['rec_weight_init_fn'],
                weight_noise=state['weight_noise_rec'],
                dropout=state['dropout_rec'],
                gating=state['rec_gating'],
                gater_activation=state['rec_gater'],
                reseting=state['rec_reseting'],
                reseter_activation=state['rec_reseter'],
                name='add_h_t_%d'%si))

    if state['encoder_stack'] > 1:
        encoder_proj = []
        for si in xrange(state['encoder_stack']):
            encoder_proj.append(MultiLayer(
                rng,
                n_in=state['dim'],
                n_hids=[state['dim'] * state['maxout_part']],
                activation=['lambda x: x'],
                init_fn=state['weight_init_fn'],
                weight_noise=state['weight_noise'],
                scale=state['weight_scale'],
                name='encoder_proj_%d'%si,
                learn_bias = (si == 0)))

        encoder_act_layer = UnaryOp(activation=eval(state['unary_activ']),
                indim=indim, pieces=pieces, rng=rng)

    # Actually add target opp
    def _add_t_op(words_embeddings, everything = None, words_mask=None,
                prev_val=None,one_step=False, bs=1,
                init_state=None, use_noise=True,
                gater_below = None,
                reseter_below = None,
                si = 0, state_below = None):
        seqlen = words_embeddings.out.shape[0]//bs

        rval = words_embeddings
        gater = None
        if state['rec_gating']:
            gater = gater_below
        reseter = None
        if state['rec_reseting']:
            reseter = reseter_below
        if si > 0:
            if isinstance(state_below, list):
                state_below = state_below[-1]
            rval += rec_proj_t[si-1](state_below,
                    one_step=one_step, use_noise=use_noise)
            if state['rec_gating']:
                projg = rec_proj_t_gater[si-1](state_below, one_step=one_step,
                        use_noise = use_noise)
                if gater: gater += projg
                else: gater = projg
            if state['rec_reseting']:
                projg = rec_proj_t_reseter[si-1](state_below, one_step=one_step,
                        use_noise = use_noise)
                if reseter: reseter += projg
                else: reseter = projg
        if everything:
            rval = rval + proj_everything_t[si](everything)
            if state['rec_gating']:
                everyg = gater_everything_t[si](everything, one_step=one_step, use_noise=use_noise)
                if gater: gater += everyg
                else: gater = everyg
            if state['rec_reseting']:
                everyg = reseter_everything_t[si](everything, one_step=one_step, use_noise=use_noise)
                if reseter: reseter += everyg
                else: reseter = everyg

        if not one_step:
            rval = add_rec_step_t[si](
                rval,
                nsteps=seqlen,
                batch_size=bs,
                mask=words_mask,
                one_step=one_step,
                init_state=init_state,
                gater_below=gater,
                reseter_below=reseter,
                use_noise=use_noise)
        else:
            # Here we link the Decoder part
            rval = add_rec_step_t[si](
                rval,
                mask=words_mask,
                state_before=prev_val,
                one_step=one_step,
                gater_below=gater,
                reseter_below=reseter,
                use_noise=use_noise)
        return rval
    add_t_op = Operator(_add_t_op)

    outdim = state['dim_mlp']
    if not state['deep_out']:
        outdim = state['rank_n_approx']

    if state['bias_code']:
        bias_code = []
        for si in xrange(state['decoder_stack']):
            bias_code.append(MultiLayer(
                rng,
                n_in=state['dim'],
                n_hids=[state['dim']],
                activation = [state['activ']],
                bias_scale = [state['bias']],
                scale=state['weight_scale'],
                init_fn=state['weight_init_fn'],
                weight_noise=state['weight_noise'],
                name='bias_code_%d'%si))

    if state['avg_word']:
        word_code_nin = state['rank_n_approx']
        word_code = MultiLayer(
            rng,
            n_in=word_code_nin,
            n_hids=[outdim],
            activation = 'lambda x:x',
            bias_scale = [state['bias_mlp']/3],
            scale=state['weight_scale'],
            init_fn=state['weight_init_fn'],
            weight_noise=state['weight_noise'],
            learn_bias = False,
            name='word_code')

    proj_code = MultiLayer(
        rng,
        n_in=state['dim'],
        n_hids=[outdim],
        activation = 'lambda x: x',
        bias_scale = [state['bias_mlp']/3],
        scale=state['weight_scale'],
        init_fn=state['weight_init_fn'],
        weight_noise=state['weight_noise'],
        learn_bias = False,
        name='proj_code')

    proj_h = []
    for si in xrange(state['decoder_stack']):
        proj_h.append(MultiLayer(
            rng,
            n_in=state['dim'],
            n_hids=[outdim],
            activation = 'lambda x: x',
            bias_scale = [state['bias_mlp']/3],
            scale=state['weight_scale'],
            init_fn=state['weight_init_fn'],
            weight_noise=state['weight_noise'],
            name='proj_h_%d'%si))

    if state['bigram']:
        proj_word = MultiLayer(
            rng,
            n_in=state['rank_n_approx'],
            n_hids=[outdim],
            activation=['lambda x:x'],
            bias_scale = [state['bias_mlp']/3],
            init_fn=state['weight_init_fn'],
            weight_noise=state['weight_noise'],
            scale=state['weight_scale'],
            learn_bias = False,
            name='emb_words_lm')

    if state['deep_out']:
        indim = 0
        pieces = 0
        act_layer = UnaryOp(activation=eval(state['unary_activ']))
        drop_layer = DropOp(rng=rng, dropout=state['dropout'])

    if state['deep_out']:
        indim = state['dim_mlp'] / state['maxout_part']
        rank_n_approx = state['rank_n_approx']
        rank_n_activ = state['rank_n_activ']
    else:
        indim = state['rank_n_approx']
        rank_n_approx = 0
        rank_n_activ = None
    output_layer = SoftmaxLayer(
            rng,
            indim,
            state['nouts'],
            state['weight_scale'],
            -1,
            rank_n_approx = rank_n_approx,
            rank_n_activ = rank_n_activ,
            weight_noise=state['weight_noise'],
            init_fn=state['weight_init_fn'],
            name='out')

    def _pop_op(everything, accum, everything_max = None,
            everything_min = None, word = None, aword = None,
            one_step=False, use_noise=True):

        rval = proj_h[0](accum[0], one_step=one_step, use_noise=use_noise)
        for si in xrange(1,state['decoder_stack']):
            rval += proj_h[si](accum[si], one_step=one_step, use_noise=use_noise)
        if state['mult_out']:
            rval = rval * everything
        else:
            rval = rval + everything

        if aword and state['avg_word']:
            wcode = aword
            if one_step:
                if state['mult_out']:
                    rval = rval * wcode
                else:
                    rval = rval + wcode
            else:
                if not isinstance(wcode, TT.TensorVariable):
                    wcode = wcode.out
                shape = wcode.shape
                rshape = rval.shape
                rval = rval.reshape([rshape[0]/shape[0], shape[0], rshape[1]])
                wcode = wcode.dimshuffle('x', 0, 1)
                if state['mult_out']:
                    rval = rval * wcode
                else:
                    rval = rval + wcode
                rval = rval.reshape(rshape)
        if word and state['bigram']:
            if one_step:
                if state['mult_out']:
                    rval *= proj_word(emb_t(word, use_noise=use_noise),
                            one_step=one_step, use_noise=use_noise)
                else:
                    rval += proj_word(emb_t(word, use_noise=use_noise),
                            one_step=one_step, use_noise=use_noise)
            else:
                if isinstance(word, TT.TensorVariable):
                    shape = word.shape
                    ndim = word.ndim
                else:
                    shape = word.shape
                    ndim = word.out.ndim
                pword = proj_word(emb_t(word, use_noise=use_noise),
                        one_step=one_step, use_noise=use_noise)
                shape_pword = pword.shape
                if ndim == 1:
                    pword = Shift()(pword.reshape([shape[0], 1, outdim]))
                else:
                    pword = Shift()(pword.reshape([shape[0], shape[1], outdim]))
                if state['mult_out']:
                    rval *= pword.reshape(shape_pword)
                else:
                    rval += pword.reshape(shape_pword)
        if state['deep_out']:
            rval = drop_layer(act_layer(rval), use_noise=use_noise)
        return rval

    pop_op = Operator(_pop_op)

    logger.info("Construct the model")
    gater_below = None
    if state['rec_gating']:
        gater_below = gater_words[0](emb(x))
    reseter_below = None
    if state['rec_reseting']:
        reseter_below = reseter_words[0](emb(x))
    encoder_acts = [add_op(emb_words[0](emb(x)), x_mask,
        bs=x_mask.shape[1],
        si=0, gater_below=gater_below, reseter_below=reseter_below)]
    if state['encoder_stack'] > 1:
        everything = encoder_proj[0](last(encoder_acts[-1]))
    for si in xrange(1,state['encoder_stack']):
        gater_below = None
        if state['rec_gating']:
            gater_below = gater_words[si](emb(x))
        reseter_below = None
        if state['rec_reseting']:
            reseter_below = reseter_words[si](emb(x))
        encoder_acts.append(add_op(emb_words[si](emb(x)),
                x_mask, bs=x_mask.shape[1],
                si=si, state_below=encoder_acts[-1],
                gater_below=gater_below,
                reseter_below=reseter_below))
        if state['encoder_stack'] > 1:
            everything += encoder_proj[si](last(encoder_acts[-1]))

    if state['encoder_stack'] <= 1:
        encoder = encoder_acts[-1]
        everything = LastState(ntimes=True,n=y.shape[0])(encoder)
    else:
        everything = encoder_act_layer(everything)
        everything = everything.reshape([1, everything.shape[0], everything.shape[1]])
        everything = LastState(ntimes=True,n=y.shape[0])(everything)

    if state['bias_code']:
        init_state = [bc(everything[-1]) for bc in bias_code]
    else:
        init_state = [None for bc in bias_code]

    if state['avg_word']:
        shape = x.shape
        pword = emb(x).out.reshape([shape[0], shape[1], state['rank_n_approx']])
        pword = pword * x_mask.dimshuffle(0, 1, 'x')
        aword = pword.sum(0) / TT.maximum(1., x_mask.sum(0).dimshuffle(0, 'x'))
        aword = word_code(aword, use_noise=False)
    else:
        aword = None

    gater_below = None
    if state['rec_gating']:
        gater_below = gater_words_t[0](emb_t(y0))
    reseter_below = None
    if state['rec_reseting']:
        reseter_below = reseter_words_t[0](emb_t(y0))
    has_said = [add_t_op(emb_words_t[0](emb_t(y0)),
            everything,
            y_mask, bs=y_mask.shape[1],
            gater_below = gater_below,
            reseter_below = reseter_below,
            init_state=init_state[0],
            si=0)]
    for si in xrange(1,state['decoder_stack']):
        gater_below = None
        if state['rec_gating']:
            gater_below = gater_words_t[si](emb_t(y0))
        reseter_below = None
        if state['rec_reseting']:
            reseter_below = reseter_words_t[si](emb_t(y0))
        has_said.append(add_t_op(emb_words_t[si](emb_t(y0)),
                everything,
                y_mask, bs=y_mask.shape[1],
                state_below = has_said[-1],
                gater_below = gater_below,
                reseter_below = reseter_below,
                init_state=init_state[si],
                si=si))

    # has_said are hidden layer states

    if has_said[0].out.ndim < 3:
        for si in xrange(state['decoder_stack']):
            shape_hs = has_said[si].shape
            if y0.ndim == 1:
                shape = y0.shape
                has_said[si] = Shift()(has_said[si].reshape([shape[0], 1, state['dim_mlp']]))
            else:
                shape = y0.shape
                has_said[si] = Shift()(has_said[si].reshape([shape[0], shape[1], state['dim_mlp']]))
            has_said[si] = TT.set_subtensor(has_said[si][0, :, :], init_state[si])
            has_said[si] = has_said[si].reshape(shape_hs)
    else:
        for si in xrange(state['decoder_stack']):
            has_said[si] = Shift()(has_said[si])
            has_said[si].out = TT.set_subtensor(has_said[si][0, :, :], init_state[si])

    model = pop_op(proj_code(everything), has_said, word=y0, aword = aword)

    nll = output_layer.train(state_below=model, target=y0,
                 mask=y_mask, reg=None) / TT.cast(y.shape[0]*y.shape[1], 'float32')

    valid_fn = None
    noise_fn = None

    x = TT.lvector(name='x')
    n_steps = TT.iscalar('nsteps')
    temp = TT.scalar('temp')
    gater_below = None
    if state['rec_gating']:
        gater_below = gater_words[0](emb(x))
    reseter_below = None
    if state['rec_reseting']:
        reseter_below = reseter_words[0](emb(x))
    encoder_acts = [add_op(emb_words[0](emb(x),use_noise=False),
            si=0,
            use_noise=False,
            gater_below=gater_below,
            reseter_below=reseter_below)]
    if state['encoder_stack'] > 1:
        everything = encoder_proj[0](last(encoder_acts[-1]), use_noise=False)
    for si in xrange(1,state['encoder_stack']):
        gater_below = None
        if state['rec_gating']:
            gater_below = gater_words[si](emb(x))
        reseter_below = None
        if state['rec_reseting']:
            reseter_below = reseter_words[si](emb(x))
        encoder_acts.append(add_op(emb_words[si](emb(x),use_noise=False),
                si=si,
                state_below=encoder_acts[-1], use_noise=False,
                gater_below = gater_below,
                reseter_below = reseter_below))
        if state['encoder_stack'] > 1:
            everything += encoder_proj[si](last(encoder_acts[-1]), use_noise=False)
    if state['encoder_stack'] <= 1:
        encoder = encoder_acts[-1]
        everything = last(encoder)
    else:
        everything = encoder_act_layer(everything)

    init_state = []
    for si in xrange(state['decoder_stack']):
        if state['bias_code']:
            init_state.append(TT.reshape(bias_code[si](everything,
                use_noise=False), [1, state['dim']]))
        else:
            init_state.append(TT.alloc(numpy.float32(0), 1, state['dim']))

    if state['avg_word']:
        aword = emb(x,use_noise=False).out.mean(0)
        aword = word_code(aword, use_noise=False)
    else:
        aword = None

    def sample_fn(*args):
        aidx = 0; word_tm1 = args[aidx]
        aidx += 1; prob_tm1 = args[aidx]
        has_said_tm1 = []
        for si in xrange(state['decoder_stack']):
            aidx += 1; has_said_tm1.append(args[aidx])
        aidx += 1; ctx = args[aidx]
        if state['avg_word']:
            aidx += 1; awrd = args[aidx]
        else:
            awrd = None

        val = pop_op(proj_code(ctx), has_said_tm1, word=word_tm1,
                aword=awrd, one_step=True, use_noise=False)
        sample = output_layer.get_sample(state_below=val, temp=temp)
        logp = output_layer.get_cost(
                state_below=val.out.reshape([1, TT.cast(output_layer.n_in, 'int64')]),
                temp=temp, target=sample.reshape([1,1]), use_noise=False)
        gater_below = None
        if state['rec_gating']:
            gater_below = gater_words_t[0](emb_t(sample))
        reseter_below = None
        if state['rec_reseting']:
            reseter_below = reseter_words_t[0](emb_t(sample))
        has_said_t = [add_t_op(emb_words_t[0](emb_t(sample)),
                ctx,
                prev_val=has_said_tm1[0],
                gater_below=gater_below,
                reseter_below=reseter_below,
                one_step=True, use_noise=True,
                si=0)]
        for si in xrange(1, state['decoder_stack']):
            gater_below = None
            if state['rec_gating']:
                gater_below = gater_words_t[si](emb_t(sample))
            reseter_below = None
            if state['rec_reseting']:
                reseter_below = reseter_words_t[si](emb_t(sample))
            has_said_t.append(add_t_op(emb_words_t[si](emb_t(sample)),
                    ctx,
                    prev_val=has_said_tm1[si],
                    gater_below=gater_below,
                    reseter_below=reseter_below,
                    one_step=True, use_noise=True,
                    si=si, state_below=has_said_t[-1]))
        for si in xrange(state['decoder_stack']):
            if isinstance(has_said_t[si], list):
                has_said_t[si] = has_said_t[si][-1]
        rval = [sample, TT.cast(logp, 'float32')] + has_said_t
        return rval

    sampler_params = [everything]
    if state['avg_word']:
        sampler_params.append(aword)

    states = [TT.alloc(numpy.int64(0), n_steps)]
    states.append(TT.alloc(numpy.float32(0), n_steps))
    states += init_state

    outputs, updates = scan(sample_fn,
            states = states,
            params = sampler_params,
            n_steps= n_steps,
            name='sampler_scan'
            )
    samples = outputs[0]
    probs = outputs[1]

    sample_fn = theano.function(
        [n_steps, temp, x], [samples, probs.sum()],
        updates=updates,
        profile=False, name='sample_fn')

    model = LM_Model(
        cost_layer=nll,
        weight_noise_amount=state['weight_noise_amount'],
        valid_fn=valid_fn,
        sample_fn=sample_fn,
        clean_before_noise_fn = False,
        noise_fn=noise_fn,
        indx_word=state['indx_word_target'],
        indx_word_src=state['indx_word'],
        character_level=False,
        rng=rng)

    algo = SGD(model, state, train_data)

    def hook_fn():
        if not hasattr(model, 'word_indxs'): model.load_dict()
        if not hasattr(model, 'word_indxs_src'):
            model.word_indxs_src = model.word_indxs
        old_offset = train_data.offset
        if state['sample_reset']: train_data.reset()
        ns = 0
        for sidx in xrange(state['sample_n']):
            while True:
                batch = train_data.next()
                if batch:
                    break
            x = batch['x']
            y = batch['y']
            #xbow = batch['x_bow']
            masks = batch['x_mask']
            if x.ndim > 1:
                for idx in xrange(x.shape[1]):
                    ns += 1
                    if ns > state['sample_max']:
                        break
                    print 'Input: ',
                    for k in xrange(x[:,idx].shape[0]):
                        print model.word_indxs_src[x[:,idx][k]],
                        if model.word_indxs_src[x[:,idx][k]] == '<eol>':
                            break
                    print ''
                    print 'Target: ',
                    for k in xrange(y[:,idx].shape[0]):
                        print model.word_indxs[y[:,idx][k]],
                        if model.word_indxs[y[:,idx][k]] == '<eol>':
                            break
                    print ''
                    senlen = len(x[:,idx])
                    if len(numpy.where(masks[:,idx]==0)[0]) > 0:
                        senlen = numpy.where(masks[:,idx]==0)[0][0]
                    if senlen < 1:
                        continue
                    xx = x[:senlen, idx]
                    #xx = xx.reshape([xx.shape[0], 1])
                    model.get_samples(state['seqlen']+1,  1, xx)
            else:
                ns += 1
                model.get_samples(state['seqlen']+1,  1, x)
            if ns > state['sample_max']:
                break
        train_data.offset = old_offset
        return

    main = MainLoop(train_data, valid_data, None, model, algo, state, channel,
            reset = state['reset'], hooks = hook_fn)
    if state['reload']: main.load()
    main.main()

    if state["scoring"]:
        score_file = open(state["score_file"], "w")

        logger.info("Compiling score function")
        score_fn = theano.function(scoring_inputs, [-nll.cost_per_sample])

        count = 0
        n_samples = 0
        logger.info('Scoring phrases')
        for batch in train_data:
            if batch == None:
                continue
            if batch['x'].shape[0] <= 0 or \
                    batch['x_mask'].shape[0] <= 0 or \
                    batch['y'].shape[0] <= 0 or \
                    batch['y_mask'].shape[0] <= 0:
                logger.error('Wrong batch!!!')
                continue
            st = time.time()
            [scores] = score_fn(batch['x'], batch['x_mask'],
                    batch['y'], batch['y_mask'])
            up_time = time.time() - st
            for s in scores:
                print >>score_file, "{:.5f}".format(float(s))

            n_samples += batch['x'].shape[1]
            count += 1

            if state['flush_scores'] >= 1 and count % state['flush_scores'] == 0:
                score_file.flush()
                logger.debug("Scores flushed")
            logger.debug("{} batches, {} samples, {} per sample; example scores: {}".format(
                count, n_samples, up_time/scores.shape[0], scores[:5]))
        logger.info("Done")
        score_file.flush()
        score_file.close()

    if state['sampler_test']:
        # This is a test script: we only sample
        if not hasattr(model, 'word_indxs'): model.load_dict()
        if not hasattr(model, 'word_indxs_src'):
            model.word_indxs_src = model.word_indxs

        indx_word=pkl.load(open(state['word_indx'],'rb'))

        try:
            while True:
                try:
                    seqin = raw_input('Input Sequence: ')
                    n_samples = int(raw_input('How many samples? '))
                    alpha = float(raw_input('Inverse Temperature? '))

                    seqin = seqin.lower()
                    seqin = seqin.split()

                    seqlen = len(seqin)
                    seq = numpy.zeros(seqlen+1, dtype='int64')
                    for idx,sx in enumerate(seqin):
                        try:
                            seq[idx] = indx_word[sx]
                        except:
                            seq[idx] = indx_word[state['oov']]
                    seq[-1] = state['null_sym_source']

                except Exception:
                    print 'Something wrong with your input! Try again!'
                    continue

                sentences = []
                all_probs = []
                for sidx in xrange(n_samples):
                    #import ipdb; ipdb.set_trace()
                    [values, probs] = model.sample_fn(seqlen * 3, alpha, seq)
                    sen = []
                    for k in xrange(values.shape[0]):
                        if model.word_indxs[values[k]] == '<eol>':
                            break
                        sen.append(model.word_indxs[values[k]])
                    sentences.append(" ".join(sen))
                    all_probs.append(-probs)
                sprobs = numpy.argsort(all_probs)
                for pidx in sprobs:
                    print pidx,"(%f):"%(-all_probs[pidx]),sentences[pidx]
                print

        except KeyboardInterrupt:
            print 'Interrupted'
            pass