示例#1
0
def build_fn(args, embeddings):
    """
        Build training and testing functions.
    """
    in_x1 = T.imatrix('x1')
    in_x2 = T.imatrix('x2')
    in_x3 = T.imatrix('x3')
    in_mask1 = T.matrix('mask1')
    in_mask2 = T.matrix('mask2')
    in_mask3 = T.matrix('mask3')
    in_y = T.ivector('y')

    #batch x word_num x mea_num ( extra input)
    in_x4 = T.ftensor3('x4')

    l_in1 = lasagne.layers.InputLayer((None, None), in_x1)
    l_mask1 = lasagne.layers.InputLayer((None, None), in_mask1)
    l_emb1 = lasagne.layers.EmbeddingLayer(l_in1,
                                           args.vocab_size,
                                           args.embedding_size,
                                           W=embeddings)

    l_in2 = lasagne.layers.InputLayer((None, None), in_x2)
    l_mask2 = lasagne.layers.InputLayer((None, None), in_mask2)
    l_emb2 = lasagne.layers.EmbeddingLayer(l_in2,
                                           args.vocab_size,
                                           args.embedding_size,
                                           W=l_emb1.W)

    l_in3 = lasagne.layers.InputLayer((None, None), in_x3)
    l_mask3 = lasagne.layers.InputLayer((None, None), in_mask3)
    l_emb3 = lasagne.layers.EmbeddingLayer(l_in3,
                                           args.vocab_size,
                                           args.embedding_size,
                                           W=l_emb1.W)
    # extra
    l_in4 = lasagne.layers.InputLayer((None, None, args.mea_num), in_x4)

    if not args.tune_embedding:
        l_emb1.params[l_emb1.W].remove('trainable')
        l_emb2.params[l_emb2.W].remove('trainable')
        l_emb3.params[l_emb3.W].remove('trainable')

    args.rnn_output_size = args.hidden_size * 2 if args.bidir else args.hidden_size
    assert args.model is None
    network1 = nn_layers.stack_rnn(l_emb1,
                                   l_mask1,
                                   args.num_layers,
                                   args.hidden_size,
                                   grad_clipping=args.grad_clipping,
                                   dropout_rate=args.dropout_rate,
                                   only_return_final=(args.att_func == 'last'),
                                   bidir=args.bidir,
                                   name='d',
                                   rnn_layer=args.rnn_layer)

    network2 = nn_layers.stack_rnn(l_emb2,
                                   l_mask2,
                                   args.num_layers,
                                   args.hidden_size,
                                   grad_clipping=args.grad_clipping,
                                   dropout_rate=args.dropout_rate,
                                   only_return_final=True,
                                   bidir=args.bidir,
                                   name='q',
                                   rnn_layer=args.rnn_layer)
    if args.att_func == 'mlp':
        att = nn_layers.MLPAttentionLayer([network1, network2],
                                          args.rnn_output_size,
                                          mask_input=l_mask1)
    elif args.att_func == 'bilinear':
        att = nn_layers.BilinearAttentionLayer([network1, network2],
                                               args.rnn_output_size,
                                               mask_input=l_mask1)
        att_weightLayer = nn_layers.BilinearAttentionWeightLayer(
            [network1, network2], args.rnn_output_size, mask_input=l_mask1)
    elif args.att_func == 'avg':
        att = nn_layers.AveragePoolingLayer(network1, mask_input=l_mask1)
    elif args.att_func == 'last':
        att = network1
    elif args.att_func == 'dot':
        att = nn_layers.DotProductAttentionLayer([network1, network2],
                                                 mask_input=l_mask1)
    else:
        raise NotImplementedError('att_func = %s' % args.att_func)

    #weighted mean: passage embedding
    l_weight = lasagne.layers.DenseLayer(l_in4, 1, num_leading_axes=-1)
    att = nn_layers.WeightedAverageLayer([l_emb1, l_weight, l_mask1])

    network3 = nn_layers.stack_rnn(l_emb3,
                                   l_mask3,
                                   args.num_layers,
                                   args.hidden_size,
                                   grad_clipping=args.grad_clipping,
                                   dropout_rate=args.dropout_rate,
                                   only_return_final=True,
                                   bidir=args.bidir,
                                   name='o',
                                   rnn_layer=args.rnn_layer)
    network3 = lasagne.layers.ReshapeLayer(
        network3, (in_x1.shape[0], 4, args.rnn_output_size))
    network = nn_layers.BilinearDotLayer([network3, att], args.rnn_output_size)
    if args.pre_trained is not None:
        dic = utils.load_params(args.pre_trained)
        lasagne.layers.set_all_param_values(network, dic['params'])
        del dic['params']
        logging.info('Loaded pre-trained model: %s' % args.pre_trained)
        for dic_param in dic.iteritems():
            logging.info(dic_param)

    logging.info('#params: %d' %
                 lasagne.layers.count_params(network, trainable=True))
    logging.info('#fixed params: %d' %
                 lasagne.layers.count_params(network, trainable=False))
    for layer in lasagne.layers.get_all_layers(network):
        logging.info(layer)

    # Test functions
    test_prob = lasagne.layers.get_output(network, deterministic=True)
    test_prediction = T.argmax(test_prob, axis=-1)
    acc = T.sum(T.eq(test_prediction, in_y))
    test_fn = theano.function([in_x1, in_mask1, in_x3, in_mask3, in_y, in_x4],
                              [acc, test_prediction],
                              on_unused_input='warn')

    # Train functions
    train_prediction = lasagne.layers.get_output(network)
    train_prediction = T.clip(train_prediction, 1e-7, 1.0 - 1e-7)
    loss = lasagne.objectives.categorical_crossentropy(train_prediction,
                                                       in_y).mean()

    # Attention functions
    att_weight = lasagne.layers.get_output(att_weightLayer, deterministic=True)
    attention_fn = theano.function([in_x1, in_mask1, in_x2, in_mask2],
                                   att_weight,
                                   on_unused_input='warn')
    # TODO: lasagne.regularization.regularize_network_params(network, lasagne.regularization.l2)
    #    params = lasagne.layers.get_all_params(network)#, trainable=True)
    params = lasagne.layers.get_all_params(network, trainable=True)
    all_params = lasagne.layers.get_all_params(network)
    if args.optimizer == 'sgd':
        updates = lasagne.updates.sgd(loss, params, args.learning_rate)
    elif args.optimizer == 'adam':
        updates = lasagne.updates.adam(loss,
                                       params,
                                       learning_rate=args.learning_rate)
    elif args.optimizer == 'rmsprop':
        updates = lasagne.updates.rmsprop(loss,
                                          params,
                                          learning_rate=args.learning_rate)
    else:
        raise NotImplementedError('optimizer = %s' % args.optimizer)
    train_fn = theano.function([in_x1, in_mask1, in_x3, in_mask3, in_y, in_x4],
                               loss,
                               updates=updates,
                               on_unused_input='warn')

    return train_fn, test_fn, params, all_params
