logo

TensorFlow 三阶 API

王哲峰 / 2022-09-20


目录

低阶 API

低阶 API 主要包括:

中阶 API

TensorFlow 中阶 API 主要包括:

高阶 API

TensorFlow 的高阶 API 主要为 tf.keras.models 提供的模型的类接口

TensorFlow 高阶 API 主要包括:

线性回归模型

载入 Python 依赖

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow.keras import models, layers, losses, metrics, optimizers

数据准备

# 样本数量
num_samples = 400

# 生成测试用数据集
X = tf.random.uniform([n, 2], minval = -10, maxval = 10)
w0 = tf.constant([[2.0], [-3.0]])
b0 = tf.constant([[3.0]])
Y = X@w0 + b0 + tf.random.normal([n, 1], mean = 0.0, stddev = 2.0)
%matplotlib inline
%config InlineBackend.figure_format = "svg"

plt.figure(figsize = (12, 5))

ax1 = plt.subplot(121)
ax1.scatter(X[:, 0], Y[:, 0], c = "b")
plt.xlabel("x1")
plt.ylabel("y", rotation = 0)

ax2 = plt.subplot(122)
ax2.scatter(X[:, 1], Y[:, 0], c = "g")
plt.xlabel("x2")
plt.ylabel("y", rotation = 0)

plt.show()

模型构建

tf.keras.backend.clear_session()

model = models.Sequential()
model.add(layers.Dense(1, input_shape = (2,)))

model.summary()

模型训练

model.compile(
    optimizer = "adam",
    loss = "mse",
    metrics = ["mae"],
)
model.fit(X, Y, batch_size = 10, epochs = 200)

tf.print(f"w = {model.layers[0].kernel}")
tf.print(f"b = {model.layers[0].bias}")

模型结果可视化

%matplotlib inline
%config InlineBackend.figure_format = "svg"

w, b = model.variables

plt.figure(figsize = (12, 5))

ax1 = plt.subplot(121)
ax1.scatter(X[:, 0], Y[:, 0], c = "b", label = "samples")
ax1.plot(X[:, 0], w[0] * X[:, 0] + b[0], "-r", linewidth = 5.0, label = "model")
ax1.legend()
plt.xlabel("x1")
plt.ylabel("y", rotation = 0)

ax2 = plt.subplot(122)
ax2.scatter(X[:, 1], Y[:, 0], c = "g", label = "samples")
ax2.plot(X[:, 1], w[1] * X[:, 1] + b[0], "-r", linewidth = 5.0, label = "model")
ax2.legend()
plt.xlabel("x2")
plt.ylabel("y", rotation = 0)

plt.show()