logo

模型保存和加载

王哲峰 / 2022-10-24


目录

模型持久化介绍

# python libraries
import os
import sys
from typing import List
import logging
import pickle
from sklearn.externals import joblib
from sklearn2pmml import sklearn2pmml, PMMLPipeline
from sklearn_pandas import DataFrameMapper
from pypmml import Model


# TODO
class ModelDeploy:
    """
    模型部署类
    """
    def __init__(self, save_file_path: str):
        self.save_file_path = save_file_path
    
    def ModelSave(self):
        """
        模型保存
        """
        pass

    def ModelLoad(self):
        """
        模型载入
        """
        pass

Pickle

class ModelDeployPkl(ModelDeploy):
    """
    模型离线部署类
    """
    def __init__(self, save_file_path: str):
        self.save_file_path = save_file_path  # 模型保存的目标路径
    
    def ModelSave(self, model):
        """
        模型保存

        Args:
            model (instance): 模型实例

        Raises:
            Exception: [description]
        """
        if not self.save_file_path.endswith(".pkl"):
            raise Exception("参数 save_file_path 后缀必须为 'pkl', 请检查.")
        
        with open(self.save_file_path, "wb") as f:
            pickle.dump(model, f, protocol = 2)
        # TODO
        f.close()
        logging.info(f"模型文件已保存至{self.save_file_path}")

    def ModelLoad(self):
        """
        模型载入

        Raises:
            Exception: [description]
        """
        if not os.path.exists(self.save_file_path):
            raise Exception("参数 save_file_path 指向的文件路径不存在,请检查.")
        
        self.model = joblib.load(self.save_file_path)


# 测试代码 main 函数
def main():
    save_file_path = None
    model_deploy_pkl = ModelDeployPkl(save_file_path)

if __name__ == "__main__":
    main()

PMML

class ModelDeployPmml(ModelDeploy):
    """
    模型在线部署类
    """
    def __init__(self, save_file_path: str):
        self.save_file_path = save_file_path  # 模型保存的目标路径
    
    def ModelSave(self, model, features_list: List):
        """
        模型保存

        Args:
            model (instance): 模型实例
            features_list (list): 模型特征名称列表

        Raises:
            Exception: [description]
        """
        if not self.save_file_path.endswith(".pmml"):
            raise Exception("参数 save_file_path 后缀必须为 'pmml', 请检查.")

        mapper = DataFrameMapper([([i], None) for i in features_list])
        pipeline = PMMLPipeline([
            ("mapper", mapper),
            ("classifier", model),
        ])
        sklearn2pmml(pipeline, pmml = self.save_file_path)
        logging.info(f"模型文件已保存至{self.save_file_path}")

    def ModelLoad(self):
        """
        模型载入

        Raises:
            Exception: [description]
        """
        if not os.path.exists(self.save_file_path):
            raise Exception("参数 save_file_path 指向的文件路径不存在,请检查.")

        self.model = Model.fromFile(self.save_file_path)




# 测试代码 main 函数
def main():
    save_file_path = None
    model_deploy_pmml = ModelDeployPmml(save_file_path)

if __name__ == "__main__":
    main()

参考