示例#1
0
def build_fn(args, embeddings):
    """
        Build training and testing functions.
    """
    in_x1 = T.imatrix('x1')
    in_x2 = T.imatrix('x2')
    in_x3 = T.imatrix('x3')
    in_mask1 = T.matrix('mask1')
    in_mask2 = T.matrix('mask2')
    in_mask3 = T.matrix('mask3')
    in_y = T.ivector('y')

    #batch x word_num x mea_num ( extra input)
    in_x4 = T.ftensor3('x4')

    l_in1 = lasagne.layers.InputLayer((None, None), in_x1)
    l_mask1 = lasagne.layers.InputLayer((None, None), in_mask1)
    l_emb1 = lasagne.layers.EmbeddingLayer(l_in1,
                                           args.vocab_size,
                                           args.embedding_size,
                                           W=embeddings)

    l_in2 = lasagne.layers.InputLayer((None, None), in_x2)
    l_mask2 = lasagne.layers.InputLayer((None, None), in_mask2)
    l_emb2 = lasagne.layers.EmbeddingLayer(l_in2,
                                           args.vocab_size,
                                           args.embedding_size,
                                           W=l_emb1.W)

    l_in3 = lasagne.layers.InputLayer((None, None), in_x3)
    l_mask3 = lasagne.layers.InputLayer((None, None), in_mask3)
    l_emb3 = lasagne.layers.EmbeddingLayer(l_in3,
                                           args.vocab_size,
                                           args.embedding_size,
                                           W=l_emb1.W)
    # extra
    l_in4 = lasagne.layers.InputLayer((None, None, args.mea_num), in_x4)

    if not args.tune_embedding:
        l_emb1.params[l_emb1.W].remove('trainable')
        l_emb2.params[l_emb2.W].remove('trainable')
        l_emb3.params[l_emb3.W].remove('trainable')

    args.rnn_output_size = args.hidden_size * 2 if args.bidir else args.hidden_size
    assert args.model is None
    network1 = nn_layers.stack_rnn(l_emb1,
                                   l_mask1,
                                   args.num_layers,
                                   args.hidden_size,
                                   grad_clipping=args.grad_clipping,
                                   dropout_rate=args.dropout_rate,
                                   only_return_final=(args.att_func == 'last'),
                                   bidir=args.bidir,
                                   name='d',
                                   rnn_layer=args.rnn_layer)

    network2 = nn_layers.stack_rnn(l_emb2,
                                   l_mask2,
                                   args.num_layers,
                                   args.hidden_size,
                                   grad_clipping=args.grad_clipping,
                                   dropout_rate=args.dropout_rate,
                                   only_return_final=True,
                                   bidir=args.bidir,
                                   name='q',
                                   rnn_layer=args.rnn_layer)
    if args.att_func == 'mlp':
        att = nn_layers.MLPAttentionLayer([network1, network2],
                                          args.rnn_output_size,
                                          mask_input=l_mask1)
    elif args.att_func == 'bilinear':
        att = nn_layers.BilinearAttentionLayer([network1, network2],
                                               args.rnn_output_size,
                                               mask_input=l_mask1)
        att_weightLayer = nn_layers.BilinearAttentionWeightLayer(
            [network1, network2], args.rnn_output_size, mask_input=l_mask1)
    elif args.att_func == 'avg':
        att = nn_layers.AveragePoolingLayer(network1, mask_input=l_mask1)
    elif args.att_func == 'last':
        att = network1
    elif args.att_func == 'dot':
        att = nn_layers.DotProductAttentionLayer([network1, network2],
                                                 mask_input=l_mask1)
    else:
        raise NotImplementedError('att_func = %s' % args.att_func)

    #weighted mean: passage embedding
    l_weight = lasagne.layers.DenseLayer(l_in4, 1, num_leading_axes=-1)
    att = nn_layers.WeightedAverageLayer([l_emb1, l_weight, l_mask1])

    network3 = nn_layers.stack_rnn(l_emb3,
                                   l_mask3,
                                   args.num_layers,
                                   args.hidden_size,
                                   grad_clipping=args.grad_clipping,
                                   dropout_rate=args.dropout_rate,
                                   only_return_final=True,
                                   bidir=args.bidir,
                                   name='o',
                                   rnn_layer=args.rnn_layer)
    network3 = lasagne.layers.ReshapeLayer(
        network3, (in_x1.shape[0], 4, args.rnn_output_size))
    network = nn_layers.BilinearDotLayer([network3, att], args.rnn_output_size)
    if args.pre_trained is not None:
        dic = utils.load_params(args.pre_trained)
        lasagne.layers.set_all_param_values(network, dic['params'])
        del dic['params']
        logging.info('Loaded pre-trained model: %s' % args.pre_trained)
        for dic_param in dic.iteritems():
            logging.info(dic_param)

    logging.info('#params: %d' %
                 lasagne.layers.count_params(network, trainable=True))
    logging.info('#fixed params: %d' %
                 lasagne.layers.count_params(network, trainable=False))
    for layer in lasagne.layers.get_all_layers(network):
        logging.info(layer)

    # Test functions
    test_prob = lasagne.layers.get_output(network, deterministic=True)
    test_prediction = T.argmax(test_prob, axis=-1)
    acc = T.sum(T.eq(test_prediction, in_y))
    test_fn = theano.function([in_x1, in_mask1, in_x3, in_mask3, in_y, in_x4],
                              [acc, test_prediction],
                              on_unused_input='warn')

    # Train functions
    train_prediction = lasagne.layers.get_output(network)
    train_prediction = T.clip(train_prediction, 1e-7, 1.0 - 1e-7)
    loss = lasagne.objectives.categorical_crossentropy(train_prediction,
                                                       in_y).mean()

    # Attention functions
    att_weight = lasagne.layers.get_output(att_weightLayer, deterministic=True)
    attention_fn = theano.function([in_x1, in_mask1, in_x2, in_mask2],
                                   att_weight,
                                   on_unused_input='warn')
    # TODO: lasagne.regularization.regularize_network_params(network, lasagne.regularization.l2)
    #    params = lasagne.layers.get_all_params(network)#, trainable=True)
    params = lasagne.layers.get_all_params(network, trainable=True)
    all_params = lasagne.layers.get_all_params(network)
    if args.optimizer == 'sgd':
        updates = lasagne.updates.sgd(loss, params, args.learning_rate)
    elif args.optimizer == 'adam':
        updates = lasagne.updates.adam(loss,
                                       params,
                                       learning_rate=args.learning_rate)
    elif args.optimizer == 'rmsprop':
        updates = lasagne.updates.rmsprop(loss,
                                          params,
                                          learning_rate=args.learning_rate)
    else:
        raise NotImplementedError('optimizer = %s' % args.optimizer)
    train_fn = theano.function([in_x1, in_mask1, in_x3, in_mask3, in_y, in_x4],
                               loss,
                               updates=updates,
                               on_unused_input='warn')

    return train_fn, test_fn, params, all_params
示例#2
0
def build_fn(args, embeddings):
    """
        Build training and testing functions.
    """
    in_x1 = T.imatrix('x1')
    in_x2 = T.imatrix('x2')
    in_x3 = T.imatrix('x3')
    in_mask1 = T.matrix('mask1')
    in_mask2 = T.matrix('mask2')
    in_mask3 = T.matrix('mask3')
    in_y = T.ivector('y')

    l_in1 = lasagne.layers.InputLayer((None, None), in_x1)
    l_mask1 = lasagne.layers.InputLayer((None, None), in_mask1)
    l_emb1 = lasagne.layers.EmbeddingLayer(l_in1,
                                           args.vocab_size,
                                           args.embedding_size,
                                           W=embeddings)

    l_in2 = lasagne.layers.InputLayer((None, None), in_x2)
    l_mask2 = lasagne.layers.InputLayer((None, None), in_mask2)
    l_emb2 = lasagne.layers.EmbeddingLayer(l_in2,
                                           args.vocab_size,
                                           args.embedding_size,
                                           W=l_emb1.W)

    l_in3 = lasagne.layers.InputLayer((None, None), in_x3)
    l_mask3 = lasagne.layers.InputLayer((None, None), in_mask3)
    l_emb3 = lasagne.layers.EmbeddingLayer(l_in3,
                                           args.vocab_size,
                                           args.embedding_size,
                                           W=l_emb1.W)

    if not args.tune_embedding:
        l_emb1.params[l_emb1.W].remove('trainable')
        l_emb2.params[l_emb2.W].remove('trainable')
        l_emb3.params[l_emb3.W].remove('trainable')

    args.rnn_output_size = args.hidden_size * 2 if args.bidir else args.hidden_size
    if args.model == "GA":
        l_d = l_emb1
        # NOTE: This implementation slightly differs from the original GA reader. Specifically:
        # 1. The query GRU is shared across hops.
        # 2. Dropout is applied to all hops (including the initial hop).
        # 3. Gated-attention is applied at the final layer as well.
        # 4. No character-level embeddings are used.

        l_q = nn_layers.stack_rnn(l_emb2,
                                  l_mask2,
                                  1,
                                  args.hidden_size,
                                  grad_clipping=args.grad_clipping,
                                  dropout_rate=args.dropout_rate,
                                  only_return_final=False,
                                  bidir=args.bidir,
                                  name='q',
                                  rnn_layer=args.rnn_layer)
        q_length = nn_layers.LengthLayer(l_mask2)
        network2 = QuerySliceLayer([l_q, q_length])
        for layer_num in xrange(args.num_GA_layers):
            l_d = nn_layers.stack_rnn(l_d,
                                      l_mask1,
                                      1,
                                      args.hidden_size,
                                      grad_clipping=args.grad_clipping,
                                      dropout_rate=args.dropout_rate,
                                      only_return_final=False,
                                      bidir=args.bidir,
                                      name='d' + str(layer_num),
                                      rnn_layer=args.rnn_layer)
            l_d = GatedAttentionLayerWithQueryAttention([l_d, l_q, l_mask2])
        network1 = l_d
    else:
        assert args.model is None
        network1 = nn_layers.stack_rnn(
            l_emb1,
            l_mask1,
            args.num_layers,
            args.hidden_size,
            grad_clipping=args.grad_clipping,
            dropout_rate=args.dropout_rate,
            only_return_final=(args.att_func == 'last'),
            bidir=args.bidir,
            name='d',
            rnn_layer=args.rnn_layer)

        network2 = nn_layers.stack_rnn(l_emb2,
                                       l_mask2,
                                       args.num_layers,
                                       args.hidden_size,
                                       grad_clipping=args.grad_clipping,
                                       dropout_rate=args.dropout_rate,
                                       only_return_final=True,
                                       bidir=args.bidir,
                                       name='q',
                                       rnn_layer=args.rnn_layer)
    if args.att_func == 'mlp':
        att = nn_layers.MLPAttentionLayer([network1, network2],
                                          args.rnn_output_size,
                                          mask_input=l_mask1)
    elif args.att_func == 'bilinear':
        att = nn_layers.BilinearAttentionLayer([network1, network2],
                                               args.rnn_output_size,
                                               mask_input=l_mask1)
    elif args.att_func == 'avg':
        att = nn_layers.AveragePoolingLayer(network1, mask_input=l_mask1)
    elif args.att_func == 'last':
        att = network1
    elif args.att_func == 'dot':
        att = nn_layers.DotProductAttentionLayer([network1, network2],
                                                 mask_input=l_mask1)
    else:
        raise NotImplementedError('att_func = %s' % args.att_func)
    network3 = nn_layers.stack_rnn(l_emb3,
                                   l_mask3,
                                   args.num_layers,
                                   args.hidden_size,
                                   grad_clipping=args.grad_clipping,
                                   dropout_rate=args.dropout_rate,
                                   only_return_final=True,
                                   bidir=args.bidir,
                                   name='o',
                                   rnn_layer=args.rnn_layer)
    network3 = lasagne.layers.ReshapeLayer(
        network3, (in_x1.shape[0], 4, args.rnn_output_size))
    network = nn_layers.BilinearDotLayer([network3, att], args.rnn_output_size)
    if args.pre_trained is not None:
        dic = utils.load_params(args.pre_trained)
        lasagne.layers.set_all_param_values(network, dic['params'])
        del dic['params']
        logging.info('Loaded pre-trained model: %s' % args.pre_trained)
        for dic_param in dic.iteritems():
            logging.info(dic_param)

    logging.info('#params: %d' %
                 lasagne.layers.count_params(network, trainable=True))
    logging.info('#fixed params: %d' %
                 lasagne.layers.count_params(network, trainable=False))
    for layer in lasagne.layers.get_all_layers(network):
        logging.info(layer)

    # Test functions
    test_prob = lasagne.layers.get_output(network, deterministic=True)
    test_prediction = T.argmax(test_prob, axis=-1)
    acc = T.sum(T.eq(test_prediction, in_y))
    test_fn = theano.function(
        [in_x1, in_mask1, in_x2, in_mask2, in_x3, in_mask3, in_y],
        [acc, test_prediction],
        on_unused_input='warn')

    # Train functions
    train_prediction = lasagne.layers.get_output(network)
    train_prediction = T.clip(train_prediction, 1e-7, 1.0 - 1e-7)
    loss = lasagne.objectives.categorical_crossentropy(train_prediction,
                                                       in_y).mean()
    # TODO: lasagne.regularization.regularize_network_params(network, lasagne.regularization.l2)
    params = lasagne.layers.get_all_params(network)  #, trainable=True)
    all_params = lasagne.layers.get_all_params(network)
    if args.optimizer == 'sgd':
        updates = lasagne.updates.sgd(loss, params, args.learning_rate)
    elif args.optimizer == 'adam':
        updates = lasagne.updates.adam(loss,
                                       params,
                                       learning_rate=args.learning_rate)
    elif args.optimizer == 'rmsprop':
        updates = lasagne.updates.rmsprop(loss,
                                          params,
                                          learning_rate=args.learning_rate)
    else:
        raise NotImplementedError('optimizer = %s' % args.optimizer)
    train_fn = theano.function(
        [in_x1, in_mask1, in_x2, in_mask2, in_x3, in_mask3, in_y],
        loss,
        updates=updates,
        on_unused_input='warn')

    return train_fn, test_fn, params, all_params
