Hugging Face + Spark:打造高效的 NLP 大数据处理引擎(一)

在自然语言处理(NLP)领域,Hugging Face 是不可或缺的处理库,而 Spark 则是大数据处理的必备工具。将两者的优势结合起来,可以实现高效的 NLP 大数据处理。以下是结合 Hugging Face 和 Spark 的两种方法,基于 Spark & PySpark 3.3.1 版本进行探索。

方法一:升级 Spark 版本至 3.4 及以上

如果你愿意升级 Spark 版本到 3.4 或更高版本,那么结合 Hugging Face 和 Spark 将变得非常方便。Spark 3.4 及以上版本天然支持加载模型进行预测。

关键步骤说明:

  1. 模型加载策略:需要为每个 Worker 单独加载模型,确保模型在分布式环境中的可用性。
  2. 文件夹管理:在加载 Hugging Face 预训练模型之前,务必删除之前的模型文件夹,防止加载失败。


注:如果图片无法显示,请检查链接合法性或稍后重试。

方法二:基于 Spark 3.3.1 的手动封装接口

如果你希望保持当前的 Spark 3.3.1 版本,那么可以通过手动封装接口来实现 Hugging Face 和 Spark 的结合。以下是详细的代码实现和关键说明。

封装分布式的模型缓存

为了高效管理模型加载和缓存,我们从spark3.4的源代码中抽取了一个分布式的模型缓存机制:

from collections import OrderedDict
from threading import Lock
from typing import Callable, Optional
from uuid import UUID
class ModelCache:
 """Cache for model prediction functions on executors.
 This requires the `spark.python.worker.reuse` configuration to be set to `true`, otherwise a
 new python worker (with an empty cache) will be started for every task.
 If a python worker is idle for more than one minute (per the IDLE_WORKER_TIMEOUT_NS setting in
 PythonWorkerFactory.scala), it will be killed, effectively clearing the cache until a new python
 worker is started.
 Caching large models can lead to out-of-memory conditions, which may require adjusting spark
 memory configurations, e.g. `spark.executor.memoryOverhead`.
 """
 _models: OrderedDict = OrderedDict()
 _capacity: int = 3 # "reasonable" default size for now, make configurable later, if needed
 _lock: Lock = Lock()
 @staticmethod
 def add(uuid: UUID, predict_fn: Callable) -> None:
 with ModelCache._lock:
 ModelCache._models[uuid] = predict_fn
 ModelCache._models.move_to_end(uuid)
 if len(ModelCache._models) > ModelCache._capacity:
 ModelCache._models.popitem(last=False)
 @staticmethod
 def get(uuid: UUID) -> Optional[Callable]:
 with ModelCache._lock:
 predict_fn = ModelCache._models.get(uuid)
 if predict_fn:
 ModelCache._models.move_to_end(uuid)
 return predict_fn

封装处理逻辑

from __future__ import annotations
import os
import argparse
import random
import logging
import pandas as pd
from pyspark.sql import SparkSession
from pyspark.sql.functions import udf, column, encode
from pyspark.sql.types import *
from datetime import datetime, timedelta
import requests as req
from io import BytesIO
import numpy as np
import uuid
import inspect
from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import (
 ArrayType,
 ByteType,
 DataType,
 DoubleType,
 FloatType,
 IntegerType,
 LongType,
 ShortType,
 StringType,
 StructType,
)
from typing import Any, Callable, Iterator, List, Mapping, TYPE_CHECKING, Tuple, Union, Optional
supported_scalar_types = (
 ByteType,
 ShortType,
 IntegerType,
 LongType,
 FloatType,
 DoubleType,
 StringType,
)
hadoop = os.path.join(os.environ['HADOOP_COMMON_HOME'], 'bin/hadoop')
def init_spark():
 """初始化 SparkSession 配置"""
 spark = SparkSession.builder \
 .config("spark.sql.caseSensitive", "false") \
 .config("spark.shuffle.spill", "true") \
 .config("spark.shuffle.spill.compress", "true") \
 .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") \
 .config("metastore.catalog.default", "hive") \
 .config("spark.sql.hive.convertMetastoreOrc", "true") \
 .config("spark.kryoserializer.buffer.max", "1024m") \
 .config("spark.kryoserializer.buffer", "64m") \
 .config("spark.driver.maxResultSize","4g") \
 .config("spark.sql.broadcastTimeout", "36000") \
 .enableHiveSupport() \
 .getOrCreate()
 return spark
def system_command(command):
 """执行系统命令"""
 code = os.system(command)
 if code != 0:
 logging.error(f"Command: ({command}) execute failed.")
 else:
 logging.info(f"Command: ({command}) execute succeed.")
def parse_args():
 """解析命令行参数"""
 parser = argparse.ArgumentParser(usage="it's usage tip.",
 description="user tags prefer")
 parser.add_argument("--db", default="", help="hive表")
 parser.add_argument("--date", default="", help="日期")
 parser.add_argument("--output_path", default="", help="输出路径")
 parser.add_argument("--batch_size", default=16, help="输出路径")
 return parser.parse_args()
def _batched(
 data: Union[pd.Series, pd.DataFrame, Tuple[pd.Series]], batch_size: int
) -> Iterator[pd.DataFrame]:
 """将 pandas dataframe/series 分批处理"""
 if isinstance(data, pd.DataFrame):
 df = data
 elif isinstance(data, pd.Series):
 df = pd.concat((data,), axis=1)
 else: # isinstance(data, Tuple[pd.Series])
 df = pd.concat(data, axis=1)
 index = 0
 data_size = len(df)
 while index < data_size:
 yield df.iloc[index : index + batch_size]
 index += batch_size
def _is_tensor_col(data: Union[pd.Series, pd.DataFrame]) -> bool:
 """检查数据是否为张量列"""
 if isinstance(data, pd.Series):
 return data.dtype == np.object_ and isinstance(data.iloc[0], (np.ndarray, list))
 elif isinstance(data, pd.DataFrame):
 return any(data.dtypes == np.object_) and any(
 [isinstance(d, (np.ndarray, list)) for d in data.iloc[0]]
 )
 else:
 raise ValueError(
 "Unexpected data type: {}, expected pd.Series or pd.DataFrame.".format(type(data))
 )
def _has_tensor_cols(data: Union[pd.Series, pd.DataFrame, Tuple[pd.Series]]) -> bool:
 """检查输入是否包含张量值列"""
 if isinstance(data, (pd.Series, pd.DataFrame)):
 return _is_tensor_col(data)
 else: # isinstance(data, Tuple)
 return any(_is_tensor_col(elem) for elem in data)
def _validate_and_transform_multiple_inputs(
 batch: pd.DataFrame, input_shapes: List[Optional[List[int]]], num_input_cols: int
) -> List[np.ndarray]:
 """验证并转换多个输入"""
 multi_inputs = [batch[col].to_numpy() for col in batch.columns]
 if input_shapes:
 if len(input_shapes) == num_input_cols:
 multi_inputs = [
 np.vstack(v).reshape([-1] + input_shapes[i]) # type: ignore
 if input_shapes[i]
 else v
 for i, v in enumerate(multi_inputs)
 ]
 if not all([len(x) == len(batch) for x in multi_inputs]):
 raise ValueError("Input data does not match expected shape.")
 else:
 raise ValueError("input_tensor_shapes must match columns")
 return multi_inputs
def _validate_and_transform_single_input(
 batch: pd.DataFrame,
 input_shapes: List[Optional[List[int]]],
 has_tensors: bool,
 has_tuple: bool,
) -> np.ndarray:
 """验证并转换单个输入"""
 # 处理逻辑省略(与原文一致)
 return single_input
def _validate_and_transform_prediction_result(
 preds: Union[np.ndarray, Mapping[str, np.ndarray], List[Mapping[str, Any]]],
 num_input_rows: int,
 return_type: DataType,
) -> Union[pd.DataFrame, pd.Series]:
 """验证并转换预测结果"""
 # 处理逻辑省略(与原文一致)
 return pd.DataFrame(preds)
def predict_batch_udf(
 make_predict_fn: Callable[
 [],
 PredictBatchFunction,
 ],
 *,
 return_type: DataType,
 batch_size: int,
 input_tensor_shapes: Optional[Union[List[Optional[List[int]]], Mapping[int, List[int]]]] = None,
):
 """定义批量预测的 Pandas UDF"""
 model_uuid = uuid.uuid4()
 def predict(data: Iterator[Union[pd.Series, pd.DataFrame]]) -> Iterator[pd.DataFrame]:
 from model_cache import ModelCache
 predict_fn = ModelCache.get(model_uuid)
 if not predict_fn:
 predict_fn = make_predict_fn()
 ModelCache.add(model_uuid, predict_fn)
 signature = inspect.signature(predict_fn)
 num_expected_cols = len(signature.parameters)
 input_shapes: List[Optional[List[int]]]
 if isinstance(input_tensor_shapes, Mapping):
 input_shapes = [None] * num_expected_cols
 for index, shape in input_tensor_shapes.items():
 input_shapes[index] = shape
 else:
 input_shapes = input_tensor_shapes # type: ignore
 for pandas_batch in data:
 has_tuple = isinstance(pandas_batch, Tuple) # type: ignore
 has_tensors = _has_tensor_cols(pandas_batch)
 if has_tensors and not input_shapes:
 raise ValueError("Tensor columns require input_tensor_shapes")
 for batch in _batched(pandas_batch, batch_size):
 num_input_rows = len(batch)
 num_input_cols = len(batch.columns)
 if num_input_cols == num_expected_cols and num_expected_cols > 1:
 multi_inputs = _validate_and_transform_multiple_inputs(
 batch, input_shapes, num_input_cols
 )
 preds = predict_fn(*multi_inputs)
 elif num_expected_cols == 1:
 single_input = _validate_and_transform_single_input(
 batch, input_shapes, has_tensors, has_tuple
 )
 preds = predict_fn(single_input)
 else:
 msg = "Model expected {} inputs, but received {} columns"
 raise ValueError(msg.format(num_expected_cols, num_input_cols))
 yield _validate_and_transform_prediction_result(
 preds, num_input_rows, return_type
 ) # type: ignore
 return pandas_udf(predict, return_type) # type: ignore[call-overload]
