[Tensorflow] Custom Training Loop

Custom Training Loop

tf.keras 中已經提供很方便的 training and evaluation loops, fit()evaluate()

但如果我們想要對 training 或 evaluation 進行更 low-level 的控制的話, 我們需要從頭開始寫自己的 training and evaluation loops,如:

  • 自定義 model 的學習演算法,同時仍然利用 fit() 的便利性。

    例如: 利用 fit() 來訓練 GAN

    我們需要用 model subclassing 的方法創建 model,並且實現 train_step()方法,在 model.fit() 的期間會一直重複呼叫此方法。詳細請看: Customizing what happens in fit()

Using the GradientTape

Tensorflow 提供了一個很好用的 API: tf.GradientTape() 用於自動微分 (Automatic Differentiation, AD),詳細介紹請看 here


  • We open a for loop that iterates over epochs
  • For each epoch, we open a for loop that iterates over the dataset, in batches
  • For each batch, we open a GradientTape() scope
  • Inside this scope, we call the model (forward pass) and compute the loss
  • Outside the scope, we retrieve the gradients of the weights of the model with regard to the loss
  • Finally, we use the optimizer to update the weights of the model based on the gradients


for epoch in range(EPOCHS):
    # Iterate over the batches of the dataset.
    for batch_idx, (x_train, y_train) in enumerate(train_dataset):
        # Open a GradientTape to record the operations
        # run during the forward pass,
        # which enables auto-differentiation.
        with tf.GradientTape() as tape:
            # forward pass
            predicitions = model(x_train, training=True)
            # Compute the loss value for this minibatch.
            loss = loss_fn(y_train, predicitions)

        # backward pass
        # Use the gradient tape to automatically retrieve
        # the gradients of the trainable variables with respect to the loss.
        grads = tape.gradient(loss, model.trainable_weights) # model.trainable_variables

        # Run one step of gradient descent by updating
        # the value of the variables to minimize the loss.
            zip(grads, model.trainable_weights))

Low-level handling of metrics

讓我們添加 metrics 來監測這一個 training loop。我們可以在這個 loop 內使用 built-in metrics 或是 custom metrics,流程如下:


Let’s use this knowledge to compute SparseCategoricalAccuracy on validation data at the end of each epoch:


# Instantiate an optimizer to train the model.
optimizer = keras.optimizers.SGD(learning_rate=1e-3)
# Instantiate a loss function.
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
# Instantiate the metrics.
train_acc_metric = keras.metrics.SparseCategoricalAccuracy()
val_acc_metric = keras.metrics.SparseCategoricalAccuracy()

for epoch in range(EPOCHS):
    # Iterate over the batches of the dataset.
    for batch_idx, (x_train, y_train) in enumerate(train_dataset):
        with tf.GradientTape() as tape:
            # forward pass
            pred = model(x_train, training=True)
            loss = loss_fn(y_train, pred)

        grads = tape.gradient(loss, model.trainable_weights) # model.trainable_variables
            zip(grads, model.trainable_weights))

        # Update training metric.
        train_acc_metric.update_state(y_train, pred)

    # Display metrics at the end of each epoch.
    train_acc = train_acc_metric.result()
    print("Training acc over epoch: %.4f" % (float(train_acc),))

    # Reset training metrics at the end of each epoch

    # Run a validation loop at the end of each epoch.
    for x_val, y_val in val_dataset:
        val_pred = model(x_val, training=False)
        # Update val metrics
        val_acc_metric.update_state(y_val, val_pred)
    val_acc = val_acc_metric.result()
    print("Validation acc: %.4f" % (float(val_acc),))
    print("Time taken: %.2fs" % (time.time() - start_time))

Speeding-up your training step with tf.function

Tensorflow 2 中默認的 runtime 是 eager execution.,這模式適合 debigging ,但是效能比較差,可以透過加上 @tf.function 裝飾器(decorator)將函式編譯成靜態圖 (static graph)。

Graph compilation has a definite performance advantage. 靜態圖會做優化


def train_step(x, y):
    with tf.GradientTape() as tape:
        logits = model(x, training=True)
        loss_value = loss_fn(y, logits)
    grads = tape.gradient(loss_value, model.trainable_weights)
    optimizer.apply_gradients(zip(grads, model.trainable_weights))
    train_acc_metric.update_state(y, logits)
    return loss_value

你會發現與沒加上 @tf.function 的比起來速度變快了!

Low-level handling of losses tracked by the model

Layers & models recursively track any losses created during the forward pass by layers that call self.add_loss(value). The resulting list of scalar loss values are available via the property model.losses at the end of the forward pass.

If you want to be using these loss components, you should sum them and add them to the main loss in your training step.


Consider this layer, that creates an activity regularization loss:

class ActivityRegularizationLayer(layers.Layer):
    def call(self, inputs):
        self.add_loss(1e-2 * tf.reduce_sum(inputs))
        return inputs

Let’s build a really simple model that uses it:

inputs = keras.Input(shape=(784,), name="digits")
x = layers.Dense(64, activation="relu")(inputs)
# Insert activity regularization as a layer
x = ActivityRegularizationLayer()(x)
x = layers.Dense(64, activation="relu")(x)
outputs = layers.Dense(10, name="predictions")(x)

model = keras.Model(inputs=inputs, outputs=outputs)

Here’s what our training step should look like now:

def train_step(x, y):
    with tf.GradientTape() as tape:
        pred = model(x, training=True)
        loss = loss_fn(y, pred)
        # Add any extra losses created during the forward pass.
        loss += sum(model.losses)
    grads = tape.gradient(loss, model.trainable_weights)
    optimizer.apply_gradients(zip(grads, model.trainable_weights))
    train_acc_metric.update_state(y, pred)
    return loss