def transformer(model: str = 'xlnet', size: str = 'base', **kwargs): """ Load Transformer emotion model. Parameters ---------- model : str, optional (default='bert') Model architecture supported. Allowed values: * ``'bert'`` - BERT architecture from google. * ``'xlnet'`` - XLNET architecture from google. * ``'albert'`` - ALBERT architecture from google. size : str, optional (default='base') Model size supported. Allowed values: * ``'base'`` - BASE size. * ``'small'`` - SMALL size. Returns ------- BERT : malaya._models._bert_model.BINARY_BERT class """ model = model.lower() size = size.lower() if model not in _availability: raise Exception( 'model not supported, please check supported models from malaya.sentiment.available_transformer_model()' ) if size not in _availability[model]: raise Exception( 'size not supported, please check supported models from malaya.sentiment.available_transformer_model()' ) check_file(PATH_TOXIC[model][size], S3_PATH_TOXIC[model][size], **kwargs) g = load_graph(PATH_TOXIC[model][size]['model']) if model in ['albert', 'bert']: if model == 'bert': from ._transformer._bert import _extract_attention_weights_import if model == 'albert': from ._transformer._albert import _extract_attention_weights_import tokenizer, cls, sep = sentencepiece_tokenizer_bert( PATH_TOXIC[model][size]['tokenizer'], PATH_TOXIC[model][size]['vocab'], ) return SIGMOID_BERT( X=g.get_tensor_by_name('import/Placeholder:0'), segment_ids=None, input_masks=None, logits=g.get_tensor_by_name('import/logits:0'), logits_seq=g.get_tensor_by_name('import/logits_seq:0'), sess=generate_session(graph=g), tokenizer=tokenizer, label=_label_toxic, cls=cls, sep=sep, attns=_extract_attention_weights_import(bert_num_layers[size], g), class_name='toxic', ) if model in ['xlnet']: from ._transformer._xlnet import _extract_attention_weights_import tokenizer = sentencepiece_tokenizer_xlnet( PATH_TOXIC[model][size]['tokenizer']) return SIGMOID_XLNET( X=g.get_tensor_by_name('import/Placeholder:0'), segment_ids=g.get_tensor_by_name('import/Placeholder_1:0'), input_masks=g.get_tensor_by_name('import/Placeholder_2:0'), logits=g.get_tensor_by_name('import/logits:0'), logits_seq=g.get_tensor_by_name('import/logits_seq:0'), sess=generate_session(graph=g), tokenizer=tokenizer, label=_label_toxic, attns=_extract_attention_weights_import(g), class_name='toxic', )
def transformer(model: str = 'xlnet', quantized: bool = False, **kwargs): """ Load Transformer toxicity model. Parameters ---------- model : str, optional (default='bert') Model architecture supported. Allowed values: * ``'bert'`` - Google BERT BASE parameters. * ``'tiny-bert'`` - Google BERT TINY parameters. * ``'albert'`` - Google ALBERT BASE parameters. * ``'tiny-albert'`` - Google ALBERT TINY parameters. * ``'xlnet'`` - Google XLNET BASE parameters. * ``'alxlnet'`` - Malaya ALXLNET BASE parameters. quantized : bool, optional (default=False) if True, will load 8-bit quantized model. Quantized model not necessary faster, totally depends on the machine. Returns ------- result : malaya.model.bert.SIGMOID_BERT class """ model = model.lower() if model not in _transformer_availability: raise Exception( 'model not supported, please check supported models from `malaya.toxicity.available_transformer()`.' ) check_file( PATH_TOXIC[model], S3_PATH_TOXIC[model], quantized = quantized, **kwargs ) if quantized: model_path = 'quantized' else: model_path = 'model' g = load_graph(PATH_TOXIC[model][model_path], **kwargs) path = PATH_TOXIC if model in ['albert', 'bert', 'tiny-albert', 'tiny-bert']: if model in ['bert', 'tiny-bert']: from malaya.transformers.bert import ( _extract_attention_weights_import, ) from malaya.transformers.bert import bert_num_layers tokenizer = sentencepiece_tokenizer_bert( path[model]['tokenizer'], path[model]['vocab'] ) if model in ['albert', 'tiny-albert']: from malaya.transformers.albert import ( _extract_attention_weights_import, ) from malaya.transformers.albert import bert_num_layers from albert import tokenization tokenizer = tokenization.FullTokenizer( vocab_file = path[model]['vocab'], do_lower_case = False, spm_model_file = path[model]['tokenizer'], ) return SIGMOID_BERT( X = g.get_tensor_by_name('import/Placeholder:0'), segment_ids = None, input_masks = g.get_tensor_by_name('import/Placeholder_1:0'), logits = g.get_tensor_by_name('import/logits:0'), logits_seq = g.get_tensor_by_name('import/logits_seq:0'), vectorizer = g.get_tensor_by_name('import/dense/BiasAdd:0'), sess = generate_session(graph = g, **kwargs), tokenizer = tokenizer, label = label, attns = _extract_attention_weights_import( bert_num_layers[model], g ), class_name = 'toxic', ) if model in ['xlnet', 'alxlnet']: if model in ['xlnet']: from malaya.transformers.xlnet import ( _extract_attention_weights_import, ) if model in ['alxlnet']: from malaya.transformers.alxlnet import ( _extract_attention_weights_import, ) tokenizer = sentencepiece_tokenizer_xlnet(path[model]['tokenizer']) return SIGMOID_XLNET( X = g.get_tensor_by_name('import/Placeholder:0'), segment_ids = g.get_tensor_by_name('import/Placeholder_1:0'), input_masks = g.get_tensor_by_name('import/Placeholder_2:0'), logits = g.get_tensor_by_name('import/logits:0'), logits_seq = g.get_tensor_by_name('import/logits_seq:0'), vectorizer = g.get_tensor_by_name('import/transpose_3:0'), sess = generate_session(graph = g, **kwargs), tokenizer = tokenizer, label = label, attns = _extract_attention_weights_import(g), class_name = 'toxic', )
def transformer(model: str = 'xlnet', **kwargs): """ Load Transformer toxicity model. Parameters ---------- model : str, optional (default='bert') Model architecture supported. Allowed values: * ``'bert'`` - BERT architecture from google. * ``'tiny-bert'`` - BERT architecture from google with smaller parameters. * ``'albert'`` - ALBERT architecture from google. * ``'tiny-albert'`` - ALBERT architecture from google with smaller parameters. * ``'xlnet'`` - XLNET architecture from google. * ``'alxlnet'`` - XLNET architecture from google + Malaya. Returns ------- result : malaya.model.bert.SIGMOID_BERT class """ model = model.lower() if model not in _availability: raise Exception( 'model not supported, please check supported models from malaya.sentiment.available_transformer()' ) check_file(PATH_TOXIC[model], S3_PATH_TOXIC[model], **kwargs) g = load_graph(PATH_TOXIC[model]['model']) path = PATH_TOXIC if model in ['albert', 'bert', 'tiny-albert', 'tiny-bert']: if model in ['bert', 'tiny-bert']: from malaya.transformers.bert import ( _extract_attention_weights_import, ) from malaya.transformers.bert import bert_num_layers tokenizer = sentencepiece_tokenizer_bert(path[model]['tokenizer'], path[model]['vocab']) if model in ['albert', 'tiny-albert']: from malaya.transformers.albert import ( _extract_attention_weights_import, ) from malaya.transformers.albert import bert_num_layers from albert import tokenization tokenizer = tokenization.FullTokenizer( vocab_file=path[model]['vocab'], do_lower_case=False, spm_model_file=path[model]['tokenizer'], ) return SIGMOID_BERT( X=g.get_tensor_by_name('import/Placeholder:0'), segment_ids=None, input_masks=g.get_tensor_by_name('import/Placeholder_1:0'), logits=g.get_tensor_by_name('import/logits:0'), logits_seq=g.get_tensor_by_name('import/logits_seq:0'), sess=generate_session(graph=g), tokenizer=tokenizer, label=label, attns=_extract_attention_weights_import(bert_num_layers[model], g), class_name='toxic', ) if model in ['xlnet', 'alxlnet']: if model in ['xlnet']: from malaya.transformers.xlnet import ( _extract_attention_weights_import, ) if model in ['alxlnet']: from malaya.transformers.alxlnet import ( _extract_attention_weights_import, ) tokenizer = sentencepiece_tokenizer_xlnet(path[model]['tokenizer']) return SIGMOID_XLNET( X=g.get_tensor_by_name('import/Placeholder:0'), segment_ids=g.get_tensor_by_name('import/Placeholder_1:0'), input_masks=g.get_tensor_by_name('import/Placeholder_2:0'), logits=g.get_tensor_by_name('import/logits:0'), logits_seq=g.get_tensor_by_name('import/logits_seq:0'), sess=generate_session(graph=g), tokenizer=tokenizer, label=label, attns=_extract_attention_weights_import(g), class_name='toxic', )