def extract_text_embedding(model, tokenizer, sentence):
 """提取文本嵌入向量"""
 inputs = tokenizer(sentence, return_tensors='pt', max_length=32, padding=True, truncation=True)
 embeddings = model(**inputs)
 embeddings = embeddings.pooler_output
 embeddings = embeddings.tolist()
 for i in range(len(embeddings)):
 embeddings[i] = [round(c,4) for c in embeddings[i]]
 return np.array(embeddings, dtype=np.float32)
if __name__ == "__main__":
 args = parse_args()
 spark = init_spark() 
	
	 ### 读取数据
	 df = spark.sql(f"""
 select article_id, title
 from xxx
 """)
	
 def predict_embedding():
 system_command(f"""rm -rf ./bert-base-chinese""")
 system_command(f"""{hadoop} fs -get /path/to/bert-base-chinese""")
 from transformers import BertTokenizer, BertModel
 tokenizer = BertTokenizer.from_pretrained('./bert-base-chinese')
 text_model = BertModel.from_pretrained('./bert-base-chinese')
 def predict(inputs):
 sentence = inputs.tolist()
 embeddings = extract_text_embedding(text_model, tokenizer, sentence)
 return embeddings
 return predict
 predict_embedding_udf = predict_batch_udf(predict_embedding,
 return_type=ArrayType(StringType()),
 batch_size=100)
 df.withColumn("title_embedding", predict_embedding_udf("title")).show(5)
 spark.stop()
 del spark

关键点说明:

  1. 模型加载与缓存:通过 predict_batch_udf 函数封装预测逻辑,利用模型缓存避免重复加载,提高效率。
  2. 批量处理:使用 _batched 函数将数据分批处理,避免内存溢出,适合大数据场景。
  3. 类型转换与验证:通过 _validate_and_transform 系列函数确保输入输出类型匹配,提高代码健壮性。
  4. Hugging Face 模型集成:在 predict_embedding 函数中加载 Hugging Face 的 BERT 模型,并定义预测逻辑。
  5. typing语法修改:python3.10以的typing是不支持|语法的,需要改成Union进行类型的或推断

方法比较

对比维度方法一:升级 Spark 至 3.4+方法二:基于 Spark 3.3.1 手动封装接口
实现难度较低,依托新版本特性较高,需手动实现缓存及接口封装
模型加载方式每个 Worker 单独加载模型每个 Worker 单独加载模型,并通过分布式缓存机制复用
文件管理要求需提前删除旧模型文件夹防止加载失败加载前删除旧模型文件夹
代码复用性可直接使用新版本 API,代码简洁需手动封装,代码量较大,但更具灵活性
性能优化新版本可能自带优化可通过调整缓存策略、批量处理逻辑等进行精细优化
适用场景适合可升级环境,追求快速开发适合无法升级环境,或对性能和资源管理有更高要求的场景
可维护性依赖新版本稳定性,升级后需充分测试自定义逻辑较多,需额外维护封装的接口和缓存机制
扩展性依赖 Spark 新版本的更新节奏可根据项目需求灵活扩展自定义功能
社区支持可直接参考官方文档和社区对新版本的案例需结合旧版本社区经验,同时参考自定义实现的维护文档
资源消耗新版本可能对硬件有新要求可通过优化缓存和批处理逻辑,更精细地控制资源使用

通过以上两种方法,可以在不同 Spark 版本环境下实现 Hugging Face 和 Spark 的结合,充分发挥两者在 NLP 和大数据处理中的优势,推荐第二种,更加可控一些。

本文由博客一文多发平台 OpenWrite 发布!

作者:saboxu原文地址:https://www.cnblogs.com/saboxu/p/18889903

%s 个评论

要回复文章请先登录注册