示例#2
0
def build_fn(args, embeddings):
    """
        Build training and testing functions.
    """
    in_x1 = T.imatrix('x1')
    in_x2 = T.imatrix('x2')
    in_mask1 = T.matrix('mask1')
    in_mask2 = T.matrix('mask2')
    in_l = T.matrix('l')
    in_y = T.ivector('y')

    l_in1 = lasagne.layers.InputLayer((None, None), in_x1)
    l_mask1 = lasagne.layers.InputLayer((None, None), in_mask1)
    l_emb1 = lasagne.layers.EmbeddingLayer(l_in1,
                                           args.vocab_size,
                                           args.embedding_size,
                                           W=embeddings)

    l_in2 = lasagne.layers.InputLayer((None, None), in_x2)
    l_mask2 = lasagne.layers.InputLayer((None, None), in_mask2)
    l_emb2 = lasagne.layers.EmbeddingLayer(l_in2,
                                           args.vocab_size,
                                           args.embedding_size,
                                           W=l_emb1.W)

    network1 = nn_layers.stack_rnn(l_emb1,
                                   l_mask1,
                                   args.num_layers,
                                   args.hidden_size,
                                   grad_clipping=args.grad_clipping,
                                   dropout_rate=args.dropout_rate,
                                   only_return_final=(args.att_func == 'last'),
                                   bidir=args.bidir,
                                   name='d',
                                   rnn_layer=args.rnn_layer)

    network2 = nn_layers.stack_rnn(l_emb2,
                                   l_mask2,
                                   args.num_layers,
                                   args.hidden_size,
                                   grad_clipping=args.grad_clipping,
                                   dropout_rate=args.dropout_rate,
                                   only_return_final=True,
                                   bidir=args.bidir,
                                   name='q',
                                   rnn_layer=args.rnn_layer)

    args.rnn_output_size = args.hidden_size * 2 if args.bidir else args.hidden_size

    if args.att_func == 'mlp':
        att = nn_layers.MLPAttentionLayer([network1, network2],
                                          args.rnn_output_size,
                                          mask_input=l_mask1)
    elif args.att_func == 'bilinear':
        att = nn_layers.BilinearAttentionLayer([network1, network2],
                                               args.rnn_output_size,
                                               mask_input=l_mask1)
    elif args.att_func == 'avg':
        att = nn_layers.AveragePoolingLayer(network1, mask_input=l_mask1)
    elif args.att_func == 'last':
        att = network1
    elif args.att_func == 'dot':
        att = nn_layers.DotProductAttentionLayer([network1, network2],
                                                 mask_input=l_mask1)
    else:
        raise NotImplementedError('att_func = %s' % args.att_func)

    network = lasagne.layers.DenseLayer(
        att, args.num_labels, nonlinearity=lasagne.nonlinearities.softmax)

    if args.pre_trained is not None:
        dic = utils.load_params(args.pre_trained)
        lasagne.layers.set_all_param_values(network,
                                            dic['params'],
                                            trainable=True)
        del dic['params']
        logging.info('Loaded pre-trained model: %s' % args.pre_trained)
        for dic_param in dic.iteritems():
            logging.info(dic_param)

    logging.info('#params: %d' %
                 lasagne.layers.count_params(network, trainable=True))
    for layer in lasagne.layers.get_all_layers(network):
        logging.info(layer)

    # Test functions
    test_prob = lasagne.layers.get_output(network, deterministic=True) * in_l
    test_prediction = T.argmax(test_prob, axis=-1)
    acc = T.sum(T.eq(test_prediction, in_y))
    test_fn = theano.function([in_x1, in_mask1, in_x2, in_mask2, in_l, in_y],
                              acc)

    # Train functions
    train_prediction = lasagne.layers.get_output(network) * in_l
    train_prediction = train_prediction / \
        train_prediction.sum(axis=1).reshape((train_prediction.shape[0], 1))
    train_prediction = T.clip(train_prediction, 1e-7, 1.0 - 1e-7)
    loss = lasagne.objectives.categorical_crossentropy(train_prediction,
                                                       in_y).mean()
    # TODO: lasagne.regularization.regularize_network_params(network, lasagne.regularization.l2)
    params = lasagne.layers.get_all_params(network, trainable=True)

    if args.optimizer == 'sgd':
        updates = lasagne.updates.sgd(loss, params, args.learning_rate)
    elif args.optimizer == 'adam':
        updates = lasagne.updates.adam(loss, params)
    elif args.optimizer == 'rmsprop':
        updates = lasagne.updates.rmsprop(loss, params)
    else:
        raise NotImplementedError('optimizer = %s' % args.optimizer)
    train_fn = theano.function([in_x1, in_mask1, in_x2, in_mask2, in_l, in_y],
                               loss,
                               updates=updates)

    return train_fn, test_fn, params
