Ejemplo n.º 1
0
def convert_vocab(args):
    print('converting vocab')
    merges_path = os.path.join(args.tf_model_path, 'vocab.bpe')
    vocab_path = os.path.join(args.tf_model_path, 'encoder.json')
    gluon_merges_path = os.path.join(args.save_dir, 'gpt2.merges')
    gluon_vocab_path = os.path.join(args.save_dir, 'gpt2.vocab')
    
    shutil.copy(merges_path, gluon_merges_path)
    with open(vocab_path, 'r', encoding='utf-8') as f_v:
        tf_vocab = json.load(f_v)
    tf_vocab = list(tf_vocab.items())
    tf_vocab = sorted(tf_vocab, key=lambda x: x[1])
    all_tokens = [e[0] for e in tf_vocab]
    eos_token = all_tokens[-1]
    assert eos_token == '<|endoftext|>'
    gluon_vocab = Vocab(all_tokens,
                        unk_token=None,
                        eos_token=eos_token)
    gluon_vocab.save(gluon_vocab_path)

    vocab_size = len(gluon_vocab)
    print('| converted dictionary: {} types'.format(vocab_size))
    return vocab_size
Ejemplo n.º 2
0
def test_vocab():
    def check_same_vocab(vocab1, vocab2):
        assert vocab1.all_tokens == vocab2.all_tokens
        assert len(vocab1._special_token_kv) == len(vocab2._special_token_kv)
        for k, v in vocab1._special_token_kv.items():
            assert v == vocab2._special_token_kv[k]
            assert getattr(vocab1, k) == getattr(vocab2, k)

    def check_consistency(vocab):
        for i, token in enumerate(vocab.all_tokens):
            assert vocab[token] == i
        if hasattr(vocab, 'unk_token'):
            assert vocab['some1234123dasf'] == vocab[vocab.unk_token]
        assert len(vocab) == len(vocab.all_tokens)
        if len(vocab.all_tokens) > 0:
            random_idx = [
                random.randint(0,
                               len(vocab.all_tokens) - 1) for _ in range(20)
            ]
            assert vocab.to_tokens(random_idx) == [
                vocab.all_tokens[i] for i in random_idx
            ]
            assert vocab.to_tokens(np.array(random_idx)) == [
                vocab.all_tokens[i] for i in random_idx
            ]
            random_tokens = vocab.to_tokens(random_idx)
            assert vocab[random_tokens] == random_idx
            if vocab.has_unk:
                assert vocab[random_tokens + ['213412hadhfk']]\
                       == random_idx + [vocab.unk_id]
            for k, v in vocab.special_tokens_kv.items():
                idx_property = k[:-6] + '_id'
                assert getattr(vocab, idx_property) == vocab[v]

        # Test for serialize/deserailze from json
        json_str = vocab.to_json()
        new_vocab = Vocab.from_json(json_str)
        check_same_vocab(new_vocab, vocab)
        # Test for save/load from file
        while True:
            fname = '{}.json'.format(uuid.uuid4())
            if os.path.exists(fname):
                continue
            vocab.save(path=fname)
            new_vocab = Vocab.load(fname)
            check_same_vocab(new_vocab, vocab)
            os.remove(fname)
            break

    words = ['a', 'a', 'b', 'd', 'c', 'b', 'a', 'c', 'd', 'd', 'd']
    random.shuffle(words)
    counter = collections.Counter(words)
    vocab = Vocab(counter, max_size=2, min_freq=None)
    check_consistency(vocab)
    assert vocab.all_tokens == ['d', 'a', '<unk>']
    # Test for unknown token
    vocab = Vocab(tokens=counter,
                  max_size=2,
                  min_freq=None,
                  unk_token='<unk2>')
    check_consistency(vocab)
    assert vocab.all_tokens == ['d', 'a', '<unk2>']

    vocab = Vocab(tokens=counter,
                  max_size=None,
                  min_freq=None,
                  pad_token=Vocab.PAD_TOKEN,
                  eos_token=Vocab.EOS_TOKEN,
                  bos_token=Vocab.BOS_TOKEN,
                  cls_token=Vocab.CLS_TOKEN,
                  sep_token=Vocab.SEP_TOKEN,
                  mask_token=Vocab.MASK_TOKEN)
    check_consistency(vocab)
    assert vocab.unk_token == Vocab.UNK_TOKEN
    assert vocab.pad_token == Vocab.PAD_TOKEN
    assert vocab.eos_token == Vocab.EOS_TOKEN
    assert vocab.bos_token == Vocab.BOS_TOKEN
    assert vocab.cls_token == Vocab.CLS_TOKEN
    assert vocab.sep_token == Vocab.SEP_TOKEN
    assert vocab.mask_token == Vocab.MASK_TOKEN
    assert vocab.special_token_keys == [
        'unk_token', 'bos_token', 'cls_token', 'eos_token', 'mask_token',
        'pad_token', 'sep_token'
    ]
    assert vocab.special_tokens == [
        '<unk>', '<bos>', '<cls>', '<eos>', '<mask>', '<pad>', '<sep>'
    ]
    assert vocab.all_tokens == [
        'd', 'a', 'c', 'b', '<unk>', '<bos>', '<cls>', '<eos>', '<mask>',
        '<pad>', '<sep>'
    ]

    vocab = Vocab(counter,
                  bos_token=Vocab.BOS_TOKEN,
                  eos_token=Vocab.EOS_TOKEN,
                  pad_token=Vocab.PAD_TOKEN)
    check_consistency(vocab)
    assert vocab.all_tokens == [
        'd', 'a', 'c', 'b', '<unk>', '<bos>', '<eos>', '<pad>'
    ]

    vocab = Vocab(counter,
                  max_size=None,
                  min_freq=None,
                  pad_token=Vocab.PAD_TOKEN,
                  eos_token=Vocab.EOS_TOKEN,
                  bos_token=Vocab.BOS_TOKEN,
                  mask_token='<mask2>',
                  other3_token='<other3>',
                  other2_token='<other2>')
    check_consistency(vocab)
    assert vocab.all_tokens == [
        'd', 'a', 'c', 'b', '<unk>', '<bos>', '<eos>', '<mask2>', '<other2>',
        '<other3>', '<pad>'
    ]
    assert vocab.mask_token == '<mask2>'
    assert vocab.other2_token == '<other2>'
    assert vocab.other3_token == '<other3>'
    assert vocab.special_token_keys == [
        'unk_token', 'bos_token', 'eos_token', 'mask_token', 'other2_token',
        'other3_token', 'pad_token'
    ]
    assert vocab.special_tokens == [
        '<unk>', '<bos>', '<eos>', '<mask2>', '<other2>', '<other3>', '<pad>'
    ]

    vocab = Vocab(counter, max_size=1, min_freq=10000, unk_token=None)
    check_consistency(vocab)
    assert vocab.all_tokens == []

    vocab = Vocab([],
                  pad_token=Vocab.PAD_TOKEN,
                  eos_token=Vocab.EOS_TOKEN,
                  bos_token=Vocab.BOS_TOKEN,
                  mask_token='<mask2>')
    check_consistency(vocab)
    assert vocab.all_tokens == ['<unk>', '<bos>', '<eos>', '<mask2>', '<pad>']
    vocab = Vocab(pad_token=Vocab.PAD_TOKEN,
                  eos_token=Vocab.EOS_TOKEN,
                  bos_token=Vocab.BOS_TOKEN,
                  mask_token='<mask2>')
    check_consistency(vocab)
    assert vocab.all_tokens == ['<unk>', '<bos>', '<eos>', '<mask2>', '<pad>']

    vocab = Vocab(['<unk2>', '<pad>', '<bos>', '<eos>', '<mask>', 'a'],
                  pad_token=Vocab.PAD_TOKEN,
                  eos_token=Vocab.EOS_TOKEN,
                  bos_token=Vocab.BOS_TOKEN,
                  mask_token='<mask>')
    check_consistency(vocab)
    assert vocab.all_tokens == [
        '<unk2>', '<pad>', '<bos>', '<eos>', '<mask>', 'a', '<unk>'
    ]
    assert vocab.special_tokens == [
        '<pad>', '<bos>', '<eos>', '<mask>', '<unk>'
    ]
    assert vocab.special_token_keys == [
        'pad_token', 'bos_token', 'eos_token', 'mask_token', 'unk_token'
    ]

    # Check errors
    with pytest.raises(ValueError):
        vocab = Vocab(['a', 'a', 'a'])
    with pytest.raises(ValueError):
        vocab = Vocab(['a', 'b', 'c'],
                      mask_token='<mask>',
                      another_mask_token='<mask>')
    with pytest.raises(ValueError):
        vocab = Vocab(['a', 'b', 'c'],
                      mask_token='<mask>',
                      another_mask_token='<mask>')
    vocab = Vocab(['a', 'b', 'c'])
    check_consistency(vocab)

    # Check emoji
    all_tokens = ['<unk>', '😁']
    vocab = Vocab(all_tokens, unk_token='<unk>')
    vocab_file = str(uuid.uuid4()) + '.vocab'
    vocab.save(vocab_file)
    vocab = Vocab.load(vocab_file)
    assert vocab.all_tokens == all_tokens
    os.remove(vocab_file)
