简化BERT模型的文本分类与预测工具
Ernie是一个基于BERT的Python库,为文本分类和预测任务提供简洁接口。它支持多种预训练模型,允许微调和自定义。Ernie具备灵活的文本分割和结果聚合策略,能够处理长文本,并提供模型保存、加载和自动保存功能。这个工具适用于情感分析、文本分类等多种自然语言处理任务,为NLP研究和开发提供了便捷的解决方案。
由 <a href="http://stickermule.com/supports/ernie20-sponsorship"><img src="https://yellow-cdn.veclightyear.com/0a4dffa0/979a9358-b1dc-42b2-bc97-aaab96b809ae.png" alt="Sticker Mule标志" width="80px"/></a> 赞助
Ernie需要Python 3.6或更高版本。
pip install ernie
<a href="https://colab.research.google.com/drive/10lmqZyAHFP_-x4LxIQxZCavYpPqcR28c"><img alt="在Colab中打开" src="https://yellow-cdn.veclightyear.com/0a4dffa0/4fbc9aaf-8519-402c-8ca9-edcc65802675.svg?style=flat-square"></a>
from ernie import SentenceClassifier, Models import pandas as pd tuples = [ ("这是一个积极的例子。我今天很开心。", 1), ("这是一个消极的句子。今天工作中一切都出错了。", 0) ] df = pd.DataFrame(tuples) classifier = SentenceClassifier( model_name=Models.BertBaseUncased, max_length=64, labels_no=2 ) classifier.load_dataset(df, validation_split=0.2) classifier.fine_tune( epochs=4, learning_rate=2e-5, training_batch_size=32, validation_batch_size=64 )
text = "哦,那太好了!" # 它返回一个包含预测结果的元组 probabilities = classifier.predict_one(text)
texts = ["哦,那太好了!", "那真是太糟糕了"] # 它返回一个包含预测结果的元组生成器 probabilities = classifier.predict(texts)
如果文本的标记长度超过模型微调时的 max_length
,它们将被截断。为避免信息丢失,你可以使用分割策略并以不同方式聚合预测结果。
SentencesWithoutUrls
。文本将被分割成句子。GroupedSentencesWithoutUrls
。文本将被分割成标记长度接近 max_length
的句子组。Mean
:文本的预测结果将是各分割部分预测结果的平均值。MeanTopFiveBinaryClassification
:仅对5个最高预测结果计算平均值。MeanTopTenBinaryClassification
:仅对10个最高预测结果计算平均值。MeanTopFifteenBinaryClassification
:仅对15个最高预测结果计算平均值 。MeanTopTwentyBinaryClassification
:仅对20个最高预测结果计算平均值。from ernie import SplitStrategies, AggregationStrategies texts = ["哦,那太棒了!", "那真是太糟糕了"] probabilities = classifier.predict( texts, split_strategy=SplitStrategies.GroupedSentencesWithoutUrls, aggregation_strategy=AggregationStrategies.Mean )
你可以通过 AggregationStrategy
和 SplitStrategy
类定义自定义策略。
from ernie import SplitStrategy, AggregationStrategy my_split_strategy = SplitStrategy( split_patterns: list, remove_patterns: list, remove_too_short_groups: bool, group_splits: bool ) my_aggregation_strategy = AggregationStrategy( method: function, max_items: int, top_items: bool, sorting_class_index: int )
classifier.dump('./model')
classifier = SentenceClassifier(model_path='./model')
由于执行可能在训练期间中断(尤其是在使用Google Colab时),你可以选择保存每个新训练的轮次,这样可以在不丢失所有进度的情况下恢复训练。
classifier = SentenceClassifier( model_name=Models.BertBaseUncased, max_length=64 ) classifier.load_dataset(df, validation_split=0.2) for epoch in range(1, 5): if epoch == 3: raise Exception("强制崩溃") classifier.fine_tune(epochs=1) classifier.dump(f'./my-model/{epoch}')
last_training_epoch = 2 classifier = SentenceClassifier(model_path=f'./my-model/{last_training_epoch}') classifier.load_dataset(df, validation_split=0.2) for epoch in range(last_training_epoch + 1, 5): classifier.fine_tune(epochs=1) classifier.dump(f'./my-model/{epoch}')
即使你没有显式地保存模型,每次成功执行 fine_tune
时,它都会自动保存到 ./ernie-autosave
中。
ernie-autosave/
└── model_family/
└── timestamp/
├── config.json
├── special_tokens_map.json
├── tf_model.h5
├── tokenizer_config.json
└── vocab.txt
你可以在结束一个会话或开始新会话时,通过调用 clean_autosave
轻松清理自动保存的模型。
from ernie import clean_autosave clean_autosave()
你可以通过 Models
类访问一些官方基础模型名称。但是,在实例化 SentenceClassifier
时,你可以直接输入HuggingFace的模型名称,如 bert-base-uncased
或 bert-base-chinese
。
在 huggingface.co/models 查看所有可用模型。