def sample_step(x_tm1, h1_tm1, h2_tm1, h3_tm1, k_tm1, w_tm1, ctx):
        xinp_h1_t, xgate_h1_t = inp_to_h1.proj(x_tm1)
        xinp_h2_t, xgate_h2_t = inp_to_h2.proj(x_tm1)
        xinp_h3_t, xgate_h3_t = inp_to_h3.proj(x_tm1)

        attinp_h1, attgate_h1 = att_to_h1.proj(w_tm1)

        h1_t = cell1.step(xinp_h1_t + attinp_h1, xgate_h1_t + attgate_h1,
                          h1_tm1)
        h1inp_h2, h1gate_h2 = h1_to_h2.proj(h1_t)
        h1inp_h3, h1gate_h3 = h1_to_h3.proj(h1_t)

        a_t = h1_t.dot(h1_to_att_a)
        b_t = h1_t.dot(h1_to_att_b)
        k_t = h1_t.dot(h1_to_att_k)

        a_t = tensor.exp(a_t)
        b_t = tensor.exp(b_t)
        k_t = k_tm1 + tensor.exp(k_t)

        ss_t = calc_phi(k_t, a_t, b_t, u)
        # calculate and return stopping criteria
        sh_t = calc_phi(k_t, a_t, b_t, u_max)
        ss5 = ss_t.dimshuffle(0, 1, 'x')
        ss6 = ss5 * ctx.dimshuffle(1, 0, 2)
        w_t = ss6.sum(axis=1)

        attinp_h2, attgate_h2 = att_to_h2.proj(w_t)
        attinp_h3, attgate_h3 = att_to_h3.proj(w_t)

        h2_t = cell2.step(xinp_h2_t + h1inp_h2 + attinp_h2,
                          xgate_h2_t + h1gate_h2 + attgate_h2, h2_tm1)

        h2inp_h3, h2gate_h3 = h2_to_h3.proj(h2_t)

        h3_t = cell3.step(xinp_h3_t + h1inp_h3 + h2inp_h3 + attinp_h3,
                          xgate_h3_t + h1gate_h3 + h2gate_h3 + attgate_h3,
                          h3_tm1)
        out_t = h1_t.dot(h1_to_outs) + h2_t.dot(h2_to_outs) + h3_t.dot(
            h3_to_outs)
        theano.printing.Print("out_t.shape")(out_t.shape)
        l1_t = relu(out_t.dot(l1_proj) + b_l1_proj)
        l2_t = relu(l1_t.dot(l2_proj) + b_l2_proj)
        l3_t = relu(l2_t.dot(l3_proj) + b_l3_proj)
        theano.printing.Print("l3_t.shape")(l3_t.shape)
        pred_t = l3_t.dot(softmax_proj) + b_softmax_proj
        theano.printing.Print("pred_t.shape")(pred_t.shape)
        pred_t = pred_t.reshape((-1, n_features, n_softmax))
        pred_t = softmax(pred_t * (1. + softmax_bias_sym))
        theano.printing.Print("pred_t.shape")(pred_t.shape)

        shp = pred_t.shape
        pred_t = pred_t.reshape((-1, shp[-1]))
        samp_t = sample_softmax(pred_t, srng)
        samp_t = samp_t.reshape((shp[0], shp[1]))
        x_t = samp_t
        theano.printing.Print("samp_t.shape")(samp_t.shape)
        theano.printing.Print("x_t.shape")(x_t.shape)
        return x_t, h1_t, h2_t, h3_t, k_t, w_t, ss_t, sh_t
