def test_force_projection_no_binding(self):
     input_layer = keras.layers.Input(shape=(None, ))
     embed_layer = AdaptiveEmbedding(
         input_dim=3,
         output_dim=16,
         force_projection=True,
         return_embeddings=True,
         return_projections=True,
     )(input_layer)
     softmax_layer = AdaptiveSoftmax(
         input_dim=16,
         output_dim=3,
         force_projection=True,
     )(embed_layer)
     model = keras.models.Model(input_layer, softmax_layer)
     model_path = os.path.join(
         tempfile.gettempdir(),
         'test_ada_softmax_%f.h5' % np.random.random())
     model.save(model_path)
     model = keras.models.load_model(model_path,
                                     custom_objects={
                                         'AdaptiveEmbedding':
                                         AdaptiveEmbedding,
                                         'AdaptiveSoftmax': AdaptiveSoftmax,
                                     })
     model.summary()
 def test_cutoffs_no_projection_bind(self):
     input_layer = keras.layers.Input(shape=(None, ))
     embed_layer = AdaptiveEmbedding(
         input_dim=30,
         output_dim=8,
         cutoffs=[10, 20, 25],
         div_val=2,
         mask_zero=True,
         force_projection=False,
         return_embeddings=True,
         return_projections=True,
     )(input_layer)
     softmax_layer = AdaptiveSoftmax(
         input_dim=8,
         output_dim=30,
         cutoffs=[10, 20, 25],
         div_val=2,
         force_projection=False,
         bind_embeddings=True,
         bind_projections=True,
     )(embed_layer)
     model = keras.models.Model(input_layer, softmax_layer)
     model_path = os.path.join(
         tempfile.gettempdir(),
         'test_ada_softmax_%f.h5' % np.random.random())
     model.save(model_path)
     model = keras.models.load_model(model_path,
                                     custom_objects={
                                         'AdaptiveEmbedding':
                                         AdaptiveEmbedding,
                                         'AdaptiveSoftmax': AdaptiveSoftmax,
                                     })
     model.summary()
    def test_fit(self):
        input_layer = keras.layers.Input(shape=(None, ))
        embed_layer = AdaptiveEmbedding(
            input_dim=30,
            output_dim=32,
            cutoffs=[5, 15, 25],
            div_val=2,
            return_embeddings=True,
            return_projections=True,
            mask_zero=True,
        )(input_layer)
        dense_layer = keras.layers.Dense(
            units=32,
            activation='tanh',
        )(embed_layer[0])
        softmax_layer = AdaptiveSoftmax(
            input_dim=32,
            output_dim=30,
            cutoffs=[5, 15, 25],
            div_val=2,
            bind_embeddings=True,
            bind_projections=True,
        )([dense_layer] + embed_layer[1:])
        model = keras.models.Model(inputs=input_layer, outputs=softmax_layer)
        model.compile('adam', 'sparse_categorical_crossentropy')
        model.summary()

        inputs = np.random.randint(0, 30, (4096, 10))
        outputs = np.expand_dims(inputs, axis=-1)
        model.fit(
            inputs,
            outputs,
            epochs=100,
            callbacks=[
                keras.callbacks.EarlyStopping(monitor='loss',
                                              min_delta=1e-4,
                                              patience=2),
            ],
        )

        model = keras.models.Model(input_layer, softmax_layer)
        model_path = os.path.join(
            tempfile.gettempdir(),
            'test_ada_softmax_%f.h5' % np.random.random())
        model.save(model_path)
        model = keras.models.load_model(model_path,
                                        custom_objects={
                                            'AdaptiveEmbedding':
                                            AdaptiveEmbedding,
                                            'AdaptiveSoftmax': AdaptiveSoftmax,
                                        })

        inputs = np.random.randint(0, 30, (128, 10))
        outputs = model.predict(inputs).argmax(axis=-1)
        outputs *= np.not_equal(inputs, 0).astype('int32')
        diff = np.sum(np.not_equal(inputs, outputs))
        self.assertLess(diff, 5)
    def test_sample(self):
        embed_0 = np.array([
            [
                0.7562694862279867,
                -0.7532437781410828,
                -0.2882295795429552,
                -1.6990371818805843,
                -0.09864164817566004,
                -0.5235034477186453,
                -1.600153091413999,
                0.03441732751250957,
            ],
            [
                -0.3680529905261407,
                1.1673600332887637,
                -0.6914459306809843,
                -0.7645030146906124,
                2.0434827620248606,
                -0.2743642839675437,
                0.04834288951969495,
                -1.0368596183756285,
            ],
            [
                -0.8440324158987662,
                0.05585795322288273,
                -0.5827731797867599,
                1.502853709909658,
                -0.09311037618863122,
                1.366316512453695,
                -0.3834091917878978,
                -1.2647642860801802,
            ],
            [
                1.5212768184170435,
                -0.7854311748221854,
                -0.4674213048014483,
                -1.0460200278367862,
                0.3705555995848165,
                -0.12273261562651422,
                1.8138708310050653,
                -0.26957084415202764,
            ],
            [
                -0.15162771245260723,
                -0.19654664890167275,
                -1.77930041719533,
                -0.6987101769248606,
                0.32681036318004547,
                0.19156716698736181,
                0.8386004334587568,
                -1.8390076172747616,
            ],
            [
                -1.1363779747587972,
                -0.15233131547247872,
                0.158423477487577,
                -0.6984487776859649,
                1.2424950830966563,
                -0.16130616338419873,
                -1.6298737099566283,
                1.7229575808498785,
            ],
            [
                0.613169803410901,
                -1.5391239758406403,
                -1.2476893436624792,
                -0.05514513857644962,
                -0.5537408608863357,
                -0.9965187549427492,
                -0.6842234254089083,
                -1.2420165307152238,
            ],
            [
                -0.4086071455923046,
                -0.7286151488450243,
                1.2938629380821804,
                0.7450912596769113,
                -0.13042129128885002,
                -1.4269400640112133,
                -0.713571658756806,
                -0.5036154349645774,
            ],
            [
                0.7326026846217363,
                0.12752591749386982,
                0.7387086112901271,
                -1.4161019970745967,
                -0.6396944907214142,
                -2.0010110577965783,
                0.5843029435066284,
                -0.4033331631189724,
            ],
            [
                1.22301664512685,
                -0.024541032664251092,
                -0.27128167541306714,
                1.910258142043872,
                -0.9673069099782774,
                0.6614265651081772,
                -1.165650716838653,
                -0.5085143504562967,
            ],
        ])
        embed_1 = np.array([
            [
                0.6593494357199338, -0.06233478795012013, 0.3394579881849406,
                0.05894554241531747
            ],
            [
                1.0015451559801243, 0.7487130375684998, -0.4244371286817957,
                -0.45182923128222996
            ],
            [
                -0.41965070720383035, -0.2875756074838825, 1.8712603426351773,
                2.531083895835167
            ],
            [
                -0.6800689195006436, -0.39454047242128376, 0.5442439581019961,
                -0.21672610899025968
            ],
            [
                -1.3119449289237803, 1.5645034642903253, 1.3203132828621442,
                1.7673879116655695
            ],
            [
                -0.8817194029613362, -0.6655645822150862, 0.2341787847442309,
                -0.7641095447924122
            ],
            [
                -0.47497798682688624, 1.0109350638555383, -0.5514102704837403,
                -0.1450007600387442
            ],
            [
                -0.531267085230172, 0.12862169808408846, 0.18339345878624577,
                1.5279135983387981
            ],
            [
                0.43338928943049837, 0.2660771849859784, 1.4227633495535283,
                -0.5072818940455809
            ],
            [
                0.8704222505796531, 0.9361117741463981, 0.7442665348863866,
                0.91392694614948
            ],
        ])
        embed_2 = np.array([
            [1.2712292341556446, 1.009655780936284],
            [0.4420362222435132, 1.5186087787070979],
            [-0.10018465175352317, -0.09182475290216006],
            [-1.246047485363712, 1.6404603895987184],
            [1.4427767754835976, 1.2102150762070925],
        ])
        embed_3 = np.array([
            [0.8285545743394414],
            [0.7111875779008273],
            [0.35799413043562894],
            [-0.15005629449852656],
            [0.6263946579941496],
        ])
        proj_0 = np.array([
            [0.3409731658714878, 0.032745006392315756, 0.668797744010083],
            [-0.3082491589087075, -1.0028023345331745, 0.2122102239605163],
            [-0.3751562822576601, -0.5825445529201775, 0.43389258576225614],
            [0.26067868083146517, 0.8192897299406429, 0.073726048897453],
            [1.1346146882950412, -2.456072992985481, -0.054474463562940736],
            [-1.0283521269636255, -0.1983876737118115, 1.0132159972212373],
            [2.72334361610427, 0.5683724225575054, 2.403638230905517],
            [-0.2137114185905606, 0.3048293347650425, 1.510425235737199],
        ])
        proj_1 = np.array([
            [0.42186259731067743, 0.6034344571434473, 2.362015513199549],
            [-0.9313583984951119, -0.8242699945665621, 0.2596454482698166],
            [0.8871149648450185, -0.663397984939589, -1.195129355668761],
            [0.8016784490871957, 0.13830808473255815, -0.6580242457235711],
        ])
        proj_2 = np.array([
            [1.4802477891158519, 0.12638370704617574, -0.18503256737397666],
            [-0.3900434531439191, 0.14771223879593204, -0.8863321455068343],
        ])
        proj_3 = np.array(
            [[-0.589729339138385, 2.018799784975004, -0.08431336326635828]])

        cluster_kernel = np.array([
            [0.23014518853189528, -1.907450615160342, -0.5760690735239482],
            [0.15888698361555206, 0.16596164514332107, -1.3874452811285074],
            [-0.43498605862409073, -0.9533547594248867, 1.376861108688103],
            [2.0713086892043306, 0.3189268504371047, 0.17466615249716405],
            [-0.995949973463762, 0.043604908747614204, -1.6117698906413622],
            [0.6066490394919954, -0.5688549027107426, 0.4277926952413127],
            [-0.045942286252255375, 1.269447988095889, -2.0882415532533565],
            [0.8578254069980026, 0.6013204537529426, -1.5562555397638154],
        ])
        cluster_bias = np.array(
            [-1.2832837684769247, -0.39755882729529285, -1.6266054548863331])

        bias_0 = np.array([
            -0.44961683466248237,
            1.1354573774120789,
            1.2744817355039493,
            -1.5496828034299275,
            -0.21836162127739225,
            -0.37565781060494785,
            -0.17156518058295334,
            0.983434075647771,
            -0.3062002489865936,
            0.12998179587118727,
        ])
        bias_1 = np.array([
            -0.2091536758128447,
            -0.6470589074952924,
            0.3477127052723791,
            -0.9045321990801439,
            -0.21297856640874574,
            0.3459416954179376,
            0.37443354120881234,
            -1.1654497053299575,
            1.6909447574735716,
            0.23900953544990225,
        ])
        bias_2 = np.array([
            0.3099136565556444,
            -0.9158122257114607,
            -0.16860676319583162,
            -1.2395468248816244,
            1.204462915844038,
        ])
        bias_3 = np.array([
            1.291426908829937, -0.6694533566338665, 0.2625003902625795,
            0.9797242029047042, 1.599378867487272
        ])

        inputs = np.array([
            [
                [0.744236572859694, 0.016713611741487267, 1.4682734369173418],
                [0.27153908377796215, -1.469963926716969, -0.8287408146483969],
                [-2.12471847178894, -1.908653889589087, 0.6152713069444428],
                [0.9054803804104959, -1.2524010188123476, 0.673952005987055],
                [
                    -0.05409017774217415, -0.7869076720861053,
                    -0.8608013367536177
                ],
            ],
            [
                [0.5928070143642264, -0.1365080521672495, -1.8938283201202142],
                [1.8238080368340632, -0.8134981522315549, -0.2736867043672396],
                [
                    -0.6324104033897957, -1.1823330727729813,
                    -1.4800297849679227
                ],
                [1.3222282804156642, 1.7723967951065012, 0.38944790892928965],
                [-0.9808710814446125, 0.6626326119592982, 0.8039459587763045],
            ],
        ])

        weights = [
            embed_0,
            embed_1,
            embed_2,
            embed_3,
            proj_0,
            proj_1,
            proj_2,
            proj_3,
        ]

        input_layer = keras.layers.Input(shape=(None, 3))
        append_layer = AppendWeights(weights)
        softmax_layer = AdaptiveSoftmax(
            input_dim=3,
            output_dim=30,
            embed_dim=8,
            cutoffs=[10, 20, 25],
            div_val=2,
            bind_embeddings=True,
            bind_projections=True,
        )
        func = K.function([input_layer],
                          [softmax_layer(append_layer(input_layer))])
        append_layer.set_weights(weights)
        softmax_layer.set_weights(
            [cluster_kernel, cluster_bias, bias_0, bias_1, bias_2, bias_3])
        predicted = func([inputs])[0]
        expected = np.array([
            [
                [
                    5.605619080029101e-09,
                    5.3742809541290626e-05,
                    2.6568095563561656e-06,
                    0.9891002774238586,
                    7.272975926753134e-05,
                    9.171863979418049e-08,
                    4.264499864348181e-08,
                    4.891299454357068e-07,
                    0.0001877533650258556,
                    2.692615908017615e-07,
                    1.873376459116116e-05,
                    9.539959137327969e-05,
                    4.360527228186584e-08,
                    5.719440565599143e-08,
                    1.124294546350768e-09,
                    1.749220928104478e-07,
                    9.25613619529031e-07,
                    5.130279845388941e-08,
                    1.775680539140012e-05,
                    2.2025182261131704e-05,
                    0.0024439117405563593,
                    0.0001602671982254833,
                    0.002785446122288704,
                    2.3448987121810205e-05,
                    0.005013651214540005,
                    3.8959341395333746e-13,
                    5.834099344034435e-14,
                    1.785878980505376e-13,
                    4.786831738282094e-13,
                    5.899208530522893e-13,
                ],
                [
                    4.812825136468746e-05,
                    0.9990597367286682,
                    5.242341103439685e-06,
                    2.8096915016817547e-08,
                    1.5739469745312817e-05,
                    0.0008400182705372572,
                    8.577513312957308e-07,
                    2.8549273338285275e-05,
                    1.5727113122920855e-06,
                    2.7855088902128955e-08,
                    4.444893905659929e-15,
                    2.8687949779297045e-16,
                    1.4623736249719244e-11,
                    9.033296029595239e-14,
                    7.310696492623947e-11,
                    1.6607075686413814e-13,
                    4.8921424636999555e-14,
                    8.19115215651249e-14,
                    5.590938953123348e-13,
                    3.239051618608192e-14,
                    1.574426100603432e-08,
                    4.194554925618377e-09,
                    3.735754816602821e-09,
                    1.7098933380310655e-09,
                    4.4564735901531094e-08,
                    7.2173618193005495e-09,
                    1.4542597126521173e-09,
                    1.0874863676235691e-08,
                    1.0534140670870329e-07,
                    1.822166062481756e-08,
                ],
                [
                    0.000926146749407053,
                    0.001165713299997151,
                    1.2146524568379391e-05,
                    3.022600182298052e-12,
                    3.9759040504350196e-10,
                    0.9977163076400757,
                    1.305691249564589e-10,
                    3.6172119166621997e-07,
                    8.730076106466811e-10,
                    2.2465255824499764e-06,
                    3.152966386241185e-12,
                    3.184204844242089e-10,
                    1.6164958877484727e-15,
                    1.4817423546822917e-12,
                    1.9586689908868138e-11,
                    1.1893032565712947e-11,
                    3.2308891118049132e-09,
                    1.8932036114586298e-13,
                    7.211550107077969e-11,
                    1.3474238218236234e-11,
                    2.6987896103005185e-14,
                    1.444208793353885e-13,
                    2.029298820996339e-12,
                    3.8475198721465986e-11,
                    3.6226284558932287e-14,
                    1.2148435416747816e-05,
                    2.334001010240172e-06,
                    1.5123150660656393e-05,
                    0.00011920313409063965,
                    2.8255606594029814e-05,
                ],
                [
                    1.3934656806213752e-07,
                    0.9742105007171631,
                    1.6341533637387329e-06,
                    0.02507493644952774,
                    0.0002712457499001175,
                    2.178921022277791e-05,
                    2.0800779765295374e-08,
                    1.7274820720558637e-06,
                    0.00010041205678135157,
                    5.4775970426135245e-09,
                    2.01842809133268e-09,
                    1.3333264492487729e-09,
                    4.06389899509918e-09,
                    2.0069401696076739e-10,
                    8.946644536322879e-10,
                    3.6186006968641493e-10,
                    6.276996145082592e-10,
                    2.0115159538036664e-10,
                    2.6643403927550935e-08,
                    9.023438884980806e-09,
                    8.10222263680771e-05,
                    5.552933998842491e-06,
                    4.113625254831277e-05,
                    5.870374479854945e-07,
                    0.00018922182789538056,
                    6.140000210737642e-14,
                    1.2461146379355321e-14,
                    9.52243519496479e-14,
                    9.516042467905272e-13,
                    1.5695048884677154e-13,
                ],
                [
                    0.017724955454468727,
                    0.9227734804153442,
                    0.003521926701068878,
                    7.439290357069694e-07,
                    0.0008589967619627714,
                    0.03857409581542015,
                    0.0015136339934542775,
                    0.011509685777127743,
                    0.000178004804183729,
                    0.00017167421174235642,
                    1.2908007995804383e-10,
                    1.3230781054085483e-11,
                    8.642934545832759e-08,
                    1.9794590411237323e-09,
                    3.62592459168809e-07,
                    5.063550023720609e-09,
                    1.6391373813817722e-09,
                    1.607139976655958e-09,
                    7.201423013469821e-09,
                    4.98530883241699e-10,
                    2.105740350089036e-06,
                    8.830390925140819e-07,
                    6.429553423004108e-07,
                    7.17075693046354e-07,
                    5.8689788602350745e-06,
                    0.00046149574336595833,
                    7.730671495664865e-05,
                    0.0003315970825497061,
                    0.0014439808437600732,
                    0.0008476407965645194,
                ],
            ],
            [
                [
                    0.019987842068076134,
                    0.19081537425518036,
                    0.00488634780049324,
                    1.745469688785306e-07,
                    0.009911534376442432,
                    0.000864819681737572,
                    0.5589132905006409,
                    0.1608048379421234,
                    0.0006605738890357316,
                    0.00029775931034237146,
                    8.567404191682851e-14,
                    2.7650286773307244e-16,
                    1.0644154002648065e-07,
                    2.0991774291045928e-11,
                    2.813413679803034e-08,
                    3.624691866099816e-11,
                    4.0573834981898205e-13,
                    3.6105844009037824e-11,
                    9.637011674779039e-12,
                    2.933819907933316e-13,
                    2.9287286906765075e-06,
                    6.513750463454926e-07,
                    7.162674364735722e-08,
                    7.26421092167584e-08,
                    1.1740429727069568e-05,
                    0.01245346013456583,
                    0.0018510496011003852,
                    0.005540691316127777,
                    0.014380054548382759,
                    0.018616652116179466,
                ],
                [
                    3.4378548008362486e-08,
                    0.9630267024040222,
                    6.557401661666518e-07,
                    0.030607404187321663,
                    0.005104635842144489,
                    2.3386729708363418e-07,
                    1.8453529264661483e-06,
                    8.987089131551329e-06,
                    0.00046148416004143655,
                    1.0203488054472132e-09,
                    4.58596648069548e-13,
                    7.34142005406847e-15,
                    3.0094788883161527e-09,
                    5.886948372356426e-13,
                    2.4638501308626992e-11,
                    5.974854942400465e-13,
                    3.0690504120283596e-14,
                    1.4213487158076799e-12,
                    1.4414204205226433e-11,
                    2.241537684632977e-12,
                    0.00017593242228031158,
                    4.26023416366661e-06,
                    5.525962023966713e-06,
                    3.286314864681117e-08,
                    0.0006022691377438605,
                    3.5567560588379774e-14,
                    6.867705443371714e-15,
                    4.5175658596992643e-14,
                    3.63893133291035e-13,
                    8.34426725306904e-14,
                ],
                [
                    0.11111080646514893,
                    0.4013337790966034,
                    0.0010328179923817515,
                    4.1730782024407276e-11,
                    1.2512266948760953e-05,
                    0.24545560777187347,
                    0.000354167161276564,
                    0.004766426980495453,
                    1.5340842764999252e-06,
                    5.2795141527894884e-05,
                    3.2432775083336913e-15,
                    2.201742370634856e-16,
                    2.4322148461930482e-11,
                    6.073784595932163e-13,
                    1.7714454347839137e-09,
                    1.73516311995775e-12,
                    5.266814545254461e-13,
                    3.685409453741545e-13,
                    6.401695037787369e-13,
                    1.8601455079492353e-14,
                    8.367688208998914e-10,
                    9.737708417389968e-10,
                    3.437621853841222e-10,
                    3.2822229378837164e-09,
                    2.3505164481463225e-09,
                    0.0271458700299263,
                    0.0047686779871582985,
                    0.023600326851010323,
                    0.12625090777873993,
                    0.054113760590553284,
                ],
                [
                    1.3648338459404386e-08,
                    2.7086382914376372e-08,
                    4.8717127356212586e-05,
                    0.7242416143417358,
                    0.002194569678977132,
                    1.5252779594909782e-10,
                    0.0023464339319616556,
                    6.96199422236532e-05,
                    0.004193915985524654,
                    9.601204510545358e-05,
                    0.006397495046257973,
                    0.0010100157232955098,
                    0.008361614309251308,
                    0.00016940751811489463,
                    2.352720002818387e-06,
                    0.0004652136121876538,
                    4.8424004489788786e-05,
                    0.0003627230762504041,
                    0.003415221581235528,
                    0.002616040175780654,
                    0.05781353637576103,
                    0.0021764079574495554,
                    0.0038422606885433197,
                    4.1606021113693714e-05,
                    0.18008683621883392,
                    3.456036790083772e-09,
                    3.515727153846626e-10,
                    3.361662892498174e-10,
                    1.689842571428457e-10,
                    2.6885360604467223e-09,
                ],
                [
                    0.0023310158867388964,
                    2.793478643070557e-06,
                    0.04028953239321709,
                    0.00010354104597354308,
                    3.747312803170644e-05,
                    0.0016634142957627773,
                    0.0006405340973287821,
                    0.002110437024384737,
                    0.0002885640715248883,
                    0.41590359807014465,
                    0.01737625151872635,
                    0.3610380291938782,
                    5.221741503191879e-06,
                    0.0005045500583946705,
                    1.4334115803649183e-05,
                    0.004071446601301432,
                    0.06609751284122467,
                    0.00018661090871319175,
                    0.01566864363849163,
                    0.010083985514938831,
                    5.475956277223304e-05,
                    5.0246228056494147e-05,
                    0.00035087967989966273,
                    0.0004574395134113729,
                    9.8565717053134e-05,
                    0.026114653795957565,
                    0.0029584828298538923,
                    0.003910995088517666,
                    0.003132845275104046,
                    0.024453623220324516,
                ],
            ],
        ])
        self.assertTrue(np.allclose(expected, predicted))
        sums = np.sum(predicted, axis=-1)
        self.assertTrue(np.allclose(np.ones_like(sums), sums))