示例#3
0
def build_fn(args, embeddings):
    """
        Build training and testing functions.
    """
    in_x1 = T.imatrix('x1')
    in_x2 = T.imatrix('x2')
    in_mask1 = T.matrix('mask1')
    in_mask2 = T.matrix('mask2')
    in_l = T.matrix('l')
    in_y = T.ivector('y')

    l_in1 = lasagne.layers.InputLayer((None, None), in_x1)
    l_mask1 = lasagne.layers.InputLayer((None, None), in_mask1)
    l_emb1 = lasagne.layers.EmbeddingLayer(l_in1,
                                           args.vocab_size,
                                           args.embedding_size,
                                           W=embeddings)

    l_in2 = lasagne.layers.InputLayer((None, None), in_x2)
    l_mask2 = lasagne.layers.InputLayer((None, None), in_mask2)
    l_emb2 = lasagne.layers.EmbeddingLayer(l_in2,
                                           args.vocab_size,
                                           args.embedding_size,
                                           W=l_emb1.W)

    network1 = nn_layers.stack_rnn(l_emb1,
                                   l_mask1,
                                   args.num_layers,
                                   args.hidden_size,
                                   grad_clipping=args.grad_clipping,
                                   dropout_rate=args.dropout_rate,
                                   only_return_final=(args.att_func == 'last'),
                                   bidir=args.bidir,
                                   name='d',
                                   rnn_layer=args.rnn_layer)

    network2 = nn_layers.stack_rnn(l_emb2,
                                   l_mask2,
                                   args.num_layers,
                                   args.hidden_size,
                                   grad_clipping=args.grad_clipping,
                                   dropout_rate=args.dropout_rate,
                                   only_return_final=True,
                                   bidir=args.bidir,
                                   name='q',
                                   rnn_layer=args.rnn_layer)

    args.rnn_output_size = args.hidden_size * 2 if args.bidir else args.hidden_size

    if args.att_func == 'mlp':
        att = nn_layers.MLPAttentionLayer([network1, network2],
                                          args.rnn_output_size,
                                          mask_input=l_mask1)
    elif args.att_func == 'bilinear':
        att = nn_layers.BilinearAttentionLayer([network1, network2],
                                               args.rnn_output_size,
                                               mask_input=l_mask1)
    elif args.att_func == 'avg':
        att = nn_layers.AveragePoolingLayer(network1, mask_input=l_mask1)
    elif args.att_func == 'last':
        att = network1
    elif args.att_func == 'dot':
        att = nn_layers.DotProductAttentionLayer([network1, network2],
                                                 mask_input=l_mask1)
    else:
        raise NotImplementedError('att_func = %s' % args.att_func)

    network = lasagne.layers.DenseLayer(
        att, args.num_labels, nonlinearity=lasagne.nonlinearities.softmax)

    if args.pre_trained is not None:
        dic = utils.load_params(args.pre_trained)
        lasagne.layers.set_all_param_values(network,
                                            dic['params'],
                                            trainable=True)
        del dic['params']
        logging.info('Loaded pre-trained model: %s' % args.pre_trained)
        for dic_param in dic.iteritems():
            logging.info(dic_param)

    logging.info('#params: %d' %
                 lasagne.layers.count_params(network, trainable=True))
    for layer in lasagne.layers.get_all_layers(network):
        logging.info(layer)

    # Test functions
    test_prob = lasagne.layers.get_output(network, deterministic=True) * in_l
    test_prediction = T.argmax(test_prob, axis=-1)
    acc = T.sum(T.eq(test_prediction, in_y))
    test_fn = theano.function([in_x1, in_mask1, in_x2, in_mask2, in_l, in_y],
                              acc)

    # Train functions
    train_prediction = lasagne.layers.get_output(network) * in_l
    train_prediction = train_prediction / \
        train_prediction.sum(axis=1).reshape((train_prediction.shape[0], 1))
    train_prediction = T.clip(train_prediction, 1e-7, 1.0 - 1e-7)
    loss = lasagne.objectives.categorical_crossentropy(train_prediction,
                                                       in_y).mean()
    # TODO: lasagne.regularization.regularize_network_params(network, lasagne.regularization.l2)
    params = lasagne.layers.get_all_params(network, trainable=True)

    if args.optimizer == 'sgd':
        updates = lasagne.updates.sgd(loss, params, args.learning_rate)
    elif args.optimizer == 'adam':
        updates = lasagne.updates.adam(loss, params)
    elif args.optimizer == 'rmsprop':
        updates = lasagne.updates.rmsprop(loss, params)
    else:
        raise NotImplementedError('optimizer = %s' % args.optimizer)
    train_fn = theano.function([in_x1, in_mask1, in_x2, in_mask2, in_l, in_y],
                               loss,
                               updates=updates)

    return train_fn, test_fn, params
