Ejemplo n.º 1
0
 def __init__(self, optimizer, **kwargs):
     super(OptimizerWrapper, self).__init__(**kwargs)
     self.optimizer = optimizer
     self._optimizer_attributes = []
     for k, v in get_all_attributes(self.optimizer).items():
         if k not in dir(self):
             setattr(self, k, v)
             self._optimizer_attributes.append(k)
Ejemplo n.º 2
0
import json
import numpy as np
import codecs
from bert4keras.backend import keras, set_gelu
from bert4keras.tokenizer import SpTokenizer
from bert4keras.bert import build_bert_model
from bert4keras.optimizers import Adam, extend_with_piecewise_linear_lr
from bert4keras.snippets import sequence_padding, get_all_attributes

locals().update(get_all_attributes(keras.layers))
set_gelu('tanh')

maxlen = 256
config_path = 'models/albert_base/albert_config.json'
checkpoint_path = 'models/albert_base/variables/variables'
spm_path = 'models/albert_base/assets/30k-clean.model'


def load_data(filename):
    D = []
    with codecs.open(filename, encoding='utf-8') as f:
        for l in f:
            text, label = l.strip().split('\t')
            D.append((text, int(label)))
    return D


train_data = load_data('datasets/IMDB_trainshuffle.data')
valid_data = load_data('datasets/IMDB_valshuffle.data')
test_data = load_data('datasets/IMDB_testshuffle.data')
Ejemplo n.º 3
0
#! -*- coding:utf-8 -*-
# 情感分析例子,加载albert_zh权重(https://github.com/brightmart/albert_zh)

import json
import numpy as np
import codecs
from bert4keras.backend import keras, set_gelu
from bert4keras.tokenizer import Tokenizer
from bert4keras.bert import build_bert_model
from bert4keras.optimizers import Adam, extend_with_piecewise_linear_lr
from bert4keras.snippets import sequence_padding, get_all_attributes

locals().update(get_all_attributes(keras.layers))  # from keras.layers import *
set_gelu('tanh')  # 切换gelu版本

maxlen = 128
config_path = '/root/kg/bert/albert_small_zh_google/albert_config.json'
checkpoint_path = '/root/kg/bert/albert_small_zh_google/albert_model.ckpt'
dict_path = '/root/kg/bert/albert_small_zh_google/vocab.txt'


def load_data(filename):
    D = []
    with codecs.open(filename, encoding='utf-8') as f:
        for l in f:
            text, label = l.strip().split('\t')
            D.append((text, int(label)))
    return D


# 加载数据集