示例#5
0
def build_albert(token_num,
                 pos_num=512,
                 seq_len=512,
                 embed_dim=128,
                 hidden_dim=768,
                 transformer_num=12,
                 head_num=12,
                 feed_forward_dim=3072,
                 dropout_rate=0.1,
                 attention_activation=None,
                 feed_forward_activation='gelu',
                 training=True,
                 trainable=None,
                 output_layers=None):
    """Get ALBERT model.
    See: https://arxiv.org/pdf/1909.11942.pdf
    :param token_num: Number of tokens.
    :param pos_num: Maximum position.
    :param seq_len: Maximum length of the input sequence or None.
    :param embed_dim: Dimensions of embeddings.
    :param hidden_dim: Dimensions of hidden layers.
    :param transformer_num: Number of transformers.
    :param head_num: Number of heads in multi-head attention
                    in each transformer.
    :param feed_forward_dim: Dimension of the feed forward layer
                             in each transformer.
    :param dropout_rate: Dropout rate.
    :param attention_activation: Activation for attention layers.
    :param feed_forward_activation: Activation for feed-forward layers.
    :param training: A built model with MLM and NSP outputs will be returned
                     if it is `True`, otherwise the input layers and the last
                     feature extraction layer will be returned.
    :param trainable: Whether the model is trainable.
    :param output_layers: A list of indices of output layers.
    """
    if attention_activation == 'gelu':
        attention_activation = gelu
    if feed_forward_activation == 'gelu':
        feed_forward_activation = gelu
    if trainable is None:
        trainable = training

    def _trainable(_layer):
        if isinstance(trainable, (list, tuple, set)):
            for prefix in trainable:
                if _layer.name.startswith(prefix):
                    return True
            return False
        return trainable

    # Build inputs
    input_token = keras.layers.Input(shape=(seq_len, ), name='Input-Token')
    input_segment = keras.layers.Input(shape=(seq_len, ), name='Input-Segment')
    inputs = [input_token, input_segment]

    # Build embeddings
    embed_token, embed_weights, embed_projection = AdaptiveEmbedding(
        input_dim=token_num,
        output_dim=hidden_dim,
        embed_dim=embed_dim,
        mask_zero=True,
        trainable=trainable,
        return_embeddings=True,
        return_projections=True,
        name='Embed-Token',
    )(input_token)
    embed_segment = keras.layers.Embedding(
        input_dim=2,
        output_dim=hidden_dim,
        trainable=trainable,
        name='Embed-Segment',
    )(input_segment)
    embed_layer = keras.layers.Add(name='Embed-Token-Segment')(
        [embed_token, embed_segment])
    embed_layer = PositionEmbedding(
        input_dim=pos_num,
        output_dim=hidden_dim,
        mode=PositionEmbedding.MODE_ADD,
        trainable=trainable,
        name='Embedding-Position',
    )(embed_layer)

    if dropout_rate > 0.0:
        dropout_layer = keras.layers.Dropout(
            rate=dropout_rate,
            name='Embedding-Dropout',
        )(embed_layer)
    else:
        dropout_layer = embed_layer
    embed_layer = LayerNormalization(
        trainable=trainable,
        name='Embedding-Norm',
    )(dropout_layer)

    # Build shared transformer
    attention_layer = MultiHeadAttention(
        head_num=head_num,
        activation=attention_activation,
        name='Attention',
    )
    attention_normal = LayerNormalization(name='Attention-Normal')
    feed_forward_layer = FeedForward(units=feed_forward_dim,
                                     activation=feed_forward_activation,
                                     name='Feed-Forward')
    feed_forward_normal = LayerNormalization(name='Feed-Forward-Normal')

    transformed = embed_layer
    transformed_layers = []
    for i in range(transformer_num):
        attention_input = transformed
        transformed = attention_layer(transformed)
        if dropout_rate > 0.0:
            transformed = keras.layers.Dropout(
                rate=dropout_rate,
                name='Attention-Dropout-{}'.format(i + 1),
            )(transformed)
        transformed = keras.layers.Add(
            name='Attention-Add-{}'.format(i + 1), )(
                [attention_input, transformed])
        transformed = attention_normal(transformed)

        feed_forward_input = transformed
        transformed = feed_forward_layer(transformed)
        if dropout_rate > 0.0:
            transformed = keras.layers.Dropout(
                rate=dropout_rate,
                name='Feed-Forward-Dropout-{}'.format(i + 1),
            )(transformed)
        transformed = keras.layers.Add(
            name='Feed-Forward-Add-{}'.format(i + 1), )(
                [feed_forward_input, transformed])
        transformed = feed_forward_normal(transformed)
        transformed_layers.append(transformed)

    if training:
        # Build tasks
        mlm_dense_layer = keras.layers.Dense(
            units=hidden_dim,
            activation=feed_forward_activation,
            name='MLM-Dense',
        )(transformed)
        mlm_norm_layer = LayerNormalization(name='MLM-Norm')(mlm_dense_layer)
        mlm_pred_layer = AdaptiveSoftmax(
            input_dim=hidden_dim,
            output_dim=token_num,
            embed_dim=embed_dim,
            bind_embeddings=True,
            bind_projections=True,
            name='MLM-Sim',
        )([mlm_norm_layer, embed_weights, embed_projection])
        masked_layer = Masked(name='MLM')([mlm_pred_layer, inputs[-1]])
        extract_layer = Extract(index=0, name='Extract')(transformed)
        nsp_dense_layer = keras.layers.Dense(
            units=hidden_dim,
            activation='tanh',
            name='SOP-Dense',
        )(extract_layer)
        nsp_pred_layer = keras.layers.Dense(
            units=2,
            activation='softmax',
            name='SOP',
        )(nsp_dense_layer)
        model = keras.models.Model(inputs=inputs,
                                   outputs=[masked_layer, nsp_pred_layer])
        for layer in model.layers:
            layer.trainable = _trainable(layer)
        return model
    if output_layers is not None:
        if isinstance(output_layers, list):
            output_layers = [
                transformed_layers[index] for index in output_layers
            ]
            output = keras.layers.Concatenate(name='Output', )(output_layers)
        else:
            output = transformed_layers[output_layers]
        model = keras.models.Model(inputs=inputs, outputs=output)
        return model
    model = keras.models.Model(inputs=inputs, outputs=transformed)
    for layer in model.layers:
        layer.trainable = _trainable(layer)
    return inputs, transformed
