logo

CatBoost API

王哲峰 / 2023-02-24


目录

CatBoost 参数

CatBoost API

CatBoost 安装

安装依赖库

$ pip install numpy six

安装 CatBoost 库

$ pip install catboost

核心数据结构

Learning API

CatBoostClassifier

import numpy as np
from catboost import CatBoostClassifier, Pool

# initialize data
train_data = np.random.randit(0, 100, size = (100, 10))
train_labels = np.random.randint(0, 2, size = (100))
test_data = catboost_pool = Pool(train_data, train_labels)

# build model
model = CatBoostClassifier(
    iterations = 2,
    depth = 2,
    learning_rate = 1,
    loss_function = "Logloss",
    verbose = True
)

# train model
model.fit(train_data, train_labels)

# prediction using model
y_pred = model.predict(test_data)
y_pred_proba = model.predict_proba(test_data)
print("class = ", y_pred)
print("proba = ", y_pred_proba)

CatBoostRegressor

import numpy as np
from catboost import CatBoostRegressor, Pool

# initialize data
train_data = np.random.randint(0, 100, size = (100, 10))
train_labels = np.random.randint(0, 100, size = (100))
test_data = np.random.randint(0, 100, size = (50, 10))

# initialize Pool
train_pool = Pool(train_data, train_label, cat_features = [0, 2, 5])
test_pool = Pool(test_data, cat_features = [0, 2, 5])

# build model
model = CatBoostRegressor(
    iterations = 2, 
    depth = 2,
    learning_rate = 1, 
    loss_function = "RMSE"
)

# train model
model.fit(train_pool)

# prediction
y_pred = model.predict(test_pool)
print(y_pred)

CatBoost

import numpy as np
from catboost import CatBoost, Pool

# read the dataset
train_data = np.random.randint(0, 100, size = (100, 10))
train_labels = np.random.randint(0, 2, size = (100))
test_data = np.random.randint(0, 100, size = (50, 10))

# init pool
train_pool = Pool(train_data, train_labels)
test_pool = Pool(test_data)

# build model
param = {
    "iterations": 5
}
model = CatBoost(param)

# train model
model.fit(train_pool)

# prediction
y_pred_class = model.predict(test_pool, prediction_type = "Class")
y_pred_proba = model.predict(test_pool, prediction_type = "Probability")
y_pred_raw_vals = model.predict(test_pool, prediction_type = "RawFormulaVal")
print("Class", y_pred_class)
print("Proba", y_pred_proba)
print("Raw", y_pred_raw_valss)

数据可视化 API

安装 ipywidgets 可视化库:

$ pip install ipywidgets
$ jypyter nbextension enable --py widgetsnbextersion

CatBoost 数据可视化介绍: