logo

TensorFlow 工具集

王哲峰 / 2022-07-15


目录

模型可视化

plot_model()


import tensorflow as tf

tf.keras.utils.plot_model(
    model,
    to_file = "model.png",
    show_shapes = False,
    show_dtype = False,
    show_layer_names = True,
    rankdir = "TB",
    expand_nested = False,
    dpi = 96,
)

model_to_dot()

import tensorflow as tf

tf.keras.utils.model_to_dot(
    model,
    show_shapes = False,
    show_dtype = False,
    show_layer_names = True,
    rankdir = "TB",             # "TB": a vertical plot; "LR": a horizontal plot
    expand_nested = False,
    dpi = 96,
    subgraph = False,
)

序列化工具(Serialization utilities)

CustomObjectScope class

import tensorflow as tf

tf.keras.utils.custom_object_scope(*args)
# 一个自定义的正则化器 `my_regularizer`
my_regularizer = None

# a layer
layer = Dense(3, kernel_regularizer = my_regularizer)

# Config contains a reference to "my_regularizer"
config = layer.get_config()
...

# Later
with custom_object_scope({"my_regularizer": my_regularizer}):
    layer = Dense.from_config(config)

get_custom_objects()

import tensorflow as tf

tf.keras.utils.get_custom_objects()
get_custom_objects().clear()
get_custom_objects()["MyObject"] = MyObject

register_keras_serializable()

import tensorflow as tf

tf.keras.utils.register_keras.serializable(package = "Custom", name = None)

serialize_keras_object()

import tensorflow as tf

tf.keras.utils.serialize_keras_object(instance)

daserialize_keras_object()

import tensorflow as tf

tf.keras.utils.deserialize_keras_object(
    identifier, 
    module_objects = None,
    custom_objects = None,
    printable_module_name = "object"
)

Python & Numpy utilities

to_categorical()

import tensorflow as tf

utils.to_categorical(y,
                    num_classes = None,
                    dtypes = "float32")
# example 1
a = tf.keras.utils.to_categorical([0, 1, 2, 3], num_classes = 4)
a = tf.constant(a, shape = [4, 4])
print(a)

# example 2
b = tf.constant([.9, .04, .03, .03,
                    .3, .45, .15, .13,
                    .04, .01, .94, .05,
                    .12, .21, .5, .17],
                    shape = [4, 4])
loss = tf.keras.backend.categorical_crossentropy(a, b)
print(np.around(loss, 5))

# example 3
loss = tf.keras.backend.categorical_crossentropy(a, a)
print(np.around(loss, 5))

normalize()

import tensorflow as tf

tf.keras.utils.normalize(x, axis = -1, order = 2)

get_file()


tf.keras.utils.get_file(
    fname,
    origin,
    untar=False,
    md5_hash=None,
    file_hash=None,
    cache_subdir="datasets",
    hash_algorithm="auto",
    extract=False,
    archive_format="auto",
    cache_dir=None,
)

import tensorflow

path_to_downloaded_file = tf.keras.utils.get_file(
    "flower_photos",
    "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz",
    untar = True
)

Progbar class

import tensorflow as tf

tf.keras.utils.Progbar(
    target, 
    width = 30, 
    verbose = 1, 
    interval = 0.05, 
    stateful_metrics = None, 
    unit_name = "step"
)

Sequence class

import tensorflow as tf
tf.keras.utils.Sequence()
from skimage.io import imread
from skimage.transform import resize
import numpy as np
import math

# Here, `x_set` is list of path to the images
# and `y_set` are the associated classes.

class CIFAR10Sequence(Sequence):

    def __init__(self, x_set, y_set, batch_size):
        self.x, self.y = x_set, y_set
        self.batch_size = batch_size

    def __len__(self):
        return math.ceil(len(self.x) / self.batch_size)

    def __getitem__(self, idx):
        batch_x = self.x[idx * self.batch_size:(idx + 1) *
        self.batch_size]
        batch_y = self.y[idx * self.batch_size:(idx + 1) *
        self.batch_size]

        return np.array([
            resize(imread(file_name), (200, 200))
            for file_name in batch_x]), np.array(batch_y)