示例#3
0
def build_fn(args, embeddings):
    """
        Build training and testing functions.
    """
    in_x1 = T.imatrix('x1')
    in_x2 = T.imatrix('x2')
    in_x3 = T.imatrix('x3')
    in_mask1 = T.matrix('mask1')
    in_mask2 = T.matrix('mask2')
    in_mask3 = T.matrix('mask3')
    in_y = T.ivector('y')

    l_in1 = lasagne.layers.InputLayer((None, None), in_x1)
    l_mask1 = lasagne.layers.InputLayer((None, None), in_mask1)
    l_emb1 = lasagne.layers.EmbeddingLayer(l_in1,
                                           args.vocab_size,
                                           args.embedding_size,
                                           W=embeddings)

    l_in2 = lasagne.layers.InputLayer((None, None), in_x2)
    l_mask2 = lasagne.layers.InputLayer((None, None), in_mask2)
    l_emb2 = lasagne.layers.EmbeddingLayer(l_in2,
                                           args.vocab_size,
                                           args.embedding_size,
                                           W=l_emb1.W)

    l_in3 = lasagne.layers.InputLayer((None, None), in_x3)
    l_mask3 = lasagne.layers.InputLayer((None, None), in_mask3)
    l_emb3 = lasagne.layers.EmbeddingLayer(l_in3,
                                           args.vocab_size,
                                           args.embedding_size,
                                           W=l_emb1.W)

    if not args.tune_embedding:
        l_emb1.params[l_emb1.W].remove('trainable')
        l_emb2.params[l_emb2.W].remove('trainable')
        l_emb3.params[l_emb3.W].remove('trainable')

    args.rnn_output_size = args.hidden_size * 2 if args.bidir else args.hidden_size
    if args.model == "GA":
        l_d = l_emb1
        # NOTE: This implementation slightly differs from the original GA reader. Specifically:
        # 1. The query GRU is shared across hops.
        # 2. Dropout is applied to all hops (including the initial hop).
        # 3. Gated-attention is applied at the final layer as well.
        # 4. No character-level embeddings are used.

        l_q = nn_layers.stack_rnn(l_emb2,
                                  l_mask2,
                                  1,
                                  args.hidden_size,
                                  grad_clipping=args.grad_clipping,
                                  dropout_rate=args.dropout_rate,
                                  only_return_final=False,
                                  bidir=args.bidir,
                                  name='q',
                                  rnn_layer=args.rnn_layer)
        q_length = nn_layers.LengthLayer(l_mask2)
        network2 = QuerySliceLayer([l_q, q_length])
        for layer_num in xrange(args.num_GA_layers):
            l_d = nn_layers.stack_rnn(l_d,
                                      l_mask1,
                                      1,
                                      args.hidden_size,
                                      grad_clipping=args.grad_clipping,
                                      dropout_rate=args.dropout_rate,
                                      only_return_final=False,
                                      bidir=args.bidir,
                                      name='d' + str(layer_num),
                                      rnn_layer=args.rnn_layer)
            l_d = GatedAttentionLayerWithQueryAttention([l_d, l_q, l_mask2])
        network1 = l_d
    else:
        assert args.model is None
        network1 = nn_layers.stack_rnn(
            l_emb1,
            l_mask1,
            args.num_layers,
            args.hidden_size,
            grad_clipping=args.grad_clipping,
            dropout_rate=args.dropout_rate,
            only_return_final=(args.att_func == 'last'),
            bidir=args.bidir,
            name='d',
            rnn_layer=args.rnn_layer)

        network2 = nn_layers.stack_rnn(l_emb2,
                                       l_mask2,
                                       args.num_layers,
                                       args.hidden_size,
                                       grad_clipping=args.grad_clipping,
                                       dropout_rate=args.dropout_rate,
                                       only_return_final=True,
                                       bidir=args.bidir,
                                       name='q',
                                       rnn_layer=args.rnn_layer)
    if args.att_func == 'mlp':
        att = nn_layers.MLPAttentionLayer([network1, network2],
                                          args.rnn_output_size,
                                          mask_input=l_mask1)
    elif args.att_func == 'bilinear':
        att = nn_layers.BilinearAttentionLayer([network1, network2],
                                               args.rnn_output_size,
                                               mask_input=l_mask1)
    elif args.att_func == 'avg':
        att = nn_layers.AveragePoolingLayer(network1, mask_input=l_mask1)
    elif args.att_func == 'last':
        att = network1
    elif args.att_func == 'dot':
        att = nn_layers.DotProductAttentionLayer([network1, network2],
                                                 mask_input=l_mask1)
    else:
        raise NotImplementedError('att_func = %s' % args.att_func)
    network3 = nn_layers.stack_rnn(l_emb3,
                                   l_mask3,
                                   args.num_layers,
                                   args.hidden_size,
                                   grad_clipping=args.grad_clipping,
                                   dropout_rate=args.dropout_rate,
                                   only_return_final=True,
                                   bidir=args.bidir,
                                   name='o',
                                   rnn_layer=args.rnn_layer)
    network3 = lasagne.layers.ReshapeLayer(
        network3, (in_x1.shape[0], 4, args.rnn_output_size))
    network = nn_layers.BilinearDotLayer([network3, att], args.rnn_output_size)
    if args.pre_trained is not None:
        dic = utils.load_params(args.pre_trained)
        lasagne.layers.set_all_param_values(network, dic['params'])
        del dic['params']
        logging.info('Loaded pre-trained model: %s' % args.pre_trained)
        for dic_param in dic.iteritems():
            logging.info(dic_param)

    logging.info('#params: %d' %
                 lasagne.layers.count_params(network, trainable=True))
    logging.info('#fixed params: %d' %
                 lasagne.layers.count_params(network, trainable=False))
    for layer in lasagne.layers.get_all_layers(network):
        logging.info(layer)

    # Test functions
    test_prob = lasagne.layers.get_output(network, deterministic=True)
    test_prediction = T.argmax(test_prob, axis=-1)
    acc = T.sum(T.eq(test_prediction, in_y))
    test_fn = theano.function(
        [in_x1, in_mask1, in_x2, in_mask2, in_x3, in_mask3, in_y],
        [acc, test_prediction],
        on_unused_input='warn')

    # Train functions
    train_prediction = lasagne.layers.get_output(network)
    train_prediction = T.clip(train_prediction, 1e-7, 1.0 - 1e-7)
    loss = lasagne.objectives.categorical_crossentropy(train_prediction,
                                                       in_y).mean()
    # TODO: lasagne.regularization.regularize_network_params(network, lasagne.regularization.l2)
    params = lasagne.layers.get_all_params(network)  #, trainable=True)
    all_params = lasagne.layers.get_all_params(network)
    if args.optimizer == 'sgd':
        updates = lasagne.updates.sgd(loss, params, args.learning_rate)
    elif args.optimizer == 'adam':
        updates = lasagne.updates.adam(loss,
                                       params,
                                       learning_rate=args.learning_rate)
    elif args.optimizer == 'rmsprop':
        updates = lasagne.updates.rmsprop(loss,
                                          params,
                                          learning_rate=args.learning_rate)
    else:
        raise NotImplementedError('optimizer = %s' % args.optimizer)
    train_fn = theano.function(
        [in_x1, in_mask1, in_x2, in_mask2, in_x3, in_mask3, in_y],
        loss,
        updates=updates,
        on_unused_input='warn')

    return train_fn, test_fn, params, all_params