示例#4
0
def build_fn(args, embeddings):
    """
        Build training and testing functions.
    """
    in_x1 = T.imatrix('x1')
    in_x2 = T.imatrix('x2')
    in_mask1 = T.matrix('mask1')
    in_mask2 = T.matrix('mask2')
    in_l = T.matrix('l')
    in_y = T.ivector('y')

    l_in1 = lasagne.layers.InputLayer((None, None), in_x1)
    l_mask1 = lasagne.layers.InputLayer((None, None), in_mask1)
    l_emb1 = lasagne.layers.EmbeddingLayer(l_in1, args.vocab_size,
                                           args.embedding_size, W=embeddings)

    l_in2 = lasagne.layers.InputLayer((None, None), in_x2)
    l_mask2 = lasagne.layers.InputLayer((None, None), in_mask2)
    l_emb2 = lasagne.layers.EmbeddingLayer(l_in2, args.vocab_size,
                                           args.embedding_size, W=l_emb1.W)

    network1 = nn_layers.stack_rnn(l_emb1, l_mask1, args.num_layers, args.hidden_size,
                                   grad_clipping=args.grad_clipping,
                                   dropout_rate=args.dropout_rate,
                                   only_return_final=(args.att_func == 'last'),
                                   bidir=args.bidir,
                                   name='d',
                                   rnn_layer=args.rnn_layer)

    network2 = nn_layers.stack_rnn(l_emb2, l_mask2, args.num_layers, args.hidden_size,
                                   grad_clipping=args.grad_clipping,
                                   dropout_rate=args.dropout_rate,
                                   only_return_final=True,
                                   bidir=args.bidir,
                                   name='q',
                                   rnn_layer=args.rnn_layer)

    args.rnn_output_size = args.hidden_size * 2 if args.bidir else args.hidden_size

    if args.att_func == 'mlp':
        att = nn_layers.MLPAttentionLayer([network1, network2], args.rnn_output_size,
                                          mask_input=l_mask1)
    elif args.att_func == 'bilinear':
        att = nn_layers.BilinearAttentionLayer([network1, network2], args.rnn_output_size,
                                               mask_input=l_mask1)
    elif args.att_func == 'avg':
        att = nn_layers.AveragePoolingLayer(network1, mask_input=l_mask1)
    elif args.att_func == 'last':
        att = network1
    elif args.att_func == 'dot':
        att = nn_layers.DotProductAttentionLayer([network1, network2], mask_input=l_mask1)
    else:
        raise NotImplementedError('att_func = %s' % args.att_func)

    network = lasagne.layers.DenseLayer(att, args.num_labels,
                                        nonlinearity=lasagne.nonlinearities.softmax)

    if args.pre_trained is not None:
        dic = utils.load_params(args.pre_trained)
        lasagne.layers.set_all_param_values(network, dic['params'], trainable=True)
        del dic['params']
        logging.info('Loaded pre-trained model: %s' % args.pre_trained)
        for dic_param in dic.iteritems():
            logging.info(dic_param)

    logging.info('#params: %d' % lasagne.layers.count_params(network, trainable=True))
    for layer in lasagne.layers.get_all_layers(network):
        logging.info(layer)

    # Test functions
    test_prob = lasagne.layers.get_output(network, deterministic=True) * in_l
    test_prediction = T.argmax(test_prob, axis=-1)
    acc = T.sum(T.eq(test_prediction, in_y))
    test_fn = theano.function([in_x1, in_mask1, in_x2, in_mask2, in_l, in_y], acc)

    # Train functions
    train_prediction = lasagne.layers.get_output(network) * in_l
    train_prediction = train_prediction / \
        train_prediction.sum(axis=1).reshape((train_prediction.shape[0], 1))
    train_prediction = T.clip(train_prediction, 1e-7, 1.0 - 1e-7)
    loss = lasagne.objectives.categorical_crossentropy(train_prediction, in_y).mean()
    # TODO: lasagne.regularization.regularize_network_params(network, lasagne.regularization.l2)
    params = lasagne.layers.get_all_params(network, trainable=True)

    if args.optimizer == 'sgd':
        updates = lasagne.updates.sgd(loss, params, args.learning_rate)
    elif args.optimizer == 'adam':
        updates = lasagne.updates.adam(loss, params)
    elif args.optimizer == 'rmsprop':
        updates = lasagne.updates.rmsprop(loss, params)
    else:
        raise NotImplementedError('optimizer = %s' % args.optimizer)
    train_fn = theano.function([in_x1, in_mask1, in_x2, in_mask2, in_l, in_y],
                               loss, updates=updates)

    return train_fn, test_fn, params