Example #2
0
    def sample_step(x_tm1, h1_tm1, h2_tm1, h3_tm1):
        xinp_h1_t, xgate_h1_t = inp_to_h1.proj(x_tm1)
        xinp_h2_t, xgate_h2_t = inp_to_h2.proj(x_tm1)
        xinp_h3_t, xgate_h3_t = inp_to_h3.proj(x_tm1)

        h1_t = cell1.step(xinp_h1_t, xgate_h1_t, h1_tm1)
        h1inp_h2, h1gate_h2 = h1_to_h2.proj(h1_t)
        h1inp_h3, h1gate_h3 = h1_to_h3.proj(h1_t)


        h2_t = cell2.step(xinp_h2_t + h1inp_h2,
                          xgate_h2_t + h1gate_h2, h2_tm1)

        h2inp_h3, h2gate_h3 = h2_to_h3.proj(h2_t)

        h3_t = cell3.step(xinp_h3_t + h1inp_h3 + h2inp_h3,
                          xgate_h3_t + h1gate_h3 + h2gate_h3,
                          h3_tm1)
        out_t = h1_t.dot(h1_to_outs) + h2_t.dot(h2_to_outs) + h3_t.dot(
            h3_to_outs) + b_to_outs

        theano.printing.Print("x_tm1.shape")(x_tm1.shape)

        theano.printing.Print("out_t.shape")(out_t.shape)
        inpt_oh = theano_one_hot(x_tm1, n_classes=n_bins)
        theano.printing.Print("inpt_oh.shape")(inpt_oh.shape)
        shp = inpt_oh.shape
        prev_t = inpt_oh
        for i in range(n_frame):
            partial_out_t = out_t[:, i * n_hid: (i + 1) * n_hid]
            theano.printing.Print("partial_out_t.shape")(partial_out_t.shape)
            theano.printing.Print("prev_t.shape")(prev_t.shape)
            shp = prev_t.shape
            prev_ti = prev_t[:, i:].reshape((shp[0], -1))
            theano.printing.Print("prev_ti.shape")(prev_ti.shape)
            features_t = tensor.concatenate((partial_out_t, prev_ti),
                         axis=1)
            theano.printing.Print("features_t.shape")(features_t.shape)
            mlp1_t = relu(features_t.dot(mlp1_w) + mlp1_b)
            mlp2_t = relu(mlp1_t.dot(mlp2_w) + mlp2_b)
            mlp3_t = relu(mlp2_t.dot(mlp3_w) + mlp3_b)
            pred_t = softmax(mlp3_t.dot(pred_w) + pred_b)
            theano.printing.Print("pred_t.shape")(pred_t.shape)
            samp_t = sample_softmax(pred_t, srng)
            theano.printing.Print("samp_t.shape")(samp_t.shape)
            samp_t_oh = theano_one_hot(samp_t, n_classes=n_bins)
            samp_t_oh = samp_t_oh.dimshuffle(0, 'x', 1)
            theano.printing.Print("samp_t_oh.shape")(samp_t_oh.shape)
            prev_t = tensor.concatenate((prev_t, samp_t_oh), axis=1)
            theano.printing.Print("prev_t.shape")(prev_t.shape)
        pred_t = prev_t[:, n_frame:].argmax(axis=-1)
        x_t = tensor.cast(pred_t, theano.config.floatX)
        return x_t, h1_t, h2_t, h3_t
    def sample_step(x_tm1, h1_tm1, h2_tm1, h3_tm1, k_tm1, w_tm1, ctx):
        xinp_h1_t, xgate_h1_t = inp_to_h1.proj(x_tm1)
        xinp_h2_t, xgate_h2_t = inp_to_h2.proj(x_tm1)
        xinp_h3_t, xgate_h3_t = inp_to_h3.proj(x_tm1)

        attinp_h1, attgate_h1 = att_to_h1.proj(w_tm1)

        h1_t = cell1.step(xinp_h1_t + attinp_h1, xgate_h1_t + attgate_h1,
                          h1_tm1)
        h1inp_h2, h1gate_h2 = h1_to_h2.proj(h1_t)
        h1inp_h3, h1gate_h3 = h1_to_h3.proj(h1_t)

        a_t = h1_t.dot(h1_to_att_a)
        b_t = h1_t.dot(h1_to_att_b)
        k_t = h1_t.dot(h1_to_att_k)

        a_t = tensor.exp(a_t)
        b_t = tensor.exp(b_t)
        k_t = k_tm1 + tensor.exp(k_t)

        ss_t = calc_phi(k_t, a_t, b_t, u)
        # calculate and return stopping criteria
        sh_t = calc_phi(k_t, a_t, b_t, u_max)
        ss5 = ss_t.dimshuffle(0, 1, 'x')
        ss6 = ss5 * ctx.dimshuffle(1, 0, 2)
        w_t = ss6.sum(axis=1)

        attinp_h2, attgate_h2 = att_to_h2.proj(w_t)
        attinp_h3, attgate_h3 = att_to_h3.proj(w_t)

        h2_t = cell2.step(xinp_h2_t + h1inp_h2 + attinp_h2,
                          xgate_h2_t + h1gate_h2 + attgate_h2, h2_tm1)

        h2inp_h3, h2gate_h3 = h2_to_h3.proj(h2_t)

        h3_t = cell3.step(xinp_h3_t + h1inp_h3 + h2inp_h3 + attinp_h3,
                          xgate_h3_t + h1gate_h3 + h2gate_h3 + attgate_h3,
                          h3_tm1)
        out_t = h1_t.dot(h1_to_outs) + h2_t.dot(h2_to_outs) + h3_t.dot(
            h3_to_outs)

        out_t = softmax(out_t)
        s = sample_softmax(out_t, srng)
        x_t = s
        return x_t, h1_t, h2_t, h3_t, k_t, w_t, ss_t, sh_t
    theano.printing.Print("sliced.shape")(sliced.shape)

    border_mode = "half"
    blur_conv = conv2d(sliced, w_blurconv, border_mode=border_mode)
    theano.printing.Print("w_blurconv.shape")(w_blurconv.shape)
    theano.printing.Print("blur_conv.shape")(blur_conv.shape)
    final_conv = conv2d(blur_conv, w_finalconv, border_mode=border_mode)
    theano.printing.Print("w_finalconv.shape")(w_finalconv.shape)
    theano.printing.Print("final_conv.shape")(final_conv.shape)
    """

    outs_deconv = final_conv[:, :, :, :input_dim]
    outs_deconv = outs_deconv.dimshuffle(2, 0, 3, 1)
    outs_deconv = outs_deconv[:target.shape[0]]
    theano.printing.Print("outs_deconv.shape")(outs_deconv.shape)
    preds = softmax(outs_deconv + b_softmax)
    theano.printing.Print("preds.shape")(preds.shape)
    theano.printing.Print("target.shape")(target.shape)
    target = theano_one_hot(target, r=n_bins)
    theano.printing.Print("target.shape")(target.shape)
    cost = categorical_crossentropy(preds, target)
    theano.printing.Print("cost.shape")(cost.shape)
    theano.printing.Print("mask.shape")(mask.shape)
    cost = cost * mask.dimshuffle(0, 1, 'x')
    cost = cost.sum() / (target.shape[0] * target.shape[1])
    grads = tensor.grad(cost, params)

    init_x = as_shared(np_zeros((minibatch_size, n_out)))
    srng = RandomStreams(1999)

    """
    init_hidden = tensor.zeros((shuff_inpt.shape[1], n_v_proj),
                                dtype=theano.config.floatX)
    theano.printing.Print("init_hidden.shape")(init_hidden.shape)
    v_h1, updates = theano.scan(
        fn=out_step,
        sequences=[vinp_h1, vgate_h1],
        outputs_info=[init_hidden])
    pre_pred = v_h1.dot(pred_proj) + pred_b
    pre_pred = pre_pred.dimshuffle(1, 0, 2)
    shp = pre_pred.shape
    # Have to undo the minibatch_size * time features the same way they came in
    pre_pred = pre_pred.reshape((minibatch_size, shp[0] // minibatch_size,
                                 shp[1], shp[2]))
    pre_pred = pre_pred.dimshuffle(1, 0, 2, 3)
    theano.printing.Print("pre_pred.shape")(pre_pred.shape)
    pred = softmax(pre_pred)
    theano.printing.Print("target.shape")(target.shape)
    target = theano_one_hot(target, n_classes=n_bins)
    theano.printing.Print("target.shape")(target.shape)
    theano.printing.Print("pred.shape")(pred.shape)

    cost = categorical_crossentropy(pred, target)
    cost = cost * mask.dimshuffle(0, 1, 'x')
    # sum over sequence length and features, mean over minibatch
    cost = cost.dimshuffle(0, 2, 1)
    theano.printing.Print("cost.shape")(cost.shape)
    cost = cost.reshape((-1, cost.shape[2]))
    theano.printing.Print("cost.shape")(cost.shape)
    cost = cost.sum(axis=0).mean()
    """
    # optimize sum of probabilities rather than product?
Example #6
0
    theano.printing.Print("sliced.shape")(sliced.shape)

    border_mode = "half"
    blur_conv = conv2d(sliced, w_blurconv, border_mode=border_mode)
    theano.printing.Print("w_blurconv.shape")(w_blurconv.shape)
    theano.printing.Print("blur_conv.shape")(blur_conv.shape)
    final_conv = conv2d(blur_conv, w_finalconv, border_mode=border_mode)
    theano.printing.Print("w_finalconv.shape")(w_finalconv.shape)
    theano.printing.Print("final_conv.shape")(final_conv.shape)
    """

    outs_deconv = final_conv[:, :, :, :input_dim]
    outs_deconv = outs_deconv.dimshuffle(2, 0, 3, 1)
    outs_deconv = outs_deconv[:target.shape[0]]
    theano.printing.Print("outs_deconv.shape")(outs_deconv.shape)
    preds = softmax(outs_deconv + b_softmax)
    theano.printing.Print("preds.shape")(preds.shape)
    theano.printing.Print("target.shape")(target.shape)
    target = theano_one_hot(target, r=n_bins)
    theano.printing.Print("target.shape")(target.shape)
    cost = categorical_crossentropy(preds, target)
    theano.printing.Print("cost.shape")(cost.shape)
    theano.printing.Print("mask.shape")(mask.shape)
    cost = cost * mask.dimshuffle(0, 1, 'x')
    cost = cost.sum() / (target.shape[0] * target.shape[1])
    grads = tensor.grad(cost, params)

    init_x = as_shared(np_zeros((minibatch_size, n_out)))
    srng = RandomStreams(1999)
    """
    # Used to calculate stopping heuristic from sections 5.3
Example #7
0
    def sample_step(x_tm1, h1_tm1, h2_tm1, h3_tm1, k_tm1, w_tm1, ctx):
        theano.printing.Print("x_tm1.shape")(x_tm1.shape)
        pt1 = theano_one_hot(x_tm1[:, 0], n_classes=n_softmax1)
        theano.printing.Print("pt1.shape")(pt1.shape)
        pt2 = theano_one_hot(x_tm1[:, 1], n_classes=n_softmax2)
        theano.printing.Print("pt2.shape")(pt2.shape)
        x_tm1 = tensor.concatenate((pt1, pt2), axis=-1)
        theano.printing.Print("x_tm1.shape")(x_tm1.shape)
        x_tm1_reduced = x_tm1.dot(inp_proj) + inp_b
        theano.printing.Print("x_tm1_reduced.shape")(x_tm1_reduced.shape)
        xinp_h1_t, xgate_h1_t = inp_to_h1.proj(x_tm1_reduced)
        xinp_h2_t, xgate_h2_t = inp_to_h2.proj(x_tm1_reduced)
        xinp_h3_t, xgate_h3_t = inp_to_h3.proj(x_tm1_reduced)

        attinp_h1, attgate_h1 = att_to_h1.proj(w_tm1)

        h1_t = cell1.step(xinp_h1_t + attinp_h1, xgate_h1_t + attgate_h1,
                          h1_tm1)
        h1inp_h2, h1gate_h2 = h1_to_h2.proj(h1_t)
        h1inp_h3, h1gate_h3 = h1_to_h3.proj(h1_t)

        a_t = h1_t.dot(h1_to_att_a)
        b_t = h1_t.dot(h1_to_att_b)
        k_t = h1_t.dot(h1_to_att_k)

        a_t = tensor.exp(a_t)
        b_t = tensor.exp(b_t)
        k_t = k_tm1 + tensor.exp(k_t)

        ss_t = calc_phi(k_t, a_t, b_t, u)
        # calculate and return stopping criteria
        sh_t = calc_phi(k_t, a_t, b_t, u_max)
        ss5 = ss_t.dimshuffle(0, 1, 'x')
        ss6 = ss5 * ctx.dimshuffle(1, 0, 2)
        w_t = ss6.sum(axis=1)

        attinp_h2, attgate_h2 = att_to_h2.proj(w_t)
        attinp_h3, attgate_h3 = att_to_h3.proj(w_t)

        h2_t = cell2.step(xinp_h2_t + h1inp_h2 + attinp_h2,
                          xgate_h2_t + h1gate_h2 + attgate_h2, h2_tm1)

        h2inp_h3, h2gate_h3 = h2_to_h3.proj(h2_t)

        h3_t = cell3.step(xinp_h3_t + h1inp_h3 + h2inp_h3 + attinp_h3,
                          xgate_h3_t + h1gate_h3 + h2gate_h3 + attgate_h3,
                          h3_tm1)
        out_t = h1_t.dot(h1_to_outs) + h2_t.dot(h2_to_outs) + h3_t.dot(
            h3_to_outs)
        #l1_t = relu(out_t.dot(l1_proj) + l1_b)
        #l2_t = relu(l1_t.dot(l2_proj) + l2_b)
        pred1_t = softmax(out_t.dot(softmax1_proj) + softmax1_b)
        pred2_t = softmax(out_t.dot(softmax2_proj) + softmax2_b)

        s1_t = sample_softmax(pred1_t, srng)
        theano.printing.Print("s1_t.shape")(s1_t.shape)
        s2_t = sample_softmax(pred2_t, srng)
        theano.printing.Print("s2_t.shape")(s2_t.shape)
        s1_t = s1_t.dimshuffle(0, 'x')
        theano.printing.Print("s1_t.shape")(s1_t.shape)
        s2_t = s2_t.dimshuffle(0, 'x')
        theano.printing.Print("s2_t.shape")(s2_t.shape)
        x_t = tensor.concatenate((s1_t, s2_t), axis=1)
        theano.printing.Print("x_t.shape")(x_t.shape)
        return x_t, h1_t, h2_t, h3_t, k_t, w_t, ss_t, sh_t
Example #8
0
    (h1, h2, h3, kappa, w), updates = theano.scan(
        fn=step,
        sequences=[inp_h1, inpgate_h1, inp_h2, inpgate_h2, inp_h3, inpgate_h3],
        outputs_info=[init_h1, init_h2, init_h3, init_kappa, init_w],
        non_sequences=[context])

    outs = h1.dot(h1_to_outs) + h2.dot(h2_to_outs) + h3.dot(h3_to_outs)
    theano.printing.Print("outs.shape")(outs.shape)
    outs_shape = outs.shape

    #l1 = relu(outs.dot(l1_proj) + l1_b)
    #l2 = relu(l1.dot(l2_proj) + l2_b)
    #theano.printing.Print("l1.shape")(l1.shape)
    #theano.printing.Print("l2.shape")(l2.shape)

    pred1 = softmax(outs.dot(softmax1_proj) + softmax1_b)
    pred2 = softmax(outs.dot(softmax2_proj) + softmax2_b)
    theano.printing.Print("pred1.shape")(pred1.shape)
    theano.printing.Print("pred2.shape")(pred2.shape)

    # Make one hot targets
    theano.printing.Print("target.shape")(target.shape)
    target1 = target[:, :, 0]
    target2 = target[:, :, 1]
    theano.printing.Print("target1.shape")(target1.shape)
    theano.printing.Print("target2.shape")(target2.shape)

    shp = target1.shape
    target1 = target1.ravel()
    target1 = theano_one_hot(target1, n_classes=n_softmax1)
    target1 = target1.reshape((shp[0], shp[1], n_softmax1))
Example #9
0
        joint = tensor.concatenate((inpt_oh, next_oh), axis=2)
        sliced_context = joint[:, :, i:i + n_frame]
        theano.printing.Print("sliced_context.shape")(sliced_context.shape)
        shp = sliced_context.shape
        sliced_context = sliced_context.reshape((shp[0], shp[1], -1))
        features = tensor.concatenate((partial_outs, sliced_context), axis=2)
        theano.printing.Print("partial_outs.shape")(partial_outs.shape)
        theano.printing.Print("joint.shape")(joint.shape)
        theano.printing.Print("sliced_context.shape")(sliced_context.shape)
        theano.printing.Print("features.shape")(features.shape)
        shp = features.shape
        mlp_inpt = features.reshape((-1, shp[-1]))
        mlp1 = relu(mlp_inpt.dot(mlp1_w) + mlp1_b)
        mlp2 = relu(mlp1.dot(mlp2_w) + mlp2_b)
        mlp3 = relu(mlp2.dot(mlp3_w) + mlp3_b)
        pred = softmax(mlp3.dot(pred_w) + pred_b)
        theano.printing.Print("pred.shape")(pred.shape)
        pred = pred.reshape((shp[0], shp[1], -1))
        theano.printing.Print("pred.shape")(pred.shape)
        pred_i.append(pred.dimshuffle(0, 1, 2, 'x'))
    pred = tensor.concatenate(pred_i, axis=-1).dimshuffle(0, 1, 3, 2)
    theano.printing.Print("pred.shape")(pred.shape)
    theano.printing.Print("target.shape")(target.shape)
    target = theano_one_hot(target, n_classes=n_bins)
    theano.printing.Print("target.shape")(target.shape)
    # dimshuffle so batch is on last axis
    cost = categorical_crossentropy(pred, target, eps=1E-9)
    theano.printing.Print("cost.shape")(cost.shape)
    theano.printing.Print("mask.shape")(mask.shape)

    cost = cost * mask.dimshuffle(0, 1, 'x')
    theano.printing.Print("init_pred.shape")(init_pred.shape)
    theano.printing.Print("init_hidden.shape")(init_hidden.shape)
    r, updates = theano.scan(fn=out_step,
                             sequences=[shuff_inpt, vinp],
                             outputs_info=[init_pred, init_hidden])
    (pre_pred, v_h1) = r
    theano.printing.Print("pre_pred.shape")(pre_pred.shape)
    pre_pred = pre_pred.dimshuffle(1, 0, 'x')
    shp = pre_pred.shape
    theano.printing.Print("pre_pred.shape")(pre_pred.shape)
    pre_pred = pre_pred.reshape(
        (minibatch_size, shp[0] // minibatch_size, shp[1], shp[2]))
    theano.printing.Print("pre_pred.shape")(pre_pred.shape)
    pre_pred = pre_pred.dimshuffle(1, 0, 2, 3)
    theano.printing.Print("pre_pred.shape")(pre_pred.shape)
    pred = softmax(pre_pred)
    theano.printing.Print("pred.shape")(pred.shape)
    theano.printing.Print("target.shape")(target.shape)
    target = theano_one_hot(target, n_classes=n_bins)
    theano.printing.Print("target.shape")(target.shape)
    raise ValueError()

    cost = categorical_crossentropy(pred, target)
    cost = cost * mask.dimshuffle(0, 1, 'x')
    # sum over sequence length and features, mean over minibatch
    cost = cost.dimshuffle(0, 2, 1)
    cost = cost.reshape((-1, cost.shape[2]))
    cost = cost.sum(axis=0).mean()

    l2_penalty = 0
    for p in list(set(params) - set(biases)):
        h2_t = GRU(h1_h2_t, h1gate_h2_t, h2_tm1,
                   n_hid, n_hid, random_state)

        h2_h3_t, h2gate_h3_t = GRUFork([h2_t], [n_hid], n_hid, random_state)

        h3_t = GRU(h2_h3_t,
                   h2gate_h3_t, h3_tm1,
                   n_hid, n_hid, random_state)
        return h1_t, h2_t, h3_t

    (h1, h2, h3), updates = theano.scan(
        fn=step,
        sequences=[in_h1, ingate_h1],
        outputs_info=[init_h1, init_h2, init_h3])
    out = Linear([h3], [n_hid], n_bins, random_state)
    pred = softmax(out)
    shp = target.shape
    target = target.reshape((shp[0], shp[1]))
    target = theano_one_hot(target, n_classes=n_bins)
    # dimshuffle so batch is on last axis
    cost = categorical_crossentropy(pred, target)
    cost = cost * mask.dimshuffle(0, 1)
    # sum over sequence length and features, mean over minibatch
    cost = cost.dimshuffle(1, 0)
    cost = cost.mean()
    # convert to bits vs nats
    cost = cost * tensor.cast(1.44269504089, theano.config.floatX)

    params = param_search(cost, lambda x: hasattr(x, "param"))
    print_param_info(params)
Example #12
0
        h1_t = GRU(in_h1_t, ingate_h1_t, h1_tm1, n_hid, n_hid, random_state)
        h1_h2_t, h1gate_h2_t = GRUFork([h1_t], [n_hid], n_hid, random_state)

        h2_t = GRU(h1_h2_t, h1gate_h2_t, h2_tm1, n_hid, n_hid, random_state)

        h2_h3_t, h2gate_h3_t = GRUFork([h2_t], [n_hid], n_hid, random_state)

        h3_t = GRU(h2_h3_t, h2gate_h3_t, h3_tm1, n_hid, n_hid, random_state)
        return h1_t, h2_t, h3_t

    (h1, h2,
     h3), updates = theano.scan(fn=step,
                                sequences=[in_h1, ingate_h1],
                                outputs_info=[init_h1, init_h2, init_h3])
    out = Linear([h3], [n_hid], n_bins, random_state)
    pred = softmax(out)
    shp = target.shape
    target = target.reshape((shp[0], shp[1]))
    target = theano_one_hot(target, n_classes=n_bins)
    # dimshuffle so batch is on last axis
    cost = categorical_crossentropy(pred, target)
    cost = cost * mask.dimshuffle(0, 1)
    # sum over sequence length and features, mean over minibatch
    cost = cost.dimshuffle(1, 0)
    cost = cost.mean()
    # convert to bits vs nats
    cost = cost * tensor.cast(1.44269504089, theano.config.floatX)

    params = param_search(cost, lambda x: hasattr(x, "param"))
    print_param_info(params)
    (h1, h2, h3, kappa, w), updates = theano.scan(
        fn=step,
        sequences=[inp_h1, inpgate_h1, inp_h2, inpgate_h2, inp_h3, inpgate_h3],
        outputs_info=[init_h1, init_h2, init_h3, init_kappa, init_w],
        non_sequences=[context])

    outs = h1.dot(h1_to_outs) + h2.dot(h2_to_outs) + h3.dot(h3_to_outs)
    l1 = relu(outs.dot(l1_proj) + b_l1_proj)
    l2 = relu(l1.dot(l2_proj) + b_l2_proj)
    l3 = relu(l2.dot(l3_proj) + b_l3_proj)
    shp = l3.shape

    l3 = l3.reshape((-1, shp[-1]))
    preds = l3.dot(softmax_proj) + b_softmax_proj
    preds = preds.reshape((shp[0], shp[1], n_features, n_softmax))
    preds = softmax(preds * (1. + softmax_bias_sym))
    theano.printing.Print("preds.shape")(preds.shape)
    theano.printing.Print("target.shape")(target.shape)
    target = theano_one_hot(target, n_softmax)
    theano.printing.Print("target.shape")(target.shape)
    cost = categorical_crossentropy(preds, target, eps=1E-9)
    theano.printing.Print("cost.shape")(cost.shape)

    cost = cost * mask.dimshuffle(0, 1, 'x')
    cost = cost.sum() / cut_len
    grads = tensor.grad(cost, params)
    grads = gradient_clipping(grads, 10.)

    learning_rate = 1E-4

    opt = adam(params, learning_rate)
             init_h1, init_h2, init_h3, init_kappa, init_w, context)

    r = step(inp_h1[1], inpgate_h1[1], inp_h2[1], inpgate_h2[1],
             inp_h3[1], inpgate_h3[1],
             r[0], r[1], r[2], r[3], r[4], context)
    """
    (h1, h2, h3, kappa, w), updates = theano.scan(
        fn=step,
        sequences=[inp_h1, inpgate_h1,
                   inp_h2, inpgate_h2,
                   inp_h3, inpgate_h3],
        outputs_info=[init_h1, init_h2, init_h3, init_kappa, init_w],
        non_sequences=[context])

    outs = h1.dot(h1_to_outs) + h2.dot(h2_to_outs) + h3.dot(h3_to_outs)
    outs = softmax(outs)
    cost = categorical_crossentropy(outs, target)

    cost = cost * mask
    cost = cost.sum() / cut_len
    grads = tensor.grad(cost, params)
    grads = gradient_clipping(grads, 10.)

    learning_rate = 1E-4

    opt = adam(params, learning_rate)
    updates = opt.updates(params, grads)

    train_function = theano.function([X_sym, X_mask_sym, c_sym, c_mask_sym,
                                      init_h1, init_h2, init_h3, init_kappa,
                                      init_w],