def build_transformer_xl(units,
                         embed_dim,
                         hidden_dim,
                         num_token,
                         num_block,
                         num_head,
                         batch_size,
                         memory_len,
                         target_len,
                         dropout=0.0,
                         attention_dropout=0.0,
                         cutoffs=None,
                         div_val=1,
                         force_projection=None,
                         bind_embeddings=True,
                         bind_projections=True,
                         clamp_len=None,
                         share_biases=True):
    """Build transformer-XL model.

    :param units: Units inside the transformer.
    :param embed_dim: Dimension of embeddings.
    :param hidden_dim: Dimension inside position-wise feed-forward layer.
    :param num_token: Number of distinct input tokens.
    :param num_block: Number of basic encoder blocks.
    :param num_head: Number of heads for attention.
    :param batch_size: Maximum batch size.
    :param memory_len: The maximum length of memories.
    :param target_len: The length of prediction block.
    :param dropout: General dropout rate.
    :param attention_dropout: Dropout rate inside attention layer.
    :param cutoffs: Cutoffs of adaptive embedding.
    :param div_val: Scale factor of adaptive embedding.
    :param force_projection: Add projection when the dimensions are equal.
    :param bind_embeddings: Whether to bind embeddings to adaptive softmax.
    :param bind_projections: Whether to bind projections to adaptive softmax.
    :param clamp_len: The maximum value of relative position.
    :param share_biases: Whether to use the same biases for all layers.
    :return: The built model.
    """
    token_input = keras.layers.Input(shape=(target_len,), name='Input-Token')
    memory_length_input = keras.layers.Input(shape=(1,), name='Input-Memory-Length')
    inputs = [token_input, memory_length_input]

    results = AdaptiveEmbedding(
        input_dim=num_token,
        output_dim=units,
        embed_dim=embed_dim,
        cutoffs=cutoffs,
        div_val=div_val,
        mask_zero=True,
        force_projection=force_projection,
        return_embeddings=True,
        return_projections=True,
        name='Embed-Token',
    )(token_input)
    token_embed, embedding_weights = results[0], results[1:]
    token_embed = Scale(scale=np.sqrt(units), name='Embed-Token-Scaled')(token_embed)
    last_memory = Memory(
        batch_size=batch_size,
        memory_len=memory_len,
        target_len=target_len,
        output_dim=units,
        name='Memory-0',
    )([token_embed, memory_length_input])

    position_embed = PositionalEmbedding(
        output_dim=units,
        clamp_len=clamp_len,
        name='Embed-Position',
    )([token_input, last_memory])

    if 0.0 < dropout < 1.0:
        token_embed = keras.layers.Dropout(rate=dropout, name='Embed-Token-Dropped')(token_embed)
        position_embed = keras.layers.Dropout(rate=dropout, name='Embed-Position-Dropped')(position_embed)

    context_bias, relative_bias = None, None
    if share_biases:
        context_bias, relative_bias = RelativeBias(units=units, name='Biases')(last_memory)

    outputs = [token_embed]
    for i in range(num_block):
        block_input, block_output = outputs[-1], outputs[-1]
        if not share_biases:
            context_bias, relative_bias = RelativeBias(units=units, name='Biases-{}'.format(i + 1))(last_memory)
        block_output = RelativePartialMultiHeadSelfAttention(
            units=units,
            num_head=num_head,
            use_bias=False,
            attention_dropout=attention_dropout,
            name='Attention-{}'.format(i + 1),
        )([block_output, position_embed, last_memory, context_bias, relative_bias])
        if 0.0 < dropout < 1.0:
            block_output = keras.layers.Dropout(rate=dropout, name='Attention-Dropped-{}'.format(i + 1))(block_output)
        block_output = keras.layers.Add(name='Attention-Res-{}'.format(i + 1))([block_input, block_output])
        block_output = LayerNormalization(name='Attention-Norm-{}'.format(i + 1))(block_output)

        block_input = block_output
        block_output = FeedForward(
            units=hidden_dim,
            dropout_rate=dropout,
            name='FeedForward-{}'.format(i + 1),
        )(block_output)
        if 0.0 < dropout < 1.0:
            block_output = keras.layers.Dropout(rate=dropout, name='FeedForward-Dropped-{}'.format(i + 1))(block_output)
        block_output = keras.layers.Add(name='FeedForward-Res-{}'.format(i + 1))([block_input, block_output])
        block_output = LayerNormalization(name='FeedForward-Norm-{}'.format(i + 1))(block_output)

        if i < num_block - 1:
            last_memory = Memory(
                batch_size=batch_size,
                memory_len=memory_len,
                target_len=target_len,
                output_dim=units,
                name='Memory-{}'.format(i + 1),
            )([block_output, memory_length_input])

        outputs.append(block_output)

    if 0.0 < dropout < 1.0:
        outputs[-1] = keras.layers.Dropout(rate=dropout, name='Output-Dropped')(outputs[-1])
    softmax = AdaptiveSoftmax(
        input_dim=units,
        output_dim=num_token,
        embed_dim=embed_dim,
        cutoffs=cutoffs,
        div_val=div_val,
        force_projection=force_projection,
        bind_embeddings=bind_embeddings,
        bind_projections=bind_projections,
        name='Softmax',
    )(outputs[-1:] + embedding_weights)

    model = keras.models.Model(inputs=inputs, outputs=softmax)
    return model