示例#4
0
def build_fn(args, embeddings):
    """
        Build training and testing functions.
    """
    in_x1 = T.imatrix('x1')
    in_x3 = T.imatrix('x3')
    in_mask1 = T.matrix('mask1')
    in_mask3 = T.matrix('mask3')
    in_y = T.ivector('y')

    #batch x word_num x mea_num
    in_x4 = T.ftensor3('x4')

    l_in1 = lasagne.layers.InputLayer((None, None), in_x1)
    l_mask1 = lasagne.layers.InputLayer((None, None), in_mask1)
    l_emb1 = lasagne.layers.EmbeddingLayer(l_in1,
                                           args.vocab_size,
                                           args.embedding_size,
                                           W=embeddings)

    l_in3 = lasagne.layers.InputLayer((None, None), in_x3)
    l_mask3 = lasagne.layers.InputLayer((None, None), in_mask3)
    l_emb3 = lasagne.layers.EmbeddingLayer(l_in3,
                                           args.vocab_size,
                                           args.embedding_size,
                                           W=l_emb1.W)

    l_in4 = lasagne.layers.InputLayer((None, None, args.mea_num), in_x4)
    if not args.tune_embedding:
        l_emb1.params[l_emb1.W].remove('trainable')
        l_emb3.params[l_emb3.W].remove('trainable')

        assert args.model is None