示例#5
0
def build_fn(args, embeddings):
    """
        Build training and testing functions.
    """
    if args.para_shared_model is not None:
        dic = utils.load_params(args.para_shared_model)
        params_shared = dic['params']
        params_name = [
            'W', 'o_layer1.W_in_to_updategate', 'o_layer1.W_hid_to_updategate',
            'o_layer1.b_updategate', 'o_layer1.W_in_to_resetgate',
            'o_layer1.W_hid_to_resetgate', 'o_layer1.b_resetgate',
            'o_layer1.W_in_to_hidden_update',
            'o_layer1.W_hid_to_hidden_update', 'o_layer1.b_hidden_update',
            'o_layer1.hid_init', 'o_back_layer1.W_in_to_updategate',
            'o_back_layer1.W_hid_to_updategate', 'o_back_layer1.b_updategate',
            'o_back_layer1.W_in_to_resetgate',
            'o_back_layer1.W_hid_to_resetgate', 'o_back_layer1.b_resetgate',
            'o_back_layer1.W_in_to_hidden_update',
            'o_back_layer1.W_hid_to_hidden_update',
            'o_back_layer1.b_hidden_update', 'o_back_layer1.hid_init',
            'd_layer1.W_in_to_updategate', 'd_layer1.W_hid_to_updategate',
            'd_layer1.b_updategate', 'd_layer1.W_in_to_resetgate',
            'd_layer1.W_hid_to_resetgate', 'd_layer1.b_resetgate',
            'd_layer1.W_in_to_hidden_update',
            'd_layer1.W_hid_to_hidden_update', 'd_layer1.b_hidden_update',
            'd_layer1.hid_init', 'd_back_layer1.W_in_to_updategate',
            'd_back_layer1.W_hid_to_updategate', 'd_back_layer1.b_updategate',
            'd_back_layer1.W_in_to_resetgate',
            'd_back_layer1.W_hid_to_resetgate', 'd_back_layer1.b_resetgate',
            'd_back_layer1.W_in_to_hidden_update',
            'd_back_layer1.W_hid_to_hidden_update',
            'd_back_layer1.b_hidden_update', 'd_back_layer1.hid_init',
            'q_layer1.W_in_to_updategate', 'q_layer1.W_hid_to_updategate',
            'q_layer1.b_updategate', 'q_layer1.W_in_to_resetgate',
            'q_layer1.W_hid_to_resetgate', 'q_layer1.b_resetgate',
            'q_layer1.W_in_to_hidden_update',
            'q_layer1.W_hid_to_hidden_update', 'q_layer1.b_hidden_update',
            'q_layer1.hid_init', 'q_back_layer1.W_in_to_updategate',
            'q_back_layer1.W_hid_to_updategate', 'q_back_layer1.b_updategate',
            'q_back_layer1.W_in_to_resetgate',
            'q_back_layer1.W_hid_to_resetgate', 'q_back_layer1.b_resetgate',
            'q_back_layer1.W_in_to_hidden_update',
            'q_back_layer1.W_hid_to_hidden_update',
            'q_back_layer1.b_hidden_update', 'q_back_layer1.hid_init',
            'W_bilinear', 'W_bilinear'
        ]
    in_x1 = T.imatrix('x1')
    in_x3 = T.imatrix('x3')
    in_mask1 = T.matrix('mask1')
    in_mask3 = T.matrix('mask3')
    in_y = T.ivector('y')

    #batch x word_num x mea_num
    in_x4 = T.ftensor3('x4')

    l_in1 = lasagne.layers.InputLayer((None, None), in_x1)
    l_mask1 = lasagne.layers.InputLayer((None, None), in_mask1)
    Embed_W = params_shared[params_name.index('W')]
    l_emb1 = lasagne.layers.EmbeddingLayer(l_in1,
                                           args.vocab_size,
                                           args.embedding_size,
                                           W=Embed_W)

    l_in3 = lasagne.layers.InputLayer((None, None), in_x3)
    l_mask3 = lasagne.layers.InputLayer((None, None), in_mask3)
    l_emb3 = lasagne.layers.EmbeddingLayer(l_in3,
                                           args.vocab_size,
                                           args.embedding_size,
                                           W=l_emb1.W)
    # x4 is the human attention
    l_in4 = lasagne.layers.InputLayer((None, None, args.mea_num), in_x4)

    if not args.tune_embedding:
        l_emb1.params[l_emb1.W].remove('trainable')
        l_emb3.params[l_emb3.W].remove('trainable')

    args.rnn_output_size = args.hidden_size * 2 if args.bidir else args.hidden_size
    assert args.model is None
    network1 = nn_layers.stack_rnn(l_emb1,
                                   l_mask1,
                                   args.num_layers,
                                   args.hidden_size,
                                   grad_clipping=args.grad_clipping,
                                   dropout_rate=args.dropout_rate,
                                   only_return_final=(args.att_func == 'last'),
                                   bidir=args.bidir,
                                   name='d',
                                   rnn_layer=args.rnn_layer)
    #weighted mean: passage embedding
    #    weight_mlp_np = np.array([[1.]])
    #    b_mlp = np.array([0.])
    #    l_weight = lasagne.layers.DenseLayer(l_in4, 1, num_leading_axes=-1,
    #                                         name='w_dense', W=weight_mlp_np, b=b_mlp)
    # pass a Linear layer and get human ATT  l_weight: batch x word_num x 1     activation -- sigmoid
    l_weight = lasagne.layers.DenseLayer(l_in4,
                                         1,
                                         num_leading_axes=-1,
                                         nonlinearity=nonlinearities.sigmoid,
                                         name='w_dense')

    att = nn_layers.WeightedAverageLayer([network1, l_weight, l_mask1],
                                         name='w_aver')
    if RAW:
        att = nn_layers.WeightedAverageLayer([network1, l_in4, l_mask1],
                                             name='w_aver')
    if SAG:
        # network1 1x1 conv
        # l_in4 1x1 conv

        pass

    #options
    network3 = nn_layers.stack_rnn(l_emb3,
                                   l_mask3,
                                   args.num_layers,
                                   args.hidden_size,
                                   grad_clipping=args.grad_clipping,
                                   dropout_rate=args.dropout_rate,
                                   only_return_final=True,
                                   bidir=args.bidir,
                                   name='o',
                                   rnn_layer=args.rnn_layer)
    network3 = lasagne.layers.ReshapeLayer(
        network3, (in_x1.shape[0], 4, args.rnn_output_size))
    #answer
    network = nn_layers.BilinearDotLayer([network3, att], args.rnn_output_size)
    # if not args.tune_embedding:
    #     network.params[network.W].remove('trainable')
    #parameter sharing
    params_initial = lasagne.layers.get_all_params(network)
    params_set = []
    for params_initial_tmp in params_initial:
        if str(params_initial_tmp) in ['w_dense.W', 'w_dense.b']:
            params_set = params_set + [params_initial_tmp.get_value()]
        elif str(params_initial_tmp) == 'W_bilinear':
            params_set = params_set + [params_shared[-1]]
        else:
            params_set = params_set + [
                params_shared[params_name.index(str(params_initial_tmp))]
            ]
    lasagne.layers.set_all_param_values(network, params_set)

    if args.pre_trained is not None:
        dic = utils.load_params(args.pre_trained)
        lasagne.layers.set_all_param_values(network, dic['params'])
        del dic['params']
        logging.info('Loaded pre-trained model: %s' % args.pre_trained)
        for dic_param in dic.iteritems():
            logging.info(dic_param)

    logging.info('#params: %d' %
                 lasagne.layers.count_params(network, trainable=True))
    logging.info('#fixed params: %d' %
                 lasagne.layers.count_params(network, trainable=False))
    for layer in lasagne.layers.get_all_layers(network):
        logging.info(layer)

    # Test functions
    test_prob = lasagne.layers.get_output(network, deterministic=True)
    test_prediction = T.argmax(test_prob, axis=-1)
    acc = T.sum(T.eq(test_prediction, in_y))
    test_fn = theano.function([in_x1, in_mask1, in_x3, in_mask3, in_y, in_x4],
                              [acc, test_prediction],
                              on_unused_input='warn')

    # Train functions
    train_prediction = lasagne.layers.get_output(network)
    train_prediction = T.clip(train_prediction, 1e-7, 1.0 - 1e-7)
    loss = lasagne.objectives.categorical_crossentropy(train_prediction,
                                                       in_y).mean()

    # TODO: lasagne.regularization.regularize_network_params(network, lasagne.regularization.l2)
    #    params = lasagne.layers.get_all_params(network)#, trainable=True)
    params_init = lasagne.layers.get_all_params(network, trainable=True)
    params = lasagne.layers.get_all_params(network, trainable=True)
    if not (args.tune_sar):
        for params_tmp in params_init:
            if not (str(params_tmp) in ['w_dense.W', 'w_dense.b']):
                print(params_tmp)
                params.remove(params_tmp)
                print(len(params))
                print(params)
            else:
                print(params_tmp)


#                params.remove(params_tmp)
    all_params = lasagne.layers.get_all_params(network)
    if args.optimizer == 'sgd':
        updates = lasagne.updates.sgd(loss, params, args.learning_rate)
    elif args.optimizer == 'adam':
        updates = lasagne.updates.adam(loss,
                                       params,
                                       learning_rate=args.learning_rate)
    elif args.optimizer == 'rmsprop':
        updates = lasagne.updates.rmsprop(loss,
                                          params,
                                          learning_rate=args.learning_rate)
    else:
        raise NotImplementedError('optimizer = %s' % args.optimizer)
    train_fn = theano.function([in_x1, in_mask1, in_x3, in_mask3, in_y, in_x4],
                               loss,
                               updates=updates,
                               on_unused_input='warn')

    return train_fn, test_fn, params, all_params