コード例 #1
0
def get_model(inputdim,
              outputdim,
              regularization_strength=0.01,
              lr=0.000,
              cosine=False,
              **kwargs):
    transformation = Dense(inputdim,
                           init='identity',
                           W_constraint=Orthogonal())

    model = Model()
    model.add_input(name='embeddings1', input_shape=(inputdim, ))
    model.add_input(name='embeddings2', input_shape=(inputdim, ))
    model.add_shared_node(transformation,
                          name='transformation',
                          inputs=['embeddings1', 'embeddings2'],
                          outputs=['transformed1', 'transformed2'])
    model.add_node(Lambda(lambda x: x[:, :outputdim]),
                   input='transformed1',
                   name='projected1')
    model.add_node(Lambda(lambda x: -x[:, :outputdim]),
                   input='transformed2',
                   name='negprojected2')

    if cosine:
        model.add_node(
            Lambda(lambda x: x / K.reshape(K.sqrt(K.sum(x * x, axis=1)),
                                           (x.shape[0], 1))),
            name='normalized1',
            input='projected1')
        model.add_node(
            Lambda(lambda x: x / K.reshape(K.sqrt(K.sum(x * x, axis=1)),
                                           (x.shape[0], 1))),
            name='negnormalized2',
            input='negprojected2')
        model.add_node(
            Lambda(lambda x: K.reshape(K.sum(x, axis=1), (x.shape[0], 1))),
            name='distances',
            inputs=['normalized1', 'negnormalized2'],
            merge_mode='mul')
    else:
        model.add_node(Lambda(lambda x: K.reshape(K.sqrt(K.sum(x * x, axis=1)),
                                                  (x.shape[0], 1))),
                       name='distances',
                       inputs=['projected1', 'negprojected2'],
                       merge_mode='sum')

    model.add_output(name='y', input='distances')
    model.compile(loss={
        'y': lambda y, d: K.mean(y * d)
    },
                  optimizer=SimpleSGD())
    return model