#weighted mean: passage embedding
    if args.freezeMlP:
        weight_mlp_np = np.array([[1.]])
        b_mlp = np.array([0.])
        l_weight = lasagne.layers.DenseLayer(l_in4,
                                             1,
                                             num_leading_axes=-1,
                                             W=weight_mlp_np,
                                             b=b_mlp,
                                             nonlinearity=None)
        l_weight.params[l_weight.W].remove('trainable')
        l_weight.params[l_weight.b].remove('trainable')
    else:
        #        weight_mlp_np = np.zeros((args.mea_num, 1)) + 0.01*np.random.randn(args.mea_num, 1)
        weight_mlp_np = np.zeros((args.mea_num, 1))
        weight_mlp_np[-5] = 1.
        b_mlp = np.array([0.])
        #        l_weight = lasagne.layers.DenseLayer(l_in4, 1, num_leading_axes=-1,
        #                                             nonlinearity=args.actiMlP)
        #        l_weight = lasagne.layers.DenseLayer(l_in4, 1, num_leading_axes=-1,
        #                                             W=weight_mlp_np, b=b_mlp,
        #                                             nonlinearity=None)
        #        l_weight = lasagne.layers.DenseLayer(l_in4, 1, num_leading_axes=-1,
        #                                             nonlinearity=None)
        l_weight1 = lasagne.layers.DenseLayer(l_in4,
                                              1,
                                              num_leading_axes=-1,
                                              W=weight_mlp_np,
                                              b=b_mlp,
                                              nonlinearity=None)
        l_weight = nn_layers.WeightedNormLayer(l_weight1)