def convert_tf_model(hub_model_dir, save_dir, test_conversion, model_type):
    # set up the model type to be converted
    if model_type == 'bert':
        if args.torch:
            PretrainedModel, PretrainedMLMModel = ThBertModel, ThBertForMLM
        else:
            PretrainedModel, PretrainedMLMModel = BertModel, BertForMLM
    elif model_type == 'albert' and not args.torch:
        PretrainedModel, PretrainedMLMModel = AlbertModel, AlbertForMLM
    else:
        raise NotImplementedError

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    cfg, vocab_path, spm_model_path = convert_tf_assets(
        os.path.join(hub_model_dir, 'assets'), model_type)
    with open(os.path.join(save_dir, 'model.yml'), 'w') as of:
        of.write(cfg.dump())
    if spm_model_path:
        # Sentencepiece Tokenizer that used in albert model
        tokenizer = SentencepieceTokenizer(spm_model_path)
        new_vocab = Vocab(tokenizer.vocab.all_tokens,
                          unk_token='<unk>',
                          pad_token='<pad>',
                          cls_token='[CLS]',
                          sep_token='[SEP]',
                          mask_token='[MASK]')
        shutil.copy(spm_model_path, os.path.join(save_dir, 'spm.model'))
    elif vocab_path:
        # Wordpiece Tokenizer that used in bert and electra model

        # In this step, the vocabulary is converted with the help of the tokenizer,
        # so whether tokenzier is case-dependent does not matter.
        new_vocab = HuggingFaceWordPieceTokenizer(vocab_file=vocab_path,
                                                  unk_token='[UNK]',
                                                  pad_token='[PAD]',
                                                  cls_token='[CLS]',
                                                  sep_token='[SEP]',
                                                  mask_token='[MASK]',
                                                  lowercase=True).vocab

    new_vocab.save(os.path.join(save_dir, 'vocab.json'))

    # test input data
    batch_size = 2
    seq_length = 16
    num_mask = 5
    input_ids = np.random.randint(0, cfg.MODEL.vocab_size,
                                  (batch_size, seq_length))
    valid_length = np.random.randint(seq_length // 2, seq_length,
                                     (batch_size, ))
    input_mask = np.broadcast_to(np.arange(seq_length).reshape(1, -1), (batch_size, seq_length)) \
        < np.expand_dims(valid_length, 1)
    segment_ids = np.random.randint(0, 2, (batch_size, seq_length))
    mlm_positions = np.random.randint(0, seq_length // 2,
                                      (batch_size, num_mask))
    TF1_Hub_Modules = True
    try:
        tf_model = hub.Module(hub_model_dir, trainable=True)
        # see https://www.tensorflow.org/hub/tf1_hub_module for details
        logging.info('The model is loaded as the TF1 Hub Model')
        tf_input_ids = tf.constant(input_ids, dtype=np.int32)
        tf_input_mask = tf.constant(input_mask, dtype=np.int32)
        tf_segment_ids = tf.constant(segment_ids, dtype=np.int32)
        tf_mlm_positions = tf.constant(mlm_positions, dtype=np.int32)
        tf_mlm_outputs = tf_model(dict(input_ids=tf_input_ids,
                                       input_mask=tf_input_mask,
                                       segment_ids=tf_segment_ids,
                                       mlm_positions=tf_mlm_positions),
                                  signature="mlm",
                                  as_dict=True)
        tf_token_outputs = tf_model(dict(input_ids=tf_input_ids,
                                         input_mask=tf_input_mask,
                                         segment_ids=tf_segment_ids),
                                    signature="tokens",
                                    as_dict=True)
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            tf_params = sess.run(tf_model.variable_map)
            tf_token_outputs_np = sess.run(tf_token_outputs)
            tf_mlm_outputs_np = sess.run(tf_mlm_outputs)
    except RuntimeError as _:
        logging.warning(
            'The provided model directory is not valid for TF1 Hub Modules. '
            'Now try to load as TF2 SavedModels')
        bert_layer = hub.KerasLayer(hub_model_dir, trainable=True)
        # see https://www.tensorflow.org/hub/tf2_saved_model for details
        logging.info('The model is loaded as the TF2 SavedModel')
        TF1_Hub_Modules = False
        input_word_ids = tf.keras.layers.Input(shape=(seq_length, ),
                                               dtype=tf.int32,
                                               name="input_word_ids")
        input_word_mask = tf.keras.layers.Input(shape=(seq_length, ),
                                                dtype=tf.int32,
                                                name="input_mask")
        segment_type_ids = tf.keras.layers.Input(shape=(seq_length, ),
                                                 dtype=tf.int32,
                                                 name="segment_ids")
        pooled_output, sequence_output = bert_layer(
            [input_word_ids, input_word_mask, segment_type_ids])
        tf_model = tf.keras.Model(
            inputs=[input_word_ids, input_word_mask, segment_type_ids],
            outputs=[pooled_output, sequence_output])
        tf_params = {}
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            pooled_output, sequence_output = tf_model.predict(
                [input_ids, input_mask, segment_ids])
            tf_token_outputs_np = {
                'pooled_output': pooled_output,
                'sequence_output': sequence_output
            }
            # The name of the parameters in TF2 SavedModel are ending with ':0'
            # like 'bert_model/word_embeddings/embeddings_2:0'
            tf_params = {
                v.name.split(":")[0]: v.read_value()
                for v in tf_model.variables
            }
            tf_params = sess.run(tf_params)

    if USE_TF_V1 and TF1_Hub_Modules:
        tf_params_by_read = read_tf_checkpoint(
            os.path.join(hub_model_dir, 'variables', 'variables'))
        for k in tf_params:
            assert_allclose(tf_params[k], tf_params_by_read[k])

    # Get parameter names for Tensorflow with unused parameters filtered out.
    tf_names = sorted(tf_params.keys())
    tf_names = filter(lambda name: not name.endswith('adam_m'), tf_names)
    tf_names = filter(lambda name: not name.endswith('adam_v'), tf_names)
    tf_names = filter(lambda name: name != 'Variable', tf_names)
    tf_names = filter(lambda name: name != 'global_step', tf_names)
    tf_names = list(tf_names)

    # Build gluon model and initialize
    # TODO leezu
    # cfg.defrost()
    # cfg.MODEL.hidden_dropout_prob = 0.0
    # cfg.MODEL.attention_dropout_prob = 0.0
    # cfg.freeze()
    gluon_model = PretrainedModel.from_cfg(cfg, use_pooler=True)
    if args.torch:
        gluon_model = gluon_model.to(args.device)
        gluon_model.eval()
    else:
        gluon_model.initialize(ctx=args.ctx)
        gluon_model.hybridize()
    gluon_mlm_model = PretrainedMLMModel(backbone_cfg=cfg)
    if args.torch:
        gluon_mlm_model = gluon_mlm_model.to(args.device)
        gluon_mlm_model.backbone_model.to(args.device)
        gluon_mlm_model.eval()
    else:
        gluon_mlm_model.initialize(ctx=args.ctx)
        gluon_mlm_model.hybridize()

    # Pepare test data
    if args.torch:
        input_ids = th.from_numpy(input_ids).to(args.device)
        valid_length = th.from_numpy(valid_length).to(args.device)
        token_types = th.from_numpy(segment_ids).to(args.device)
        masked_positions = th.from_numpy(mlm_positions).to(args.device)
    else:
        input_ids = mx.np.array(input_ids, dtype=np.int32, ctx=args.ctx)
        valid_length = mx.np.array(valid_length, dtype=np.int32, ctx=args.ctx)
        token_types = mx.np.array(segment_ids, dtype=np.int32, ctx=args.ctx)
        masked_positions = mx.np.array(mlm_positions,
                                       dtype=np.int32,
                                       ctx=args.ctx)

    # start converting for 'backbone' and 'mlm' model.
    # However sometimes there is no mlm parameter in Tf2 SavedModels like bert wmm large
    if any(['cls' in name for name in tf_names]):
        has_mlm = True
    else:
        has_mlm = False
        logging.info(
            'There is no mask language model parameter in this pretrained model'
        )
    name_map = get_name_map(tf_names, is_TF1=TF1_Hub_Modules)
    # go through the gluon model to infer the shape of parameters
    if has_mlm:
        model = gluon_mlm_model
        contextual_embedding, pooled_output, mlm_scores = \
            model(input_ids, token_types, valid_length, masked_positions)
    else:
        model = gluon_model
        contextual_embedding, pooled_output = model(input_ids, token_types,
                                                    valid_length)

    # replace tensorflow parameter names with gluon parameter names
    params = {n: p
              for n, p in model.named_parameters()
              } if args.torch else model.collect_params()
    all_keys = set(params.keys())
    for (src_name, dst_name) in name_map.items():
        tf_param_val = tf_params[src_name]
        if dst_name is None:
            continue
        if args.torch and dst_name == 'mlm_decoder.3.weight':  # shared weight
            continue
        all_keys.remove(dst_name)
        if 'self_attention/attention_output/kernel' in src_name:
            if args.torch:
                params[dst_name].data = th.from_numpy(
                    tf_param_val.reshape(
                        (cfg.MODEL.units, -1)).T).contiguous()
            else:
                params[dst_name].set_data(tf_param_val.T)
        elif src_name.endswith('kernel'):
            if args.torch:
                params[dst_name].data = th.from_numpy(
                    tf_param_val.T).contiguous()
            else:
                params[dst_name].set_data(tf_param_val.T)
        else:
            if args.torch:
                params[dst_name].data = th.from_numpy(
                    tf_param_val).contiguous()
            else:
                params[dst_name].set_data(tf_param_val)

    # Merge query/kernel, key/kernel, value/kernel to encoder.all_encoder_groups.0.attn_qkv.weight
    def convert_qkv_weights(tf_prefix, prefix, is_mlm):
        """
        To convert the qkv weights with different prefix.

        In tensorflow framework, the prefix of query/key/value for the albert model is
        'bert/encoder/transformer/group_0/inner_group_0/attention_1/self/query/kernel',
        and that for the bert model is 'bert/encoder/layer_{}/attention/self/key/bias'.
        In gluonnlp framework, the prefix is slightly different as
        'encoder.all_encoder_groups.0.attn_qkv.weight' for albert model and
        'encoder.all_layers.{}.attn_qkv.weight' for bert model, as the
        curly braces {} can be filled with the layer number.
        """
        query_weight = tf_params['{}/query/kernel'.format(tf_prefix)]
        key_weight = tf_params['{}/key/kernel'.format(tf_prefix)]
        value_weight = tf_params['{}/value/kernel'.format(tf_prefix)]
        query_bias = tf_params['{}/query/bias'.format(tf_prefix)]
        key_bias = tf_params['{}/key/bias'.format(tf_prefix)]
        value_bias = tf_params['{}/value/bias'.format(tf_prefix)]
        if 'self_attention' in tf_prefix:
            query_weight = query_weight.reshape((cfg.MODEL.units, -1))
            key_weight = key_weight.reshape((cfg.MODEL.units, -1))
            value_weight = value_weight.reshape((cfg.MODEL.units, -1))
            query_bias = query_bias.reshape((-1, ))
            key_bias = key_bias.reshape((-1, ))
            value_bias = value_bias.reshape((-1, ))
        # Merge query_weight, key_weight, value_weight to params
        weight_name = 'encoder.{}.attn_qkv.weight'.format(prefix)
        bias_name = 'encoder.{}.attn_qkv.bias'.format(prefix)
        if is_mlm:
            weight_name = 'backbone_model.' + weight_name
            bias_name = 'backbone_model.' + bias_name
        if args.torch:
            params[weight_name].data = th.from_numpy(
                np.concatenate([query_weight, key_weight, value_weight],
                               axis=1).T).contiguous()
        else:
            params[weight_name].set_data(
                np.concatenate([query_weight, key_weight, value_weight],
                               axis=1).T)
        all_keys.remove(weight_name)
        # Merge query_bias, key_bias, value_bias to params
        if args.torch:
            params[bias_name].data = th.from_numpy(
                np.concatenate([query_bias, key_bias, value_bias],
                               axis=0)).contiguous()
        else:
            params[bias_name].set_data(
                np.concatenate([query_bias, key_bias, value_bias], axis=0))
        all_keys.remove(bias_name)

    tf_prefix = None
    if not args.torch and has_mlm:
        all_keys.remove('mlm_decoder.3.weight')
    if model_type == 'bert':
        assert all([
            re.match(
                r'^(backbone_model\.){0,1}encoder\.all_layers\.[\d]+\.attn_qkv\.(weight|bias)$',
                key) is not None for key in all_keys
        ])
        for layer_id in range(cfg.MODEL.num_layers):
            prefix = 'all_layers.{}'.format(layer_id)
            if TF1_Hub_Modules:
                tf_prefix = 'bert/encoder/layer_{}/attention/self'.format(
                    layer_id)
            else:
                tf_prefix = 'transformer/layer_{}/self_attention'.format(
                    layer_id)
            convert_qkv_weights(tf_prefix, prefix, has_mlm)
    elif model_type == 'albert':
        assert all([
            re.match(
                r'^(backbone_model\.){0,1}encoder\.all_encoder_groups\.0\.attn_qkv\.(weight|bias)$',
                key) is not None for key in all_keys
        ])
        prefix = 'all_encoder_groups.0'
        assert TF1_Hub_Modules, 'Please download the albert model from TF1 Hub'
        tf_prefix = 'bert/encoder/transformer/group_0/inner_group_0/attention_1/self'
        convert_qkv_weights(tf_prefix, prefix, has_mlm)
    else:
        raise NotImplementedError

    tolerance = 5E-4 if cfg.MODEL.num_layers == 24 else 1E-4
    # The pooled_output of albert large will have 0.5% mismatch under the tolerance of 1E-2,
    # for that we are going to use a small tolerance to pass the difference checking
    tolerance = 0.2 if 'albert_large' in args.tf_hub_model_path else tolerance

    assert len(
        all_keys
    ) == 0, f"The following torch parameters weren't assigned to: {all_keys}"

    def check_backbone(tested_model, tf_token_outputs_np):
        # test conversion results for backbone model
        tf_contextual_embedding = tf_token_outputs_np['sequence_output']
        tf_pooled_output = tf_token_outputs_np['pooled_output']
        contextual_embedding, pooled_output = \
            tested_model(input_ids, token_types, valid_length)
        if args.torch:
            assert_allclose(pooled_output.detach().cpu().numpy(),
                            tf_pooled_output, tolerance, tolerance)
        else:
            assert_allclose(pooled_output.asnumpy(), tf_pooled_output,
                            tolerance, tolerance)
        for i in range(batch_size):
            ele_valid_length = int(valid_length[i])
            if args.torch:
                assert_allclose(
                    contextual_embedding[
                        i, :ele_valid_length, :].detach().cpu().numpy(),
                    tf_contextual_embedding[i, :ele_valid_length, :],
                    tolerance, tolerance)
            else:
                assert_allclose(
                    contextual_embedding[i, :ele_valid_length, :].asnumpy(),
                    tf_contextual_embedding[i, :ele_valid_length, :],
                    tolerance, tolerance)

    if not has_mlm:
        if test_conversion:
            check_backbone(model, tf_token_outputs_np)
        th.save(model.state_dict(), os.path.join(save_dir, 'model.params'))
        logging.info('Convert the backbone model in {} to {}/{}'.format(
            hub_model_dir, save_dir, 'model.params'))
    else:
        # test conversion results for mlm model
        # TODO(zheyuye), figure out how to check the mlm model from TF2 SavedModel
        if test_conversion:
            backbone_model = model.backbone_model
            if args.torch:
                model = model.to(args.device)
                backbone_model = backbone_model.to(args.device)
            check_backbone(backbone_model, tf_mlm_outputs_np)
            if TF1_Hub_Modules:
                tf_contextual_embedding = tf_mlm_outputs_np['sequence_output']
                tf_pooled_output = tf_mlm_outputs_np['pooled_output']
                tf_mlm_scores = tf_mlm_outputs_np['mlm_logits'].reshape(
                    (batch_size, num_mask, -1))
                contextual_embedding, pooled_output, mlm_scores = \
                    model(input_ids, token_types, valid_length, masked_positions)
                if args.torch:
                    assert_allclose(pooled_output.detach().cpu().numpy(),
                                    tf_pooled_output, tolerance, tolerance)
                    assert_allclose(mlm_scores.detach().cpu().numpy(),
                                    tf_mlm_scores, tolerance, tolerance)
                else:
                    assert_allclose(pooled_output.asnumpy(), tf_pooled_output,
                                    tolerance, tolerance)
                    assert_allclose(mlm_scores.asnumpy(), tf_mlm_scores,
                                    tolerance, tolerance)
                for i in range(batch_size):
                    ele_valid_length = int(valid_length[i])
                    if args.torch:
                        assert_allclose(
                            contextual_embedding[i, :ele_valid_length, :].
                            detach().cpu().numpy(),
                            tf_contextual_embedding[i, :ele_valid_length, :],
                            tolerance, tolerance)
                    else:
                        assert_allclose(
                            contextual_embedding[
                                i, :ele_valid_length, :].asnumpy(),
                            tf_contextual_embedding[i, :ele_valid_length, :],
                            tolerance, tolerance)
        if args.torch:
            th.save(model.backbone_model.state_dict(),
                    os.path.join(save_dir, 'model.params'))
            th.save(model.state_dict(),
                    os.path.join(save_dir, 'model_mlm.params'))
        else:
            model.backbone_model.save_parameters(os.path.join(
                save_dir, 'model.params'),
                                                 deduplicate=True)
            model.save_parameters(os.path.join(save_dir, 'model_mlm.params'),
                                  deduplicate=True)
        logging.info('Convert the backbone model in {} to {}/{}'.format(
            hub_model_dir, save_dir, 'model.params'))
        logging.info('Convert the MLM model in {} to {}/{}'.format(
            hub_model_dir, save_dir, 'model_mlm.params'))

    # TODO(zheyuye) the gradient checking could be explored in further development

    logging.info('Conversion finished!')
    logging.info('Statistics:')

    old_names = os.listdir(save_dir)
    for old_name in old_names:
        new_name, long_hash = naming_convention(save_dir, old_name)
        old_path = os.path.join(save_dir, old_name)
        new_path = os.path.join(save_dir, new_name)
        shutil.move(old_path, new_path)
        file_size = os.path.getsize(new_path)
        logging.info('\t{}/{} {} {}'.format(save_dir, new_name, long_hash,
                                            file_size))
