logo

TensorFlow Estimator

wangzf / 2022-09-10


目录

预创建的 Estimator

预创建的 Estimator 程序的结构

依赖预创建的Estimator的TensorFlow程序通常包含下列四个步骤:

  1. 编写一个或多个数据集导入函数
    • 创建一个函数来导入训练集, 并创建另一个函数来导入测试集。每个数据集导入函数都必须返回两个对象:
      • 一个字典, 其中键是特征名称, 值是包含相应特征数据的张量(or Sparse Tensro);
      • 一个包含一个或多个标签的张量;
  2. 定义特征列
    • 每个 tf.feature_column 都标识了特征名称、特征类型和任何输入预处理操作
  3. 实例化相关的预创建的Estimator
    • LinearClassifier
  4. 调用训练、评估或推理方法
    • 所有Estimator都提供训练模型的 train 方法

上面步骤实现举例:

def input_fn_train(dataset):
   # manipulate dataset, extracting the feature dict and the label
   
   return feature_dict, label

def input_fn_test(dataset):
   # manipulate dataset, extracting the feature dict and the label
   
   return feature_dict, label


my_training_set = input_fn_train()
my_testing_set = input_fn_test()

population = tf.feature_column.numeric_column('population')
crime_rate = tf.feature_column.numeric_column('crime_rate')
median_education = tf.feature_column.numeric_column('median_education', 
                                                   normalizer_fn = lambda x: x - global_education_mean)

estimator = tf.estimator.LinearClassifier(
   feature_columns = [population, crime_rate, median_education],
)

estimator.train(input_fn = my_training_set, setps = 2000)

预创建的 Estimator 的优势

自定义的 Estimator

从 Keras 模型创建 Estimator

keras_inception_v3 = tf.keras.applications.keras_inception_v3.InceptionV3(weights = None)

keras_inception_v3.compile(optimizer = tf.keras.optimizers.SGD(lr = 0.0001, momentum = 0.9),
                           loss = 'categorical_crossentropy',
                           metric = 'accuracy')

est_inception_v3 = tf.keras.estimator.model_to_estimator(keras_model = keras_inception_v3)

keras_inception_v3.input_names

train_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
      x = {'input_1': train_data},
      y = train_labels,
      num_epochs = 1,
      shuffle = False
)

est_inception_v3.train(input_fn = train_input_fn, steps = 2000)

API:

从一个给定的Keras模型中构造一个Estimator实例

tf.keras.estimator.model_to_estimator(
      keras_model = None,
      keras_model_path = None,
      custom_objects = None,
      model_dir = None,
      config = None
)