#        l_weight.params[l_weight.W].remove('trainable')
#        l_weight.params[l_weight.b].remove('trainable')
#        l_weight = lasagne.layers.DenseLayer(l_in4, 1, num_leading_axes=-1,
#                                             W=lasagne.init.Constant(0.), b=lasagne.init.Constant(1.),
#                                             nonlinearity=args.actiMlP)
#        l_weight.params[l_weight.W].remove('trainable')

#    weight_mlp_np = np.zeros((15, 1))
#    weight_mlp_np[-2] = 1.
#    weight_mlp_np = np.array([[1.]])
#    b_mlp = np.array([0.])
#    l_weight = lasagne.layers.DenseLayer(l_in4, 1, num_leading_axes=-1,
#                                         W=weight_mlp_np, b=b_mlp, nonlinearity=None)
#    l_weight1 = lasagne.layers.DenseLayer(l_in4, 2, num_leading_axes=-1, nonlinearity=LeakyRectify(0.1))
#    l_weight = lasagne.layers.DenseLayer(l_weight1, 1, num_leading_axes=-1, nonlinearity=sigmoid)
#    l_weight.params[l_weight.W].remove('trainable')
#    l_weight.params[l_weight.b].remove('trainable')
    att = nn_layers.WeightedAverageLayer([l_emb1, l_weight, l_mask1])
    #mean: option embedding
    network3 = nn_layers.AveragePoolingLayer(l_emb3, mask_input=l_mask3)
    network3 = lasagne.layers.ReshapeLayer(
        network3, (in_x1.shape[0], 4, args.embedding_size))
    #predict answer
    network = nn_layers.DotLayer([network3, att], args.embedding_size)
    if args.pre_trained is not None:
        dic = utils.load_params(args.pre_trained)
        lasagne.layers.set_all_param_values(network, dic['params'])
        del dic['params']
        logging.info('Loaded pre-trained model: %s' % args.pre_trained)
        for dic_param in dic.iteritems():
            logging.info(dic_param)

    logging.info('#params: %d' %
                 lasagne.layers.count_params(network, trainable=True))
    logging.info('#fixed params: %d' %
                 lasagne.layers.count_params(network, trainable=False))
    for layer in lasagne.layers.get_all_layers(network):
        logging.info(layer)

    # Test functions
    weight = lasagne.layers.get_output(l_weight, deterministic=True)

    test_prob = lasagne.layers.get_output(network, deterministic=True)
    loss_test = lasagne.objectives.categorical_crossentropy(test_prob,
                                                            in_y).mean()
    test_prediction = T.argmax(test_prob, axis=-1)
    acc = T.sum(T.eq(test_prediction, in_y))
    #    test_fn = theano.function([in_x1, in_mask1, in_x3, in_mask3, in_y, in_x4],
    #                              [acc, test_prediction, test_prob], on_unused_input='warn',
    #                              mode=NanGuardMode(nan_is_error=True, inf_is_error=True, big_is_error=True))
    test_fn = theano.function(
        [in_x1, in_mask1, in_x3, in_mask3, in_y, in_x4],
        [acc, test_prediction, test_prob, weight, loss_test],
        on_unused_input='warn')

    # Train functions
    train_prediction = lasagne.layers.get_output(network)
    train_prediction = T.clip(train_prediction, 1e-7, 1.0 - 1e-7)
    loss = lasagne.objectives.categorical_crossentropy(train_prediction,
                                                       in_y).mean()

    # TODO: lasagne.regularization.regularize_network_params(network, lasagne.regularization.l2)
    #    l1_penalty = regularize_layer_params(l_weight, l1) * 1e-4
    #    loss = loss + l1_penalty
    #    params = lasagne.layers.get_all_params(network)#, trainable=True)
    params = lasagne.layers.get_all_params(network, trainable=True)
    all_params = lasagne.layers.get_all_params(network)
    if args.optimizer == 'sgd':
        updates = lasagne.updates.sgd(loss, params, args.learning_rate)
    elif args.optimizer == 'adam':
        updates = lasagne.updates.adam(loss,
                                       params,
                                       learning_rate=args.learning_rate)
    elif args.optimizer == 'rmsprop':
        updates = lasagne.updates.rmsprop(loss,
                                          params,
                                          learning_rate=args.learning_rate)
    else:
        raise NotImplementedError('optimizer = %s' % args.optimizer)


#    train_fn = theano.function([in_x1, in_mask1, in_x3, in_mask3, in_y, in_x4],
#                               loss, updates=updates, on_unused_input='warn',
#                               mode=NanGuardMode(nan_is_error=True, inf_is_error=True, big_is_error=True))
    train_fn = theano.function([in_x1, in_mask1, in_x3, in_mask3, in_y, in_x4],
                               loss,
                               updates=updates,
                               on_unused_input='warn')

    return train_fn, test_fn, params, all_params