def _pre_load_args(args): cfg_file_args = yaml_load_checking( load_from_config_path( flatten_string_list( getattr(args, flags_core.DEFAULT_CONFIG_FLAG.name)))) model_dirs = flatten_string_list(args.model_dir or cfg_file_args.get("model_dir", None)) hparams_set = args.hparams_set if hparams_set is None: hparams_set = cfg_file_args.get("hparams_set", None) predefined_parameters = get_hyper_parameters(hparams_set) formatted_parameters = {} if "model.class" in predefined_parameters: formatted_parameters["model.class"] = predefined_parameters.pop( "model.class") if "model" in predefined_parameters: formatted_parameters["model"] = predefined_parameters.pop("model") if "model.params" in predefined_parameters: formatted_parameters["model.params"] = predefined_parameters.pop( "model.params") if len(predefined_parameters) > 0: formatted_parameters["entry.params"] = predefined_parameters try: model_cfgs = ModelConfigs.load(model_dirs[0]) return deep_merge_dict( deep_merge_dict(model_cfgs, formatted_parameters), cfg_file_args) except Exception: return deep_merge_dict(formatted_parameters, cfg_file_args)
def test_openai_gpt2(): from transformers import GPT2Model, GPT2Tokenizer input_text = "Here is some text to encode" pt_tokenizer = GPT2Tokenizer.from_pretrained("gpt2") pt_model = GPT2Model.from_pretrained("gpt2", return_dict=True) pt_outputs = pt_model(**pt_tokenizer([input_text], return_tensors="pt")) task = build_task({ "class": "lm", "params": { "data_pipeline.class": "GPT2DataPipeline", "max_len": 50, "begin_of_sentence": "eos" } }) model_cfgs = get_hyper_parameters("gpt2_117m") model = task.build_model(model_cfgs) restore_checkpoint_if_possible_v2(model, "117M", model_name="OpenAIGPT2") input_ids = task._data_pipeline.process(input_text) tf_inputs = { "trg_input": tf.convert_to_tensor([input_ids], tf.int64), "trg_length": tf.convert_to_tensor([len(input_ids)], tf.int64) } _, gen_init = model.get_symbols_to_logits_fn(tf_inputs, is_training=False, is_inference=False) tf_outputs = model.get_decoder_output(gen_init["decoder_input"], cache=gen_init["decoder_internal_cache"], is_training=False) assert_equal_numpy(pt_outputs.last_hidden_state.detach().numpy(), tf_outputs[:, :-1].numpy(), 5e-4)
def new(cls, args, src_meta, trg_meta, name=None): """ Builds a sequence to sequence model. Args: args: A dict containing all model parameters. src_meta: A dict containing source-side vocabulary meta data, e.g. eos_id, vocab_size. trg_meta: A dict containing target-side vocabulary meta data, e.g. eos_id, vocab_size. name: The name of the model. Returns: An encoder decoder model. """ # build source and target modality src_modality, trg_modality = cls.build_modalities(args, src_meta, trg_meta) encoder_params, decoder_params = {}, {} for f in cls.class_or_method_args(): if f.name in args: if f.name.startswith("encoder."): encoder_params[f.name[8:]] = args[f.name] elif f.name.startswith("decoder."): decoder_params[f.name[8:]] = args[f.name] # build encoder and decoder encoder = None if args["bert_mode"] != "bert_as_encoder": encoder = build_encoder({ "encoder.class": "TransformerEncoder", "encoder.params": encoder_params}) decoder = build_decoder({ "decoder.class": "TransformerDecoder", "decoder.params": decoder_params}) with tf.name_scope(name or "ctnmt"): bert_model = Bert.new(get_hyper_parameters(args["bert_config"])["model.params"], vocab_meta=src_meta, name="bert") model = cls(args, bert_model, src_meta, trg_meta, src_modality, trg_modality, encoder, decoder, name=(name or "ctnmt")) _ = model({"src": tf.convert_to_tensor([[1, 2, 3]], tf.int64), "src_padding": tf.convert_to_tensor([[0, 0., 0]], tf.float32), "trg_input": tf.convert_to_tensor([[1, 2, 3]], tf.int64)}) return model
def cli_main(): if len(sys.argv) == 1 or (len(sys.argv) == 2 and (sys.argv[1] in ["help", "--help", "-h"])): print("Usage: ") print(" >> python3 -m neurst.cli.view_registry registry_name") print(" Show registered classes and their aliases.") print() print(" >> python3 -m neurst.cli.view_registry registry_name class_name") print(" Show detailed parameters of the class.") print() print("All registry names: ") for k in REGISTRIES: print(f" - {k}") exit() registry_name = sys.argv[1].lower() if registry_name not in REGISTRIES: print(f"Unknown registry name: {registry_name}") elif len(sys.argv) == 2: print(f"All registered {registry_name}(s): ") clsname2alias = {} for name, cls in REGISTRIES[registry_name].items(): clsname = cls.__name__ if clsname not in clsname2alias: clsname2alias[clsname] = [] clsname2alias[clsname].append(name) if registry_name == "hparams_set": for k in clsname2alias: print(f" - {k}") else: print(" | Class | Aliases |") for k, v in clsname2alias.items(): print(" | {} | {} |".format(k, ", ".join(v))) elif len(sys.argv) == 3: detail_name = sys.argv[2] if registry_name == "hparams_set": hparams = get_hyper_parameters(detail_name) if len(hparams) == 0: print(f"Unknown hparams_set: {detail_name}") else: print(f"Pre-defined hyperparameters set of `{detail_name}`: ") print(json.dumps(get_hyper_parameters(detail_name), indent=4)) elif detail_name not in REGISTRIES[registry_name]: print(f"Unknown class: {detail_name} under `{registry_name}`") else: if hasattr(REGISTRIES[registry_name][detail_name], "class_or_method_args"): flags = [] module_flags = [] for f in REGISTRIES[registry_name][detail_name].class_or_method_args(): if isinstance(f, ModuleFlag): module_flags.append(f) else: flags.append(f) if len(flags) > 0: print(f"Flags for {detail_name}:") print(" | flag | type | default | help |") for f in flags: print(f" | {f.name} | {str(f.dtype)} | {f.default} | {f.help} |") if len(module_flags) > 0: print(f"Dependent modules for {detail_name}: ") print(" | name | module | help |") for f in module_flags: print(f" | {f.name} | {f.module_name} | {f.help} |") else: print(f"No flags defined for `{detail_name}` ({registry_name})")
def test_st(): params = copy.deepcopy( get_hyper_parameters("speech_transformer_toy")["model.params"]) params["modality.source.dim"] = None params["modality.target.dim"] = None params["modality.source.timing"] = None params["modality.target.timing"] = None params["encoder.num_layers"] = 1 params["decoder.num_layers"] = 1 src_vocab_meta = dict(audio_feature_dim=80, audio_feature_channels=1) trg_vocab_meta = dict(vocab_size=5, eos_id=4, bos_id=3, unk_id=2) fake_audio = numpy.random.rand(1, 11, 80, 1) pt_inps = { "src": torch.FloatTensor(fake_audio), "src_length": torch.LongTensor([11]), "trg_input": torch.LongTensor([[3, 0, 1]]), } tf_inps = { "src": tf.convert_to_tensor(fake_audio, tf.float32), "src_length": tf.convert_to_tensor([11], tf.int32), "trg_input": tf.convert_to_tensor([[3, 0, 1]], tf.int32), } pt_model: SpeechTransformer = build_model( { "model.class": "speech_transformer", "params": params }, src_meta=src_vocab_meta, trg_meta=trg_vocab_meta) tf_model: TFSpeechTransformer = build_tf_model( { "model.class": "speech_transformer", "params": params }, src_meta=src_vocab_meta, trg_meta=trg_vocab_meta) pt_model._src_modality.embedding_layer._conv_layer1.weight.data = torch.FloatTensor( tf_model._src_modality.embedding_layer._conv_layers[0].kernel.numpy( ).transpose((3, 2, 0, 1))) pt_model._src_modality.embedding_layer._conv_layer1.bias.data = torch.FloatTensor( tf_model._src_modality.embedding_layer._conv_layers[0].bias.numpy()) pt_model._src_modality.embedding_layer._conv_layer2.weight.data = torch.FloatTensor( tf_model._src_modality.embedding_layer._conv_layers[1].kernel.numpy( ).transpose((3, 2, 0, 1))) pt_model._src_modality.embedding_layer._conv_layer2.bias.data = torch.FloatTensor( tf_model._src_modality.embedding_layer._conv_layers[1].bias.numpy()) pt_model._src_modality.embedding_layer._norm_layer1.weight.data = torch.FloatTensor( tf_model._src_modality.embedding_layer._norm_layers[0].gamma.numpy()) pt_model._src_modality.embedding_layer._norm_layer1.bias.data = torch.FloatTensor( tf_model._src_modality.embedding_layer._norm_layers[0].beta.numpy()) pt_model._src_modality.embedding_layer._norm_layer2.weight.data = torch.FloatTensor( tf_model._src_modality.embedding_layer._norm_layers[1].gamma.numpy()) pt_model._src_modality.embedding_layer._norm_layer2.bias.data = torch.FloatTensor( tf_model._src_modality.embedding_layer._norm_layers[1].beta.numpy()) pt_model._src_modality.embedding_layer._dense_layer.weight.data = torch.FloatTensor( tf_model._src_modality.embedding_layer._dense_layer.kernel.numpy( ).transpose()) pt_model._src_modality.embedding_layer._dense_layer.bias.data = torch.FloatTensor( tf_model._src_modality.embedding_layer._dense_layer.bias.numpy()) pt_model._trg_modality.embedding_layer._shared_weights.data = torch.FloatTensor( tf_model._trg_modality.embedding_layer._shared_weights.numpy()) pt_model._trg_modality.embedding_layer._bias.data = torch.FloatTensor( tf_model._trg_modality.embedding_layer._bias.numpy()) pt_model._encoder._output_norm_layer.weight.data = torch.FloatTensor( tf_model._encoder._output_norm_layer.gamma.numpy()) pt_model._encoder._output_norm_layer.bias.data = torch.FloatTensor( tf_model._encoder._output_norm_layer.beta.numpy()) pt_model._encoder._stacking_layers[0][ 0]._layer._qkv_transform_layer._kernel.data = torch.FloatTensor( tf_model._encoder._stacking_layers[0]._selfatt_layer._layer. _qkv_transform_layer._kernel.numpy()) pt_model._encoder._stacking_layers[0][ 0]._layer._qkv_transform_layer._bias.data = torch.FloatTensor( tf_model._encoder._stacking_layers[0]._selfatt_layer._layer. _qkv_transform_layer._bias.numpy()) pt_model._encoder._stacking_layers[0][ 0]._layer._output_transform_layer._kernel.data = torch.FloatTensor( tf_model._encoder._stacking_layers[0]._selfatt_layer._layer. _output_transform_layer._kernel.numpy()) pt_model._encoder._stacking_layers[0][ 0]._layer._output_transform_layer._bias.data = torch.FloatTensor( tf_model._encoder._stacking_layers[0]._selfatt_layer._layer. _output_transform_layer._bias.numpy()) pt_model._encoder._stacking_layers[0][ 1]._layer._dense1.weight.data = torch.FloatTensor( tf_model._encoder._stacking_layers[0]._ffn_layer._layer._conv1. kernel.numpy().transpose([1, 0])) pt_model._encoder._stacking_layers[0][ 1]._layer._dense1.bias.data = torch.FloatTensor( tf_model._encoder._stacking_layers[0]._ffn_layer._layer._conv1. bias.numpy()) pt_model._encoder._stacking_layers[0][ 1]._layer._dense2.weight.data = torch.FloatTensor( tf_model._encoder._stacking_layers[0]._ffn_layer._layer._conv2. kernel.numpy().transpose([1, 0])) pt_model._encoder._stacking_layers[0][ 1]._layer._dense2.bias.data = torch.FloatTensor( tf_model._encoder._stacking_layers[0]._ffn_layer._layer._conv2. bias.numpy()) pt_model._encoder._stacking_layers[0][ 0]._norm_layer.weight.data = torch.FloatTensor( tf_model._encoder._stacking_layers[0]._selfatt_layer._norm_layer. gamma.numpy()) pt_model._encoder._stacking_layers[0][ 0]._norm_layer.bias.data = torch.FloatTensor( tf_model._encoder._stacking_layers[0]._selfatt_layer._norm_layer. beta.numpy()) pt_model._encoder._stacking_layers[0][ 1]._norm_layer.weight.data = torch.FloatTensor( tf_model._encoder._stacking_layers[0]._ffn_layer._norm_layer.gamma. numpy()) pt_model._encoder._stacking_layers[0][ 1]._norm_layer.bias.data = torch.FloatTensor( tf_model._encoder._stacking_layers[0]._ffn_layer._norm_layer.beta. numpy()) pt_model._decoder._output_norm_layer.weight.data = torch.FloatTensor( tf_model._decoder._output_norm_layer.gamma.numpy()) pt_model._decoder._output_norm_layer.bias.data = torch.FloatTensor( tf_model._decoder._output_norm_layer.beta.numpy()) pt_model._decoder._stacking_layers[0][ 0]._layer._qkv_transform_layer._kernel.data = torch.FloatTensor( tf_model._decoder._stacking_layers[0]._selfatt_layer._layer. _qkv_transform_layer._kernel.numpy()) pt_model._decoder._stacking_layers[0][ 0]._layer._qkv_transform_layer._bias.data = torch.FloatTensor( tf_model._decoder._stacking_layers[0]._selfatt_layer._layer. _qkv_transform_layer._bias.numpy()) pt_model._decoder._stacking_layers[0][ 0]._layer._output_transform_layer._kernel.data = torch.FloatTensor( tf_model._decoder._stacking_layers[0]._selfatt_layer._layer. _output_transform_layer._kernel.numpy()) pt_model._decoder._stacking_layers[0][ 0]._layer._output_transform_layer._bias.data = torch.FloatTensor( tf_model._decoder._stacking_layers[0]._selfatt_layer._layer. _output_transform_layer._bias.numpy()) pt_model._decoder._stacking_layers[0][ 0]._norm_layer.weight.data = torch.FloatTensor( tf_model._decoder._stacking_layers[0]._selfatt_layer._norm_layer. gamma.numpy()) pt_model._decoder._stacking_layers[0][ 0]._norm_layer.bias.data = torch.FloatTensor( tf_model._decoder._stacking_layers[0]._selfatt_layer._norm_layer. beta.numpy()) pt_model._decoder._stacking_layers[0][ 1]._layer._q_transform_layer._kernel.data = torch.FloatTensor( tf_model._decoder._stacking_layers[0]._crossatt_layer._layer. _q_transform_layer._kernel.numpy()) pt_model._decoder._stacking_layers[0][ 1]._layer._q_transform_layer._bias.data = torch.FloatTensor( tf_model._decoder._stacking_layers[0]._crossatt_layer._layer. _q_transform_layer._bias.numpy()) pt_model._decoder._stacking_layers[0][ 1]._layer._kv_transform_layer._kernel.data = torch.FloatTensor( tf_model._decoder._stacking_layers[0]._crossatt_layer._layer. _kv_transform_layer._kernel.numpy()) pt_model._decoder._stacking_layers[0][ 1]._layer._kv_transform_layer._bias.data = torch.FloatTensor( tf_model._decoder._stacking_layers[0]._crossatt_layer._layer. _kv_transform_layer._bias.numpy()) pt_model._decoder._stacking_layers[0][ 1]._layer._output_transform_layer._kernel.data = torch.FloatTensor( tf_model._decoder._stacking_layers[0]._crossatt_layer._layer. _output_transform_layer._kernel.numpy()) pt_model._decoder._stacking_layers[0][ 1]._layer._output_transform_layer._bias.data = torch.FloatTensor( tf_model._decoder._stacking_layers[0]._crossatt_layer._layer. _output_transform_layer._bias.numpy()) pt_model._decoder._stacking_layers[0][ 1]._norm_layer.weight.data = torch.FloatTensor( tf_model._decoder._stacking_layers[0]._crossatt_layer._norm_layer. gamma.numpy()) pt_model._decoder._stacking_layers[0][ 1]._norm_layer.bias.data = torch.FloatTensor( tf_model._decoder._stacking_layers[0]._crossatt_layer._norm_layer. beta.numpy()) pt_model._decoder._stacking_layers[0][ 2]._layer._dense1.weight.data = torch.FloatTensor( tf_model._decoder._stacking_layers[0]._ffn_layer._layer._conv1. kernel.numpy().transpose([1, 0])) pt_model._decoder._stacking_layers[0][ 2]._layer._dense1.bias.data = torch.FloatTensor( tf_model._decoder._stacking_layers[0]._ffn_layer._layer._conv1. bias.numpy()) pt_model._decoder._stacking_layers[0][ 2]._layer._dense2.weight.data = torch.FloatTensor( tf_model._decoder._stacking_layers[0]._ffn_layer._layer._conv2. kernel.numpy().transpose([1, 0])) pt_model._decoder._stacking_layers[0][ 2]._layer._dense2.bias.data = torch.FloatTensor( tf_model._decoder._stacking_layers[0]._ffn_layer._layer._conv2. bias.numpy()) pt_model._decoder._stacking_layers[0][ 2]._norm_layer.weight.data = torch.FloatTensor( tf_model._decoder._stacking_layers[0]._ffn_layer._norm_layer.gamma. numpy()) pt_model._decoder._stacking_layers[0][ 2]._norm_layer.bias.data = torch.FloatTensor( tf_model._decoder._stacking_layers[0]._ffn_layer._norm_layer.beta. numpy()) assert_equal_numpy( tf_model(tf_inps, is_training=False).numpy(), pt_model(pt_inps, is_training=False).detach().numpy(), 5e-6)
def test_seq2seq(): params = copy.deepcopy(get_hyper_parameters("transformer_toy")["model.params"]) params["modality.source.dim"] = None params["modality.target.dim"] = None params["modality.source.timing"] = None params["modality.target.timing"] = None src_vocab_meta = dict(vocab_size=8, eos_id=7, bos_id=6, unk_id=5) trg_vocab_meta = dict(vocab_size=5, eos_id=4, bos_id=3, unk_id=2) parsed_inputs = { "src": tf.convert_to_tensor( [[0, 1, 1, 7], [1, 7, 7, 7]], tf.int64), "src_padding": tf.convert_to_tensor([[0, 0, 0, 0.], [0, 0, 1, 1.]], tf.float32), "trg_input": tf.convert_to_tensor([[3, 0, 1], [3, 2, 4]], tf.int32), "trg": tf.convert_to_tensor([[0, 1, 4], [2, 4, 4]], tf.int32), "trg_padding": tf.convert_to_tensor([[0, 0, 0.], [0, 0, 1.]], tf.float32), } model = build_model({"model.class": "transformer", "params": params}, src_meta=src_vocab_meta, trg_meta=trg_vocab_meta) _ = model(parsed_inputs, is_training=False) for w in model.trainable_weights: if "target_symbol_modality/shared/weights" in w.name: tf.compat.v1.assign(w, numpy.array( [[-0.29354253, -0.23483634, 0.25630027, -0.02696097, -0.5017841, -0.01427859, 0.64076746, 0.10676116], [-0.19711176, -0.20760003, -0.48422408, -0.0074994, -0.31429327, 0.00126553, -0.17251879, 0.29386985], [0.38033593, -0.27076742, 0.2611575, 0.66763735, 0.5333196, -0.52800345, -0.5451049, 0.5960151], [-0.38007882, 0.47841036, 0.11322564, 0.3999585, -0.5566431, -0.6169907, 0.5290351, -0.48975855], [0.24198133, -0.1712935, -0.13487989, 0.03922045, -0.27576318, 0.15308863, 0.18018633, -0.49891895]] )) elif "target_symbol_modality/shared/bias" in w.name: tf.compat.v1.assign(w, numpy.array( [-0.12844944, 0.70201373, 0.47467923, 0.17776501, -0.57099354] )) elif "input_symbol_modality/emb/weights" in w.name: tf.compat.v1.assign(w, numpy.array( [[-0.28932106, 0.04174006, 0.32917994, -0.01771283, -0.32744384, 0.4569562, -0.4678616, 0.00129563], [-0.4225411, -0.59086347, -0.0714885, 0.51049083, -0.5401395, 0.3862279, -0.53301275, 0.30440414], [-0.19314134, 0.09168714, -0.5058322, -0.42353332, 0.5074443, 0.03560042, 0.26724458, 0.33088684], [-0.5153856, -0.38528442, -0.20011288, 0.4713922, 0.13764167, -0.18305543, -0.43612635, 0.5469119], [-0.54713076, 0.32743508, 0.38312858, -0.5525645, 0.591134, 0.1707223, 0.15555906, -0.42832434], [-0.5138424, -0.21375301, -0.46360433, -0.6103692, -0.50063866, 0.24583805, -0.5414497, -0.01820809], [0.3424672, -0.38758308, 0.05292654, 0.10646945, -0.09475929, 0.5051289, 0.16801137, 0.03101033], [-0.10960919, 0.20824891, -0.02183038, -0.06829894, 0.48780817, -0.18522224, 0.22240955, -0.21551234]] )) elif ("TransformerEncoder/layer_0/self_attention_prepost_wrapper/" "self_attention/output_transform/kernel") in w.name: tf.compat.v1.assign(w, numpy.array( [[0.31903958, 0.41097552, 0.35810417, 0.4822548, 0.5416022, 0.02170408, 0.32241964, -0.54333895], [0.5172518, 0.14113712, 0.44610864, -0.43546906, 0.49923056, 0.23127198, 0.310534, 0.3501947], [0.5763511, -0.4778806, 0.3984726, 0.13659805, -0.05111057, 0.4764889, 0.05881822, -0.37829816], [-0.33052838, -0.3291011, -0.59498054, 0.2654276, -0.5715602, 0.01546502, 0.04336095, 0.13782066], [-0.32840976, -0.37728345, -0.49385822, -0.49648887, 0.4832974, 0.07143259, -0.17042065, 0.43592864], [0.31292784, 0.01520997, 0.40785295, -0.12775904, 0.03555053, -0.35662168, -0.5096859, 0.33710766], [-0.36864457, 0.30672514, -0.4093505, -0.4461822, -0.41201153, 0.12536913, -0.3134546, -0.110695], [0.50774044, 0.25777447, -0.18048626, -0.30132556, 0.3435768, 0.49845392, -0.21432358, -0.05989999]] )) elif ("TransformerEncoder/layer_0/self_attention_prepost_wrapper" "/self_attention/qkv_transform/kernel") in w.name: tf.compat.v1.assign(w, numpy.array( [[0.24556783, -0.10109329, -0.18614727, -0.35749245, 0.07600775, -0.30707863, 0.11381295, -0.21648653, -0.32361317, 0.04083973, 0.00325903, 0.17453268, -0.38458756, -0.12808836, -0.30286443, -0.28138128, 0.3906658, 0.2981322, 0.1857591, -0.10963717, 0.13652292, -0.42696893, -0.32537884, -0.17609134], [0.00684109, 0.40689567, 0.22115704, -0.22863819, -0.22739726, 0.3783851, -0.37274942, -0.21842214, -0.22557294, -0.07110339, 0.3998916, -0.0190008, 0.27676454, -0.19919433, 0.2616723, -0.41782314, -0.2811813, -0.3239204, 0.13037983, 0.10246852, -0.14516768, -0.13455674, -0.20624177, 0.30381766], [-0.36161476, 0.3910825, 0.11459449, -0.19012608, -0.1930628, -0.09042051, 0.04295725, -0.09732714, -0.27065122, -0.1735073, -0.11896703, -0.2472982, -0.24865237, 0.0597097, -0.23580097, -0.402398, -0.04311767, -0.14832097, 0.25989994, -0.03256506, -0.3376931, 0.35324004, 0.01395121, -0.28511477], [0.33902344, -0.16730174, 0.2059339, -0.0727739, -0.24657604, 0.01062217, -0.21674432, 0.11485538, 0.23314235, -0.30125052, 0.32238856, -0.2450316, 0.03718695, -0.276408, 0.23392966, -0.07773718, 0.3429754, -0.19731745, 0.37889633, 0.34160677, 0.05413216, 0.03037485, -0.3704696, 0.28774682], [-0.41983247, 0.1209394, -0.03301042, 0.20576969, -0.28212637, -0.25600716, -0.09135348, -0.19963133, -0.1577549, -0.13313296, -0.02467829, 0.39583513, -0.21820472, 0.10990372, -0.42987105, -0.3018305, -0.33682942, -0.04609847, -0.0978007, -0.35909522, 0.35906085, -0.38199574, -0.02560577, 0.4065493], [-0.39747363, -0.21786559, 0.4050602, 0.29975984, -0.03308517, -0.05114299, 0.23231843, -0.42908302, -0.09869319, -0.3929163, 0.14195767, -0.04656759, 0.2699246, 0.1801227, 0.14472279, -0.4127182, -0.4004244, -0.10136119, 0.4069151, 0.3895177, -0.15835935, -0.13569432, -0.38402212, -0.16429195], [-0.1027582, 0.02577147, 0.39300737, -0.10241205, -0.4256417, 0.33153847, -0.0325374, -0.13393977, 0.05391803, -0.20058648, -0.25471783, 0.08702543, -0.09722248, 0.02570912, -0.279415, 0.04044545, -0.27716812, 0.19806209, 0.22688219, -0.30685633, 0.00624642, 0.14048973, -0.2722684, 0.39918897], [-0.19335268, 0.38261148, 0.30058286, 0.25313148, 0.27221575, 0.37937936, 0.1745182, 0.14772478, -0.27204615, 0.38106957, 0.36370513, 0.16695651, -0.40864846, -0.14278689, 0.34316894, 0.41350552, -0.42566204, -0.22474506, -0.18263665, 0.11183658, -0.12859318, 0.02102521, -0.1425604, 0.11403349]] )) elif "TransformerEncoder/layer_0/ffn_prepost_wrapper/ffn/dense1/kernel" in w.name: tf.compat.v1.assign(w, numpy.array( [[0.38400275, 0.11049551, 0.19255298, 0.45194864, -0.02915239, 0.31835914, -0.3630433, 0.11081731, -0.02559841, 0.38685995], [0.42969477, 0.2031151, 0.5144137, -0.07936049, 0.31766498, 0.5058452, 0.44898677, 0.16335446, 0.3953011, 0.4361714], [0.04883695, -0.56701475, 0.09635973, -0.50472724, -0.1245037, -0.37787604, -0.21818402, 0.16247958, -0.14578387, -0.41005552], [0.13449967, 0.05132979, -0.5468524, -0.17919052, 0.01128888, 0.09902984, 0.23214585, -0.08920336, 0.55008626, 0.50717974], [-0.1738911, -0.24616602, 0.18358463, -0.11349753, 0.15567136, -0.45293823, 0.29155105, 0.49324703, 0.01795202, 0.255095], [-0.23427847, -0.47127584, 0.47553408, 0.17752594, -0.4635463, -0.05620468, -0.5232727, 0.39365137, -0.38289946, 0.05879569], [0.25051618, 0.26999742, -0.24446961, 0.03792298, 0.01752973, -0.41537094, 0.44205165, -0.11403576, -0.3807313, -0.23905703], [-0.33319134, -0.47972375, 0.526567, 0.34260195, -0.01981884, -0.02918285, -0.02829635, -0.5294999, 0.563005, 0.05829275]] )) elif "TransformerEncoder/layer_0/ffn_prepost_wrapper/ffn/dense2/kernel" in w.name: tf.compat.v1.assign(w, numpy.array( [[0.2340402, -0.10299325, 0.03826767, -0.00556576, 0.16777557, -0.48395926, -0.21232244, 0.540642], [-0.5568968, -0.24176422, 0.17467064, 0.3885694, 0.4655552, -0.15393665, -0.4475953, -0.3920542], [0.07647067, 0.2340278, -0.13460535, -0.34944105, 0.0448994, 0.35044646, -0.5451377, -0.39633614], [0.16932797, 0.4503368, -0.48202705, -0.05000919, -0.3586144, 0.07879007, -0.47378975, -0.5153118], [-0.4939471, -0.49206224, 0.33845508, -0.5155843, -0.07823312, 0.30778152, -0.14456016, -0.49705222], [0.23529834, 0.39454746, -0.3392254, -0.31639364, 0.39075094, 0.55396605, 0.03435838, 0.3698709], [-0.01985615, -0.14796564, -0.04773241, 0.1197027, 0.02213496, 0.24299401, 0.23960501, 0.45019186], [-0.1280163, -0.11015153, 0.19618726, -0.55472195, -0.45635638, -0.15839794, 0.28029287, 0.00874251], [-0.18816125, -0.16009945, -0.14088362, 0.41544813, -0.20673174, 0.01065433, 0.03431308, -0.17323837], [-0.30255532, 0.5155908, 0.23801541, 0.46748185, -0.42719585, -0.49111396, 0.3950773, -0.27734205]] )) elif ("TransformerEncoder/layer_1/self_attention_prepost_wrapper/" "self_attention/output_transform/kernel") in w.name: tf.compat.v1.assign(w, numpy.array( [[0.42618555, -0.09034979, -0.23231441, -0.43777925, 0.45706886, -0.59829664, 0.4076385, 0.23851973], [0.05634236, 0.17002487, -0.08434552, 0.31617373, 0.03625625, 0.5910465, -0.6076178, -0.2687951], [-0.14819229, -0.27034125, 0.2064324, -0.19751346, 0.21064728, 0.29283345, 0.23406833, 0.10519284], [0.31500018, -0.4173568, -0.00893188, -0.26349744, 0.15418595, -0.399687, -0.22666007, -0.6096985], [-0.1316917, -0.36008307, -0.43647486, 0.10060841, -0.16681895, -0.35083786, 0.26369733, -0.12640283], [0.5797457, -0.59191436, -0.57749504, -0.54847366, -0.20692074, 0.4509862, -0.01773721, 0.1577], [0.4081785, 0.5246411, -0.5135473, -0.23788959, -0.26497075, -0.23121881, 0.35329401, 0.42074102], [-0.46347424, 0.56120163, -0.2939334, 0.2747522, 0.56474787, 0.5690356, 0.19718772, -0.09090984]] )) elif ("TransformerEncoder/layer_1/self_attention_prepost_wrapper/" "self_attention/qkv_transform/kernel") in w.name: tf.compat.v1.assign(w, numpy.array( [[-0.04137263, 0.4122521, 0.07474831, -0.42290825, 0.01918331, -0.0367808, 0.20840707, -0.19495474, -0.36590886, 0.12961635, -0.42065755, 0.21793994, 0.15142605, 0.05064282, 0.3728448, 0.4305556, -0.19640265, -0.13260049, 0.41600618, -0.30270132, 0.28347465, -0.2972833, -0.22339822, -0.4168277], [-0.42739302, 0.0618836, 0.30369553, -0.01105291, -0.2725063, 0.18827173, -0.07787129, 0.29560563, 0.11015823, -0.2556733, 0.3800684, 0.20649257, -0.03591421, 0.35618058, -0.39821273, 0.0430806, -0.37791556, -0.05824929, 0.29839876, 0.06364432, -0.28479278, 0.37887844, -0.19407392, -0.24432379], [-0.2754909, 0.21458694, 0.2540948, -0.06881586, 0.2752199, -0.42529625, -0.18034342, -0.2641306, 0.08662507, -0.19239433, -0.01936874, -0.42879313, 0.2515919, 0.05828688, -0.35050425, 0.19613442, 0.10595468, -0.06380415, 0.14495179, -0.26701403, 0.33381835, 0.11836699, 0.10901466, -0.19060831], [-0.08439368, -0.1435681, -0.38354927, 0.29710206, 0.39372167, 0.29005793, 0.22486511, 0.10090873, -0.27392572, 0.12495866, -0.38597837, 0.37385282, -0.15801638, 0.34403047, 0.05333185, -0.19141418, -0.43146238, -0.09826642, 0.39207748, 0.02903318, -0.0447951, -0.140995, 0.12605539, -0.27343658], [-0.14746845, 0.26028237, -0.14068425, -0.02098277, -0.34208745, -0.36879313, 0.3709258, -0.18287906, -0.38343272, 0.01450509, 0.33475187, 0.19835839, -0.02770916, -0.19535396, 0.24291894, 0.40508488, 0.1228393, 0.35743287, -0.31064862, -0.2738737, -0.08634344, 0.17820784, 0.2404854, -0.21379128], [0.32416382, 0.23761937, -0.2714734, 0.01659575, 0.12218228, 0.08210799, 0.39640966, 0.04924238, -0.10259542, -0.42907375, -0.0455032, -0.04837993, -0.25596887, -0.16206014, -0.40621698, 0.10435715, 0.2919118, -0.3757009, 0.12669042, -0.06276929, 0.08691922, 0.01388359, 0.2609237, 0.14391366], [-0.37109214, 0.08338836, 0.41613457, 0.09220138, 0.14755598, -0.3846822, -0.32047546, -0.11989969, 0.04941088, 0.3733643, -0.22359593, 0.01040426, -0.13329476, 0.03873777, 0.25831434, 0.04679212, -0.34217292, -0.23983024, 0.36969563, 0.35033616, 0.05077001, 0.32096437, 0.2942368, -0.06438693], [0.04559416, 0.3110021, 0.10469446, -0.09112707, -0.21549596, -0.08703595, 0.19566664, -0.27119064, -0.31012705, -0.3460493, 0.20034257, 0.34390983, -0.30513322, 0.30294558, 0.15193626, -0.13466576, -0.15653265, -0.04085603, -0.04187199, -0.3818181, 0.35413423, -0.11948714, 0.12659273, 0.33491793]] )) elif "TransformerEncoder/layer_1/ffn_prepost_wrapper/ffn/dense1/kernel" in w.name: tf.compat.v1.assign(w, numpy.array( [[0.16969907, 0.538725, -0.47220635, -0.39862955, 0.5590445, -0.57381415, 0.55189013, -0.1241096, -0.1750552, 0.07282209], [-0.04967839, -0.29894733, 0.48699057, -0.26354527, -0.11624891, 0.00518572, 0.06982511, 0.21453673, 0.52487314, 0.50849414], [-0.29642364, -0.1552884, 0.37976956, -0.09915912, 0.21726537, 0.09865189, -0.3579256, 0.2882828, -0.5435448, 0.34120053], [-0.16734263, -0.30591854, -0.48299694, 0.36032963, 0.3083346, 0.32025862, -0.0323239, -0.03540909, 0.19812691, 0.56041396], [0.08146846, -0.4032659, 0.43548548, -0.505157, 0.29625255, 0.20229155, -0.2784496, -0.16810659, 0.00465661, -0.46176454], [0.25855982, -0.44527876, -0.05630809, 0.44814825, 0.4672327, 0.07238638, 0.23067313, -0.31218028, 0.5251508, -0.46993703], [0.36020505, 0.48421, 0.04297256, 0.07937276, 0.39654619, 0.08334208, -0.44477332, 0.15238297, -0.14505252, 0.5653666], [0.17023551, 0.05648631, -0.5590816, -0.4013535, 0.00587964, -0.41224653, -0.5178517, -0.44671488, -0.13213646, -0.16264695]] )) elif "TransformerEncoder/layer_1/ffn_prepost_wrapper/ffn/dense2/kernel" in w.name: tf.compat.v1.assign(w, numpy.array( [[0.08363676, 0.443043, -0.20048293, 0.5397774, -0.08774236, 0.51563346, 0.44048393, 0.05069989], [-0.39923793, 0.27010256, 0.3120396, 0.15755522, 0.09888685, 0.09209388, 0.23463911, -0.20073885], [0.39725387, 0.3083284, 0.04398292, -0.5214203, 0.1661511, 0.32843602, 0.535144, -0.30733716], [-0.52302945, 0.09949869, -0.20001906, -0.4563232, 0.10634673, -0.0867821, 0.2130729, 0.15544009], [-0.16209882, 0.47079623, -0.36366975, -0.39391387, -0.13728681, 0.36896384, -0.1279692, -0.24792987], [0.4540763, 0.43117046, 0.34526706, -0.44267043, -0.2801833, 0.09091371, 0.31143135, -0.46842438], [-0.3841617, 0.3537798, -0.456631, -0.07963607, 0.18825197, 0.34253138, 0.00311643, -0.39619297], [0.19681883, 0.02538323, 0.49230504, -0.54670614, -0.16814995, 0.26320857, -0.2583875, -0.45845556], [0.10035574, -0.33199033, -0.06377029, -0.38322705, 0.18576187, 0.30481344, 0.30165493, -0.56413436], [0.13095653, 0.5693759, -0.34928244, -0.00579017, 0.45523894, 0.45559692, -0.4755445, -0.5578483]] )) elif ("TransformerDecoder/layer_0/self_attention_prepost_wrapper/" "self_attention/output_transform/kernel") in w.name: tf.compat.v1.assign(w, numpy.array( [[0.41402858, -0.2655511, 0.21687216, -0.05976683, -0.24678236, -0.55986947, -0.10050869, 0.36443913], [-0.31218863, -0.08026814, -0.3503775, -0.2830528, 0.19764078, 0.07665694, -0.22002375, 0.58338326], [0.36593944, 0.47826117, -0.3155697, 0.22407556, -0.2367759, 0.5582003, -0.01308447, 0.02416301], [-0.5932773, 0.54228276, 0.07887, -0.36850107, -0.57571995, 0.52597564, -0.12966257, -0.06494093], [-0.5416004, -0.4324838, 0.5738513, 0.23318034, -0.5079873, 0.44698435, 0.1884408, -0.4100449], [-0.41715717, -0.47995192, 0.27436692, 0.45396346, -0.32279193, -0.52322745, -0.22139937, 0.46218258], [0.04606843, -0.48210734, -0.09731799, 0.1566211, 0.3348605, 0.53798, 0.2066397, 0.17096424], [0.5118193, -0.26824263, 0.0513528, -0.22810039, -0.02520913, -0.25055912, -0.21125275, 0.01200509]] )) elif ("TransformerDecoder/layer_0/self_attention_prepost_wrapper/" "self_attention/qkv_transform/kernel") in w.name: tf.compat.v1.assign(w, numpy.array( [[-0.24635717, -0.35896713, 0.39586702, -0.03602478, 0.27512792, 0.23269245, 0.29596278, -0.13523233, 0.3122929, 0.01758271, 0.19535479, 0.42010358, 0.3058509, -0.27858323, -0.09621406, -0.28900337, -0.13637415, 0.2554522, -0.13693246, 0.23890129, 0.22502461, -0.00342193, -0.37178487, 0.04001474], [-0.06197342, 0.28338936, 0.10876206, 0.21770415, -0.2445885, -0.37382, 0.24960616, -0.28366768, 0.33277413, 0.24190459, 0.28501043, 0.2390792, -0.21722354, -0.09839588, -0.07514569, 0.08434585, -0.17455393, -0.39285085, 0.3604456, -0.04403484, 0.17325982, 0.266789, 0.27641353, 0.2629675], [0.31777444, -0.18994613, 0.07876977, 0.19285682, -0.3603885, -0.07359949, 0.39663008, 0.12972179, 0.32373634, -0.28222823, 0.07523808, 0.06840143, 0.2784874, -0.32616594, -0.37903282, 0.11678198, -0.2441357, -0.15710688, -0.00175741, -0.40035915, -0.09226942, 0.08680966, 0.25157234, 0.00786397], [-0.06718335, -0.21293627, 0.23377934, -0.07398105, -0.04577821, 0.4012753, -0.36116257, 0.27832034, 0.20620236, -0.15069339, 0.16214707, -0.42465132, 0.25478825, -0.08184978, 0.35768852, -0.12693104, -0.1273953, -0.3078432, 0.33522883, 0.34014687, -0.08295268, -0.36013618, -0.08690733, -0.07324457], [-0.0609462, 0.06251469, -0.04659629, 0.3167083, -0.02005619, 0.32234064, 0.35482922, -0.0772118, 0.3867505, 0.3833268, -0.2319926, -0.417385, -0.38126078, 0.37261078, 0.0596388, 0.09162065, -0.23212992, -0.25532508, -0.3144799, 0.28181675, 0.01341996, 0.19811288, -0.21834192, -0.39427295], [-0.13712531, 0.2572454, 0.2866812, 0.10211042, 0.06285053, -0.3894317, -0.04404226, -0.39091605, -0.16874191, 0.08648756, -0.30481267, 0.16437915, -0.23644, 0.07409009, -0.39548072, 0.35895494, 0.03730175, 0.4324384, -0.2938407, 0.38754657, -0.3012539, -0.11363283, -0.28678095, -0.1598432], [0.00581551, 0.14337441, -0.04939786, 0.11189356, 0.31094417, 0.01152644, 0.27642164, -0.09637818, -0.09211436, -0.16248363, 0.39744857, -0.4116622, -0.05383742, 0.36805126, 0.14875862, 0.1099014, 0.371321, -0.41085994, -0.18536153, 0.20604655, -0.13384223, -0.14118773, -0.1283133, -0.39778396], [-0.01566258, -0.4047187, -0.37664068, -0.19478449, 0.09347895, -0.36023095, 0.21561489, -0.33089578, -0.2711009, -0.03610542, -0.3796572, 0.306676, 0.27266768, 0.22641936, -0.30573982, -0.18740533, -0.34311372, -0.22143514, -0.41552392, 0.42686227, -0.1086936, 0.03383243, -0.15354112, -0.26625448]] )) elif ("TransformerDecoder/layer_0/encdec_attention_prepost_wrapper/" "encdec_attention/output_transform/kernel") in w.name: tf.compat.v1.assign(w, numpy.array( [[-0.1279256, -0.2419937, -0.5854874, 0.57889825, -0.5364065, 0.23631936, -0.49949092, -0.30174196], [0.00957078, -0.49736997, -0.4237002, 0.0218854, -0.17279565, -0.5768471, -0.18963015, 0.10355526], [0.11799914, -0.292151, -0.36201292, -0.266887, 0.15741825, -0.11333472, -0.03553617, 0.0177772], [-0.39861536, 0.17891657, -0.22581154, 0.07609612, -0.34631196, 0.26317436, 0.41848058, 0.27004486], [-0.37255478, -0.20311174, 0.5176136, -0.54658747, 0.23746693, -0.03754926, 0.04889613, -0.41350323], [0.2125783, -0.536155, -0.19549471, 0.36943835, 0.24639928, 0.07458866, 0.28700095, -0.36578485], [-0.2657523, -0.2433975, -0.56110847, -0.2861476, -0.19445652, 0.21033949, -0.30730212, 0.40339154], [0.31910568, 0.0055629, 0.03742898, -0.5246967, 0.35341913, 0.3554458, 0.5315719, 0.13093019]] )) elif ("TransformerDecoder/layer_0/encdec_attention_prepost_wrapper/" "encdec_attention/q_transform/kernel") in w.name: tf.compat.v1.assign(w, numpy.array( [[0.51545686, 0.0990485, 0.29777205, -0.28110617, -0.26308733, 0.2853282, -0.31212774, 0.30727994], [0.5417524, 0.12922692, 0.3285774, -0.02031326, 0.08855647, -0.00454164, 0.02288318, 0.39679402], [-0.09431475, -0.2857204, -0.29803967, 0.28193474, 0.26423824, -0.31383288, -0.25300246, -0.01376557], [0.12011659, 0.55608934, -0.01549584, -0.48516896, -0.44164532, -0.16531923, 0.44081384, -0.54160094], [-0.3235532, 0.55393785, 0.2136209, 0.08658487, 0.02760661, -0.24593821, 0.23313332, -0.03452164], [-0.3659288, -0.55161166, -0.5393511, -0.08154327, 0.47045785, -0.2545886, 0.603108, 0.17091894], [-0.41575676, -0.24764174, 0.33940715, -0.49895483, 0.14083397, 0.05251276, 0.09940594, 0.30034548], [-0.5737393, -0.45933425, -0.02393657, -0.12469256, -0.24861848, 0.48773366, -0.38281965, 0.06820959]] )) elif ("TransformerDecoder/layer_0/encdec_attention_prepost_wrapper/" "encdec_attention/kv_transform/kernel") in w.name: tf.compat.v1.assign(w, numpy.array( [[0.3608706, -0.16985774, 0.04648876, 0.17727554, -0.32050753, 0.15797412, 0.32923543, -0.19890809, -0.09514797, 0.09165347, -0.08939207, 0.1240828, 0.12936771, -0.48354328, 0.09154546, 0.06640613], [0.26706707, -0.07982218, -0.28840077, -0.15964293, 0.44048142, 0.10202003, -0.19224763, 0.4643935, -0.49145675, 0.28452814, -0.28381097, -0.1886301, 0.3626212, 0.48149836, -0.40126383, 0.01182055], [0.48325312, 0.13339198, 0.08147466, 0.01886415, 0.410465, -0.24456823, -0.04810286, 0.3934772, -0.42655325, -0.12829137, 0.47660065, -0.3516115, -0.11145651, -0.02882326, -0.38462532, 0.16618061], [0.28752756, -0.09809136, -0.06697667, -0.22326052, 0.33962095, -0.06639445, -0.06673455, 0.03969002, 0.03658247, 0.2047621, 0.41957307, -0.27317607, -0.1286192, -0.1504153, -0.08790445, -0.27503848], [0.40700352, -0.13340664, 0.48895872, 0.2091173, -0.4158994, 0.42262292, 0.45204484, 0.31661832, -0.16831684, -0.43958127, 0.40800595, 0.4231466, 0.2662462, 0.4360491, -0.05090606, 0.41579437], [-0.1475159, 0.05631268, 0.43667984, 0.22322762, 0.24188244, -0.2558658, 0.05513358, -0.44220436, 0.47696745, 0.30288208, 0.35236907, -0.46022415, -0.2354449, -0.2824862, 0.1728853, 0.00242376], [-0.19901407, -0.17316806, 0.34936786, 0.05637395, -0.08862174, 0.15412652, 0.14734995, -0.02360725, 0.20836592, 0.10715961, 0.21128082, -0.01028705, 0.27915657, 0.00645471, 0.34993672, 0.46311176], [0.40358865, -0.12622762, 0.11518359, 0.18501854, 0.01984668, 0.45133805, 0.1628021, -0.17971015, -0.16342247, -0.22245312, -0.26478374, 0.160591, 0.4486302, -0.19825566, 0.04753971, 0.12643707]] )) elif "TransformerDecoder/layer_0/ffn_prepost_wrapper/ffn/dense1/kernel" in w.name: tf.compat.v1.assign(w, numpy.array( [[-0.06720757, 0.55263114, -0.37820417, -0.18817183, 0.4967841, 0.5301496, 0.44765162, 0.17229474, 0.02037746, -0.38267606], [0.22507912, 0.08319503, -0.42931908, 0.21395624, 0.4883101, 0.02807504, -0.10768619, -0.47498938, 0.04546309, 0.51695967], [-0.32582825, -0.15555033, -0.35707173, -0.00528497, 0.11157733, -0.4079039, -0.20309281, -0.2786939, -0.00143158, -0.45975608], [0.0592798, -0.297385, 0.35483736, 0.2347272, -0.3477485, 0.26017946, -0.17936438, 0.44473732, -0.28609666, -0.14807671], [-0.3869655, -0.5571348, -0.38598603, -0.41803488, 0.43944812, -0.3425563, 0.25616652, -0.0285089, -0.0508908, -0.54111296], [-0.44107342, -0.5042058, 0.5217055, -0.34677118, 0.475623, 0.18002027, -0.44467062, 0.05279869, -0.30962384, -0.45696396], [-0.11149651, 0.3705026, -0.5126401, 0.06722903, 0.22575969, -0.23028824, 0.2056027, -0.39192414, -0.25298402, 0.4379238], [0.14971024, 0.42451167, 0.37757248, -0.3726549, -0.17506334, -0.46460786, -0.02499455, 0.13482589, -0.12902525, -0.19523734]] )) elif "TransformerDecoder/layer_0/ffn_prepost_wrapper/ffn/dense2/kernel" in w.name: tf.compat.v1.assign(w, numpy.array( [[-0.44042823, 0.45197666, -0.23344472, 0.45998847, -0.17414865, 0.4641745, 0.4826498, -0.1315352], [0.41060203, -0.211938, 0.08441406, 0.2431289, -0.38785285, 0.35918987, 0.07967973, -0.19248444], [0.17039984, 0.01675391, -0.19650468, 0.10323095, -0.02209324, -0.24919105, 0.16697949, 0.11663049], [0.17856616, -0.20257097, 0.3182906, 0.1157276, -0.45809188, -0.13065588, -0.5293646, -0.04682791], [-0.19376227, -0.5453018, -0.0328182, -0.5452718, 0.26869357, 0.13249546, 0.08024281, 0.11003381], [-0.23756227, -0.29575357, -0.50909173, -0.05765748, -0.0089184, 0.489527, 0.0540911, -0.20290643], [-0.43088597, -0.03776497, -0.07004839, 0.3612193, 0.2700277, 0.3630551, -0.35514504, 0.0078786], [-0.3577707, 0.5772364, -0.45408776, 0.04695731, 0.12955356, 0.08641922, -0.06749266, -0.22854668], [0.3447554, -0.50018543, -0.4450423, -0.345627, 0.4853915, -0.38487256, -0.23583022, 0.41968864], [0.5223309, 0.34582454, 0.24228495, 0.4505279, 0.00524783, 0.33739161, 0.1729073, 0.46376586]] )) elif ("TransformerDecoder/layer_1/self_attention_prepost_wrapper/" "self_attention/output_transform/kernel") in w.name: tf.compat.v1.assign(w, numpy.array( [[-0.446121, 0.3940925, 0.49132103, 0.17713946, 0.5267928, 0.33675808, 0.44058722, -0.43157172], [-0.23504972, 0.1617412, 0.2769773, -0.26133326, 0.24745297, -0.0520584, 0.07277727, -0.5577672], [-0.29327726, 0.2514521, 0.32843417, 0.5675153, -0.5442774, -0.24685362, -0.3434327, 0.29523093], [0.25270784, -0.20233193, -0.13284832, 0.28228354, -0.4794641, 0.12789321, -0.39262465, 0.04397899], [-0.60009784, 0.45697302, -0.32597286, -0.03012645, 0.01654047, -0.3432645, -0.52298236, -0.45876426], [-0.19784635, 0.01058447, -0.58458495, -0.5126084, -0.5655494, -0.41740847, -0.19458848, -0.10731643], [-0.5258043, -0.61217636, -0.47019628, -0.3324889, -0.39158016, 0.36343306, -0.36333203, -0.22256723], [0.24401158, -0.13122407, 0.5713683, -0.6086697, 0.12495714, 0.25823617, -0.09232122, 0.5900312]])) elif ("TransformerDecoder/layer_1/self_attention_prepost_wrapper/" "self_attention/qkv_transform/kernel") in w.name: tf.compat.v1.assign(w, numpy.array( [[0.27481452, 0.0116879, 0.36719075, -0.40440372, -0.1954606, -0.2300574, -0.04979965, 0.15613547, 0.32280543, 0.3273132, 0.3912786, 0.4046168, -0.30568987, -0.33408988, 0.15435639, -0.08106208, 0.32937118, -0.34070706, 0.0546439, -0.24983734, 0.0207603, 0.08601627, -0.27549195, 0.20412138], [0.14348724, 0.18185094, 0.167887, -0.3021682, 0.2971051, 0.07907161, -0.37291273, -0.26329404, 0.24814805, -0.00783703, -0.1134795, 0.25298938, -0.0403159, 0.09382078, -0.25310278, 0.42588016, -0.0232923, -0.23894715, 0.26872233, -0.3017637, 0.35517278, 0.4123756, 0.35715845, -0.2612683], [0.251209, 0.30718777, -0.09743929, 0.37868705, -0.3782806, -0.10440734, -0.20695278, -0.42843944, 0.11033848, 0.4274877, 0.21334943, 0.3301848, 0.31885192, 0.3971382, -0.09676668, 0.22961542, 0.28164133, 0.28870395, 0.24603716, 0.13049194, -0.26271415, 0.3598245, 0.17889282, -0.09679371], [0.18480167, -0.423978, 0.28147706, 0.20233068, 0.07700345, 0.3950176, 0.16953233, -0.2767653, -0.0351927, -0.3871778, -0.10333872, -0.38401458, 0.08614203, -0.09418231, 0.1258482, 0.41503003, -0.23736389, 0.3829991, 0.20315519, -0.0506267, 0.02750155, 0.18088666, 0.32316545, 0.07156941], [-0.3365289, 0.07633492, 0.18811491, 0.12218675, -0.01712888, 0.11047456, 0.36789885, 0.07453135, 0.35507998, 0.32413712, 0.06988475, -0.316629, -0.09560555, -0.3577586, 0.11743674, -0.1154238, 0.40550312, -0.28373045, -0.28391486, 0.22130796, 0.19461158, 0.34828517, 0.3402731, 0.42168418], [0.22959384, -0.09466672, 0.13875905, 0.06585011, -0.08454975, -0.25139913, 0.24867311, -0.19710684, -0.38250047, 0.05279905, 0.09058633, 0.05691019, -0.43189391, -0.00754103, -0.42296854, -0.17274147, -0.1439153, -0.16499841, 0.4218262, 0.27872702, 0.269519, -0.284347, 0.00676736, -0.24074432], [-0.43105984, -0.18570966, -0.25307292, -0.19746126, 0.11514279, 0.101432, -0.12518859, 0.10440406, -0.42490405, 0.05715063, -0.2929991, 0.2661244, -0.12404522, 0.06171378, -0.15130952, 0.29441395, -0.41733328, 0.08141616, -0.34677923, -0.05524972, 0.18937346, -0.41702378, -0.06657425, 0.27120963], [0.07061633, 0.23987249, 0.22944674, 0.08817294, 0.22188488, -0.37523416, -0.3636308, 0.26619443, 0.05310896, -0.3865527, -0.0594418, 0.10325739, 0.14090309, -0.02832022, 0.09751496, -0.0530881, -0.04750797, -0.32113245, 0.25775167, -0.2249531, 0.17214248, -0.20723793, 0.05858463, -0.1042015]])) elif ("TransformerDecoder/layer_1/encdec_attention_prepost_wrapper/" "encdec_attention/output_transform/kernel") in w.name: tf.compat.v1.assign(w, numpy.array( [[-0.07463336, 0.0563764, 0.26746285, 0.58845574, 0.37224877, 0.22249967, -0.24321106, -0.48173416], [-0.30540833, 0.24408221, -0.06326765, -0.11097288, 0.10069352, -0.04288429, -0.44742495, 0.166543], [0.14135772, -0.26862615, -0.50849557, 0.5784133, -0.40443277, 0.51631385, -0.07799548, 0.28732932], [-0.09749961, 0.40039545, -0.06118071, -0.15212688, 0.34009832, 0.5772465, 0.48222512, -0.25559646], [-0.37269944, -0.15007514, 0.11866188, -0.0120635, -0.0109489, -0.60186726, -0.28244707, 0.32835752], [0.559184, 0.29157156, -0.35879636, 0.24650383, 0.5976046, -0.15556344, -0.11127496, -0.3011105], [0.5442193, -0.20431828, 0.36724424, -0.4528572, 0.10426587, 0.11822385, -0.05441982, 0.07673579], [-0.37118763, -0.24179482, -0.47427145, -0.17455658, 0.46202105, 0.24439615, -0.40861088, 0.2468313]])) elif ("TransformerDecoder/layer_1/encdec_attention_prepost_wrapper/" "encdec_attention/q_transform/kernel") in w.name: tf.compat.v1.assign(w, numpy.array( [[0.527518, -0.12114212, -0.40808892, 0.56731755, 0.2572146, 0.31378293, 0.20443302, -0.5630253], [0.6023007, -0.08801287, -0.55323726, -0.49235207, 0.18328917, -0.30462766, 0.4235236, -0.14947698], [0.05836785, -0.32457548, -0.5583779, 0.17587304, 0.13842088, -0.06220692, 0.05683714, -0.08522952], [0.11454928, 0.57845205, 0.40677744, -0.32356766, -0.10824966, 0.5729895, 0.09953862, -0.49825168], [-0.1325807, -0.5300193, -0.09281999, 0.23173773, -0.6103119, -0.17548105, -0.40918946, -0.6055349], [-0.26868924, -0.3843334, -0.14497796, 0.27963597, 0.38890153, -0.36425418, 0.13343394, -0.17070243], [-0.333827, 0.16035432, 0.17401373, -0.27310547, -0.23915032, -0.3207253, -0.00749028, 0.4876346], [0.3249125, -0.29519892, 0.49359602, -0.601942, -0.2753108, -0.39890692, 0.04002428, 0.41897768]])) elif ("TransformerDecoder/layer_1/encdec_attention_prepost_wrapper/" "encdec_attention/kv_transform/kernel") in w.name: tf.compat.v1.assign(w, numpy.array( [[-0.00709212, -0.4091934, 0.26065922, 0.40150464, 0.26608098, -0.3953911, -0.34422696, -0.06396389, -0.42655826, 0.35439622, -0.20109999, -0.18769062, 0.0049336, -0.06693316, -0.4382484, 0.00183201], [-0.02701962, 0.41023743, 0.02444375, 0.25569785, 0.04378641, -0.37053585, 0.06267512, -0.06767642, -0.44424844, 0.2922008, -0.44157362, -0.17749298, 0.17760682, -0.23238945, 0.3380952, 0.3164295], [0.20117998, -0.13788939, 0.14445269, -0.31664026, 0.49193084, 0.08778274, -0.17864335, 0.16035259, -0.17492938, -0.04081237, -0.4904747, -0.44932437, -0.19341111, -0.24871266, 0.38286912, -0.06130087], [0.2936057, -0.40730655, 0.18446267, 0.4097544, -0.0082581, 0.4734217, -0.46421993, -0.12871945, 0.22802174, 0.11106157, 0.26079726, -0.15126705, 0.40684378, -0.10213089, -0.24696314, -0.02051508], [-0.39994586, 0.16061008, 0.39812696, -0.3340621, -0.2076987, 0.20246327, -0.35409093, -0.4005847, -0.14170253, -0.21880937, 0.4408716, 0.22332358, -0.05699933, 0.17266095, 0.12294924, 0.38497412], [-0.09543967, -0.34888685, -0.42740452, 0.1517607, -0.00862324, -0.14572752, 0.47876465, -0.20919883, 0.32560217, 0.4249189, -0.3933282, -0.22128391, -0.34623587, 0.14449048, -0.3857503, -0.27833867], [0.11869216, 0.05883706, -0.21212506, 0.49957561, 0.15783632, -0.13721228, -0.21416295, -0.24007809, -0.294443, -0.16767824, 0.32042253, -0.31908023, 0.19871199, -0.43558514, -0.15620553, 0.11092794], [-0.04378927, 0.35632384, 0.20292461, -0.27540374, 0.22871876, -0.3632071, -0.40689313, 0.23316133, 0.37361324, -0.01663148, -0.12638855, -0.32248807, -0.20867753, 0.2503358, -0.39324427, -0.42774928]])) elif "TransformerDecoder/layer_1/ffn_prepost_wrapper/ffn/dense1/kernel" in w.name: tf.compat.v1.assign(w, numpy.array( [[0.24261475, -0.18643704, -0.01811624, 0.50356495, 0.01885831, -0.2399435, 0.23692662, -0.10759905, -0.38264602, 0.1351049], [0.21200335, -0.38962328, 0.29363745, 0.33583325, -0.24011764, 0.3635068, 0.4376179, 0.22551686, 0.5667083, -0.32501143], [-0.49261767, 0.1927172, -0.0046156, -0.56056315, 0.47630668, -0.31453356, 0.42453694, -0.32902807, 0.14415932, -0.5471806], [-0.3316853, 0.13726503, -0.40464914, 0.28158778, 0.47430885, -0.2569832, -0.5204258, -0.06528652, -0.5178821, 0.14735901], [0.5328666, -0.12720194, 0.5184237, 0.411116, -0.3576244, 0.34368336, 0.16382056, -0.33515644, 0.17608005, 0.26269817], [0.15965605, -0.25152162, -0.14534956, -0.2822171, 0.21284288, 0.05559379, 0.00327557, -0.4569926, -0.41969606, -0.56579554], [-0.43731868, 0.32843924, 0.29003292, 0.1792146, -0.33100158, -0.14961275, 0.12364352, -0.24879637, -0.39719564, 0.18711275], [0.05891687, 0.47468245, -0.20260152, -0.3408, 0.5017748, 0.1640119, 0.22170597, -0.34292257, -0.31018573, -0.07051545]])) elif "TransformerDecoder/layer_1/ffn_prepost_wrapper/ffn/dense2/kernel" in w.name: tf.compat.v1.assign(w, numpy.array( [[0.01111823, -0.50019276, -0.33186796, 0.52229214, -0.4700832, 0.5457233, -0.21241191, 0.37699038], [-0.28677762, -0.51243806, 0.52265644, -0.29745945, -0.35470137, -0.5047183, 0.18846446, -0.17220777], [-0.46509957, -0.00087285, -0.22127637, 0.4205513, -0.46209753, -0.11040562, -0.0872128, 0.34856063], [0.33827233, -0.31306413, -0.49311733, -0.49154714, -0.43418467, 0.11416692, 0.46271265, -0.1998105], [0.05865157, -0.19406608, 0.2172538, -0.2894684, 0.2942767, 0.19267291, -0.31736228, -0.04036039], [-0.49561584, -0.22174796, 0.15456653, -0.3632484, -0.4434304, -0.30227244, -0.4071117, 0.4257239], [0.2923094, 0.52523994, 0.22059155, 0.22125322, -0.30496007, -0.20421728, -0.5533153, 0.28908247], [-0.01375407, -0.42056724, -0.42731434, 0.14045459, -0.10852379, -0.14693105, 0.3797375, 0.5360898], [0.01416886, 0.2641362, -0.55372095, -0.17806509, -0.43746334, -0.39878494, -0.5338729, -0.50196886], [0.5125271, -0.31531927, -0.4611238, 0.38278532, -0.05637842, 0.23722917, -0.11141762, 0.44730043]])) outputs = model(parsed_inputs, is_training=False) assert numpy.sum((outputs.numpy() - numpy.array( [[[0.5600359, 1.0880388, 0.18974903, 1.8916442, 0.8008492], [1.0519575, 1.1763976, 0.42835617, 0.5486565, 0.7540616], [-0.09629793, 1.9182953, 0.4154176, -0.09568319, 0.32058734]], [[0.68914187, 1.1119794, -0.5154613, 1.8321573, 0.93645334], [-0.93543077, 1.9193068, 1.5986707, -1.1064756, -0.1642181], [0.2821706, 1.199893, -1.3765914, 0.02889553, 1.045481]]])) ** 2) < 1e-9 # test share / no share params["modality.share_embedding_and_softmax_weights"] = True params["modality.share_source_target_embedding"] = True model = build_model({"model.class": "transformer", "params": params}, src_meta=src_vocab_meta, trg_meta=src_vocab_meta) _ = model(parsed_inputs, is_training=False) assert len(model._src_modality.trainable_weights) == 2 for w in model._src_modality.trainable_weights: if "weights" in w.name: assert "shared_symbol_modality" in w.name assert len(model._trg_modality.trainable_weights) == 2 for w in model._trg_modality.trainable_weights: if "weights" in w.name: assert "shared_symbol_modality" in w.name assert model._output_linear_layer is None params["modality.share_embedding_and_softmax_weights"] = False params["modality.share_source_target_embedding"] = True model = build_model({"model.class": "transformer", "params": params}, src_meta=src_vocab_meta, trg_meta=src_vocab_meta) _ = model(parsed_inputs, is_training=False) assert len(model._trg_modality.trainable_weights) == 1 assert "shared_symbol_modality" in model._trg_modality.trainable_weights[0].name assert len(model._src_modality.trainable_weights) == 1 assert "shared_symbol_modality" in model._src_modality.trainable_weights[0].name assert model._output_linear_layer is not None params["modality.share_embedding_and_softmax_weights"] = True params["modality.share_source_target_embedding"] = False model = build_model({"model.class": "transformer", "params": params}, src_meta=src_vocab_meta, trg_meta=src_vocab_meta) _ = model(parsed_inputs, is_training=False) assert len(model._trg_modality.trainable_weights) == 2 for w in model._trg_modality.trainable_weights: if "weights" in w.name: assert "target_symbol_modality" in w.name assert len(model._src_modality.trainable_weights) == 1 assert "input_symbol_modality" in model._src_modality.trainable_weights[0].name assert model._output_linear_layer is None params["modality.share_embedding_and_softmax_weights"] = False params["modality.share_source_target_embedding"] = False model = build_model({"model.class": "transformer", "params": params}, src_meta=src_vocab_meta, trg_meta=src_vocab_meta) _ = model(parsed_inputs, is_training=False) assert len(model._trg_modality.trainable_weights) == 1 assert "target_symbol_modality" in model._trg_modality.trainable_weights[0].name assert len(model._src_modality.trainable_weights) == 1 assert "input_symbol_modality" in model._src_modality.trainable_weights[0].name assert model._output_linear_layer is not None
def test_seq2seq(): params = copy.deepcopy( get_hyper_parameters("transformer_toy")["model.params"]) params["modality.source.dim"] = None params["modality.target.dim"] = None params["modality.source.timing"] = None params["modality.target.timing"] = None params["encoder.num_layers"] = 1 params["decoder.num_layers"] = 1 src_vocab_meta = dict(vocab_size=8, eos_id=7, bos_id=6, unk_id=5) trg_vocab_meta = dict(vocab_size=5, eos_id=4, bos_id=3, unk_id=2) pt_inps = { "src": torch.LongTensor([[0, 1, 1, 7], [1, 7, 7, 7]]), "src_padding": torch.FloatTensor([[0, 0, 0, 0.], [0, 0, 1, 1.]]), "trg_input": torch.LongTensor([[3, 0, 1], [3, 2, 4]]), "trg": torch.LongTensor([[0, 1, 4], [2, 4, 4]]), "trg_padding": torch.FloatTensor([[0, 0, 0.], [0, 0, 1.]]), } tf_inps = { "src": tf.convert_to_tensor([[0, 1, 1, 7], [1, 7, 7, 7]], tf.int64), "src_padding": tf.convert_to_tensor([[0, 0, 0, 0.], [0, 0, 1, 1.]], tf.float32), "trg_input": tf.convert_to_tensor([[3, 0, 1], [3, 2, 4]], tf.int32), "trg": tf.convert_to_tensor([[0, 1, 4], [2, 4, 4]], tf.int32), "trg_padding": tf.convert_to_tensor([[0, 0, 0.], [0, 0, 1.]], tf.float32), } pt_model: Transformer = build_pt_model( { "model.class": "transformer", "params": params }, src_meta=src_vocab_meta, trg_meta=trg_vocab_meta) tf_model: TFTransformer = build_model( { "model.class": "transformer", "params": params }, src_meta=src_vocab_meta, trg_meta=trg_vocab_meta) pt_model._src_modality.embedding_layer._shared_weights.data = torch.FloatTensor( tf_model._src_modality.embedding_layer._shared_weights.numpy()) pt_model._trg_modality.embedding_layer._shared_weights.data = torch.FloatTensor( tf_model._trg_modality.embedding_layer._shared_weights.numpy()) pt_model._trg_modality.embedding_layer._bias.data = torch.FloatTensor( tf_model._trg_modality.embedding_layer._bias.numpy()) pt_model._encoder._output_norm_layer.weight.data = torch.FloatTensor( tf_model._encoder._output_norm_layer.gamma.numpy()) pt_model._encoder._output_norm_layer.bias.data = torch.FloatTensor( tf_model._encoder._output_norm_layer.beta.numpy()) pt_model._encoder._stacking_layers[0][ 0]._layer._qkv_transform_layer._kernel.data = torch.FloatTensor( tf_model._encoder._stacking_layers[0] [0]._layer._qkv_transform_layer._kernel.numpy()) pt_model._encoder._stacking_layers[0][ 0]._layer._qkv_transform_layer._bias.data = torch.FloatTensor( tf_model._encoder._stacking_layers[0] [0]._layer._qkv_transform_layer._bias.numpy()) pt_model._encoder._stacking_layers[0][ 0]._layer._output_transform_layer._kernel.data = torch.FloatTensor( tf_model._encoder._stacking_layers[0] [0]._layer._output_transform_layer._kernel.numpy()) pt_model._encoder._stacking_layers[0][ 0]._layer._output_transform_layer._bias.data = torch.FloatTensor( tf_model._encoder._stacking_layers[0] [0]._layer._output_transform_layer._bias.numpy()) pt_model._encoder._stacking_layers[0][ 1]._layer._dense1.weight.data = torch.FloatTensor( tf_model._encoder._stacking_layers[0] [1]._layer._conv1.kernel.numpy().transpose([1, 0])) pt_model._encoder._stacking_layers[0][ 1]._layer._dense1.bias.data = torch.FloatTensor( tf_model._encoder._stacking_layers[0] [1]._layer._conv1.bias.numpy()) pt_model._encoder._stacking_layers[0][ 1]._layer._dense2.weight.data = torch.FloatTensor( tf_model._encoder._stacking_layers[0] [1]._layer._conv2.kernel.numpy().transpose([1, 0])) pt_model._encoder._stacking_layers[0][ 1]._layer._dense2.bias.data = torch.FloatTensor( tf_model._encoder._stacking_layers[0] [1]._layer._conv2.bias.numpy()) pt_model._encoder._stacking_layers[0][ 0]._norm_layer.weight.data = torch.FloatTensor( tf_model._encoder._stacking_layers[0][0]._norm_layer.gamma.numpy()) pt_model._encoder._stacking_layers[0][ 0]._norm_layer.bias.data = torch.FloatTensor( tf_model._encoder._stacking_layers[0][0]._norm_layer.beta.numpy()) pt_model._encoder._stacking_layers[0][ 1]._norm_layer.weight.data = torch.FloatTensor( tf_model._encoder._stacking_layers[0][1]._norm_layer.gamma.numpy()) pt_model._encoder._stacking_layers[0][ 1]._norm_layer.bias.data = torch.FloatTensor( tf_model._encoder._stacking_layers[0][1]._norm_layer.beta.numpy()) pt_model._decoder._output_norm_layer.weight.data = torch.FloatTensor( tf_model._decoder._output_norm_layer.gamma.numpy()) pt_model._decoder._output_norm_layer.bias.data = torch.FloatTensor( tf_model._decoder._output_norm_layer.beta.numpy()) pt_model._decoder._stacking_layers[0][ 0]._layer._qkv_transform_layer._kernel.data = torch.FloatTensor( tf_model._decoder._stacking_layers[0] [0]._layer._qkv_transform_layer._kernel.numpy()) pt_model._decoder._stacking_layers[0][ 0]._layer._qkv_transform_layer._bias.data = torch.FloatTensor( tf_model._decoder._stacking_layers[0] [0]._layer._qkv_transform_layer._bias.numpy()) pt_model._decoder._stacking_layers[0][ 0]._layer._output_transform_layer._kernel.data = torch.FloatTensor( tf_model._decoder._stacking_layers[0] [0]._layer._output_transform_layer._kernel.numpy()) pt_model._decoder._stacking_layers[0][ 0]._layer._output_transform_layer._bias.data = torch.FloatTensor( tf_model._decoder._stacking_layers[0] [0]._layer._output_transform_layer._bias.numpy()) pt_model._decoder._stacking_layers[0][ 0]._norm_layer.weight.data = torch.FloatTensor( tf_model._decoder._stacking_layers[0][0]._norm_layer.gamma.numpy()) pt_model._decoder._stacking_layers[0][ 0]._norm_layer.bias.data = torch.FloatTensor( tf_model._decoder._stacking_layers[0][0]._norm_layer.beta.numpy()) pt_model._decoder._stacking_layers[0][ 1]._layer._q_transform_layer._kernel.data = torch.FloatTensor( tf_model._decoder._stacking_layers[0] [1]._layer._q_transform_layer._kernel.numpy()) pt_model._decoder._stacking_layers[0][ 1]._layer._q_transform_layer._bias.data = torch.FloatTensor( tf_model._decoder._stacking_layers[0] [1]._layer._q_transform_layer._bias.numpy()) pt_model._decoder._stacking_layers[0][ 1]._layer._kv_transform_layer._kernel.data = torch.FloatTensor( tf_model._decoder._stacking_layers[0] [1]._layer._kv_transform_layer._kernel.numpy()) pt_model._decoder._stacking_layers[0][ 1]._layer._kv_transform_layer._bias.data = torch.FloatTensor( tf_model._decoder._stacking_layers[0] [1]._layer._kv_transform_layer._bias.numpy()) pt_model._decoder._stacking_layers[0][ 1]._layer._output_transform_layer._kernel.data = torch.FloatTensor( tf_model._decoder._stacking_layers[0] [1]._layer._output_transform_layer._kernel.numpy()) pt_model._decoder._stacking_layers[0][ 1]._layer._output_transform_layer._bias.data = torch.FloatTensor( tf_model._decoder._stacking_layers[0] [1]._layer._output_transform_layer._bias.numpy()) pt_model._decoder._stacking_layers[0][ 1]._norm_layer.weight.data = torch.FloatTensor( tf_model._decoder._stacking_layers[0][1]._norm_layer.gamma.numpy()) pt_model._decoder._stacking_layers[0][ 1]._norm_layer.bias.data = torch.FloatTensor( tf_model._decoder._stacking_layers[0][1]._norm_layer.beta.numpy()) pt_model._decoder._stacking_layers[0][ 2]._layer._dense1.weight.data = torch.FloatTensor( tf_model._decoder._stacking_layers[0] [2]._layer._conv1.kernel.numpy().transpose([1, 0])) pt_model._decoder._stacking_layers[0][ 2]._layer._dense1.bias.data = torch.FloatTensor( tf_model._decoder._stacking_layers[0] [2]._layer._conv1.bias.numpy()) pt_model._decoder._stacking_layers[0][ 2]._layer._dense2.weight.data = torch.FloatTensor( tf_model._decoder._stacking_layers[0] [2]._layer._conv2.kernel.numpy().transpose([1, 0])) pt_model._decoder._stacking_layers[0][ 2]._layer._dense2.bias.data = torch.FloatTensor( tf_model._decoder._stacking_layers[0] [2]._layer._conv2.bias.numpy()) pt_model._decoder._stacking_layers[0][ 2]._norm_layer.weight.data = torch.FloatTensor( tf_model._decoder._stacking_layers[0][2]._norm_layer.gamma.numpy()) pt_model._decoder._stacking_layers[0][ 2]._norm_layer.bias.data = torch.FloatTensor( tf_model._decoder._stacking_layers[0][2]._norm_layer.beta.numpy()) assert_equal_numpy( tf_model(tf_inps, is_training=False).numpy(), pt_model(pt_inps, is_training=False).detach().numpy(), 5e-6)