Ejemplo n.º 4
0
def main(args):
    # Download the data
    url = _URLS[args.dataset]
    file_hash = _URL_FILE_STATS[url]
    target_download_location = os.path.join(args.cache_path,
                                            os.path.basename(url))
    download(url, target_download_location, sha1_hash=file_hash)
    save_dir = args.dataset if args.save_dir is None else args.save_dir
    if not os.path.exists(save_dir):
        os.makedirs(save_dir, exist_ok=True)
    # Extract and process the data
    if args.dataset == 'wikitext2':
        with zipfile.ZipFile(target_download_location) as zf:
            train_data = zf.read('wikitext-2/wiki.train.tokens')
            valid_data = zf.read('wikitext-2/wiki.valid.tokens')
            test_data = zf.read('wikitext-2/wiki.test.tokens')
            for filename, part in [('train.txt', train_data),
                                   ('valid.txt', valid_data),
                                   ('test.txt', test_data)]:
                filename = os.path.join(save_dir, filename)
                print('{} will have {} bytes'.format(filename, len(part)))
                if not path_exist_and_skip(filename, args.overwrite):
                    with open(filename, 'wb') as of:
                        of.write(part)
            vocab = build_vocab([
                os.path.join(save_dir, 'train.txt'),
                os.path.join(save_dir, 'valid.txt'),
                os.path.join(save_dir, 'test.txt')
            ])
            vocab.save(os.path.join(save_dir, 'vocab.json'))
    elif args.dataset == 'wikitext103':
        with zipfile.ZipFile(target_download_location) as zf:
            train_data = zf.read('wikitext-103/wiki.train.tokens')
            valid_data = zf.read('wikitext-103/wiki.valid.tokens')
            test_data = zf.read('wikitext-103/wiki.test.tokens')
            for filename, part in [('train.txt', train_data),
                                   ('valid.txt', valid_data),
                                   ('test.txt', test_data)]:
                filename = os.path.join(save_dir, filename)
                if not path_exist_and_skip(filename, args.overwrite):
                    print('{} will have {} bytes'.format(filename, len(part)))
                    with open(filename, 'wb') as of:
                        of.write(part)
            vocab = build_vocab([os.path.join(save_dir, 'train.txt')])
            vocab.save(os.path.join(save_dir, 'vocab.json'))
    elif args.dataset == 'text8':
        with zipfile.ZipFile(target_download_location) as zf:
            with zf.open('text8', 'r') as f:
                data = f.read().decode('utf-8')
                num_test_chars = 5000000
                train_data = data[:-2 * num_test_chars]
                valid_data = data[-2 * num_test_chars:-num_test_chars]
                test_data = data[-num_test_chars:]
                for filename, part in [('train.txt', train_data),
                                       ('valid.txt', valid_data),
                                       ('test.txt', test_data)]:
                    filename = os.path.join(save_dir, filename)
                    print('{} will have {} bytes'.format(filename, len(part)))
                    print('- Tokenizing...')
                    # Change space ' ' to underscore '_'
                    part_str = ' '.join(
                        ['_' if c == ' ' else c for c in part.strip()])
                    print('- Writing...')
                    if not path_exist_and_skip(filename, args.overwrite):
                        with open(filename, 'w', encoding='utf-8') as of:
                            of.write(part_str)
                    if not path_exist_and_skip(filename + '.raw',
                                               args.overwrite):
                        with open(filename + '.raw', 'w',
                                  encoding='utf-8') as of:
                            of.write(part)
            vocab = build_vocab([os.path.join(save_dir, 'train.txt')],
                                eos_token=None)
            vocab.save(os.path.join(save_dir, 'vocab.json'))
    elif args.dataset == 'enwik8':
        with zipfile.ZipFile(target_download_location) as zf:
            data = zf.read('enwik8')
            print('Length of enwik8: {}'.format(len(data)))
            num_test_chars = 5000000
            train_data = data[:-2 * num_test_chars]
            valid_data = data[-2 * num_test_chars:-num_test_chars]
            test_data = data[-num_test_chars:]

            for filename, part in [('train.txt', train_data),
                                   ('valid.txt', valid_data),
                                   ('test.txt', test_data)]:
                filename = os.path.join(save_dir, filename)
                print('{} will have {} bytes'.format(filename, len(part)))
                print('- Tokenizing...')
                part_str = ' '.join(
                    [str(c) if c != ord('\n') else '\n' for c in part])
                print('- Writing...')
                if not path_exist_and_skip(filename, args.overwrite):
                    with open(filename, 'w') as of:
                        of.write(part_str)
                if not path_exist_and_skip(filename + '.raw', args.overwrite):
                    with open(filename + '.raw', 'wb') as of:
                        of.write(part)
            vocab = build_vocab([os.path.join(save_dir, 'train.txt')],
                                eos_token=None)
            vocab.save(os.path.join(save_dir, 'vocab.json'))

    elif args.dataset == 'gbw':
        vocab_path = download(_URLS['gbw_vocab'],
                              os.path.join(args.cache_path,
                                           '1b_word_vocab.txt'),
                              sha1_hash=_URL_FILE_STATS[_URLS['gbw_vocab']])
        with tarfile.open(target_download_location) as f:
            os.makedirs(os.path.join(save_dir, 'train'), exist_ok=True)
            os.makedirs(os.path.join(save_dir, 'test'), exist_ok=True)
            for member in f.getmembers():
                if 'training-monolingual.tokenized.shuffled' in member.name \
                        and 'news.en' in member.name:
                    basename = os.path.basename(member.name)
                    with f.extractfile(member) as f_in:
                        with open(os.path.join(save_dir, 'train', basename),
                                  'wb') as f_out:
                            shutil.copyfileobj(f_in, f_out)
                elif 'heldout-monolingual.tokenized.shuffled' in member.name and \
                        '.heldout-' in member.name:
                    basename = os.path.basename(member.name)
                    with f.extractfile(member) as f_in:
                        with open(os.path.join(save_dir, 'test', basename),
                                  'wb') as f_out:
                            shutil.copyfileobj(f_in, f_out)
        all_tokens = []
        with open(vocab_path, 'r') as f:
            for token in f:
                token = token.strip().split()[0]
                all_tokens.append(token)
        vocab = Vocab(all_tokens, bos_token='<S>', unk_token='<UNK>')
        vocab.save(os.path.join(save_dir, 'vocab.json'))
        print('Saved Google-One-Billion-Word in {}'.format(save_dir))
        print('Vocab={}'.format(vocab))
    else:
        raise NotImplementedError