【深度学习 走进tensorflow2.0】训练模型以及保存最佳模型

无意中发现了一个巨牛的人工智能教程,忍不住分享一下给大家。教程不仅是零基础,通俗易懂,而且非常风趣幽默,像看小说一样!觉得太牛了,所以分享给大家。点这里可以跳转到教程。人工智能教程

项目目录:
数据集:
下载 二分类数据集:cats_and_dogs_filtered文件夹

wget https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip

目录结构如下:

.
├── [drwxr-x---]  cats_and_dogs_filtered
│   ├── [drwxr-x---]  train
│   │   ├── [drwxr-x---]  cats
│   │   └── [drwxr-x---]  dogs
│   └── [drwxr-x---]  validation
│       ├── [drwxr-x---]  cats
│       └── [drwxr-x---]  dogs

1、数据预处理
2、数据增强
2、创建模型
4、编译模型
5、模型自动保存定义
6、模型训练
7、模型保存回调以及自动调节学习率回调

模型保存:

├── [-rw-rw-r--]  model
├── [-rw-rw-r--]  model_class.json
├── [-rw-rw-r--]  model_ex-001_acc-0.500000.h5
├── [-rw-rw-r--]  model_ex-002_acc-0.501116.h5
├── [-rw-rw-r--]  model_ex-004_acc-0.527902.h5
├── [-rw-rw-r--]  model_ex-005_acc-0.540179.h5
├── [-rw-rw-r--]  model_ex-006_acc-0.549107.h5
├── [-rw-rw-r--]  model_ex-007_acc-0.600446.h5
├── [-rw-rw-r--]  model_ex-008_acc-0.646205.h5
├── [-rw-rw-r--]  model_ex-010_acc-0.648438.h5
├── [-rw-rw-r--]  model_ex-013_acc-0.650670.h5
└── [-rw-rw-r--]  model_ex-014_acc-0.672991.h5

关键点:
1、训练量很少时,通常会发生过度拟合。解决此问题的一种方法是扩充数据集,使其具有足够数量的训练示例。数据增强采用通过使用产生真实感图像的随机变换增强样本来从现有训练样本生成更多训练数据的方法。目标是模型在训练期间永远不会看到两次完全相同的图片。这有助于使模型暴露于数据的更多方面,并且可以更好地进行概括。
tf.keras使用ImageDataGenerator类来实现这一点。将不同的转换传递给数据集,它将在训练过程中加以应用。

2、tf.keras.callbacks.ModelCheckpoint函数 自动保存模型

tf.keras.callbacks.ModelCheckpoint
参数:

filepath:string,保存模型文件的路径。
monitor:要监测的数量。
verbose:详细信息模式,0或1。
save_best_only:如果save_best_only=True,被监测数量的最佳型号不会被覆盖。
mode:{auto,min,max}之一。如果save_best_only=True,那么是否覆盖保存文件的决定就取决于被监测数据的最大或者最小值。对于val_acc,这应该是max,对于val_loss这应该是min,等等。在auto模式中,方向是从监测数量的名称自动推断出来的。
save_weights_only:如果为True,则仅保存模型的权重(model.save_weights(filepath)),否则保存完整模型(model.save(filepath))。
period:检查点之间的间隔(epoch数)。

3、TensorFlow回调函数:tf.keras.callbacks.ReduceLROnPlateau
当指标停止提升时,降低学习速率。一旦学习停止,模型通常会将学习率降低2-10倍。该回调监测数量,如果没有看到epoch的 ‘patience’ 数量的改善,那么学习率就会降低。


reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2,
                              patience=5, min_lr=0.001)
model.fit(X_train, Y_train, callbacks=[reduce_lr])


参数:

monitor:要监测的数量。
factor:学习速率降低的因素。new_lr = lr * factor
patience:没有提升的epoch数,之后学习率将降低。
verbose:int。0:安静,1:更新消息。
mode:{auto,min,max}之一。在min模式下,当监测量停止下降时,lr将减少;在max模式下,当监测数量停止增加时,它将减少;在auto模式下,从监测数量的名称自动推断方向。
min_delta:对于测量新的最优化的阀值,仅关注重大变化。
cooldown:在学习速率被降低之后,重新恢复正常操作之前等待的epoch数量。
min_lr:学习率的下限。

4、模型训练

使用fit_generator方法ImageDataGenerator来训练网络。

history = model.fit_generator(
    train_data_gen,
    steps_per_epoch=int(num_train / batch_size),
    epochs=epochs,
    validation_data=val_data_gen,
    validation_steps=int(num_test / batch_size),
    callbacks=[checkpoint,lr_scheduler])

完整训练代码:

# -*- coding: utf-8 -*-


from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Conv2D, Flatten, Dropout, MaxPooling2D
from tensorflow.keras.preprocessing.image import ImageDataGenerator

import json


import os
import numpy as np


batch_size = 128
epochs = 15
IMG_HEIGHT = 150
IMG_WIDTH = 150



PATH = os.path.join('/home/dongli/tensorflow2.0/corpus/', 'cats_and_dogs_filtered')


train_dir = os.path.join(PATH, 'train')
validation_dir = os.path.join(PATH, 'validation')

train_cats_dir = os.path.join(train_dir, 'cats')  # directory with our training cat pictures
train_dogs_dir = os.path.join(train_dir, 'dogs')  # directory with our training dog pictures
validation_cats_dir = os.path.join(validation_dir, 'cats')  # directory with our validation cat pictures
validation_dogs_dir = os.path.join(validation_dir, 'dogs')  # directory with our validation dog pictures

num_cats_tr = len(os.listdir(train_cats_dir))
num_dogs_tr = len(os.listdir(train_dogs_dir))

num_cats_val = len(os.listdir(validation_cats_dir))
num_dogs_val = len(os.listdir(validation_dogs_dir))

total_train = num_cats_tr + num_dogs_tr
total_val = num_cats_val + num_dogs_val


print('total training cat images:', num_cats_tr)
print('total training dog images:', num_dogs_tr)

print('total validation cat images:', num_cats_val)
print('total validation dog images:', num_dogs_val)
print("--")
print("Total training images:", total_train)
print("Total validation images:", total_val)




# 训练集
# 对训练图像应用了重新缩放,45度旋转,宽度偏移,高度偏移,水平翻转和缩放增强。
image_gen_train = ImageDataGenerator(
                    rescale=1./255,
                    rotation_range=45,
                    width_shift_range=.15,
                    height_shift_range=.15,
                    horizontal_flip=True,
                    zoom_range=0.5
                    )

train_data_gen = image_gen_train.flow_from_directory(batch_size=batch_size,
                                                     directory=train_dir,
                                                     shuffle=True,
                                                     target_size=(IMG_HEIGHT, IMG_WIDTH),
                                                     class_mode='binary')

# 验证集

image_gen_val = ImageDataGenerator(rescale=1./255)

val_data_gen = image_gen_val.flow_from_directory(batch_size=batch_size,
                                                 directory=validation_dir,
                                                 target_size=(IMG_HEIGHT, IMG_WIDTH),
                                                 class_mode='binary')




# 创建模型

model = Sequential([
    Conv2D(16, 3, padding='same', activation='relu',
           input_shape=(IMG_HEIGHT, IMG_WIDTH ,3)),
    MaxPooling2D(),
    Dropout(0.2),
    Conv2D(32, 3, padding='same', activation='relu'),
    MaxPooling2D(),
    Conv2D(64, 3, padding='same', activation='relu'),
    MaxPooling2D(),
    Dropout(0.2),
    Flatten(),
    Dense(512, activation='relu'),
    Dense(1, activation='sigmoid')
])



# 编译模型

model.compile(optimizer='adam',loss='binary_crossentropy',metrics=['accuracy'])

# 模型总结
model.summary()


# 模型保存格式定义

model_class_dir='./model/'
class_indices = train_data_gen.class_indices
class_json = {}
for eachClass in class_indices:
    class_json[str(class_indices[eachClass])] = eachClass

with open(os.path.join(model_class_dir, "model_class.json"), "w+") as json_file:
    json.dump(class_json, json_file, indent=4, separators=(",", " : "),ensure_ascii=True)
    json_file.close()
print("JSON Mapping for the model classes saved to ", os.path.join(model_class_dir, "model_class.json"))



model_name = 'model_ex-{epoch:03d}_acc-{val_accuracy:03f}.h5'

trained_model_dir='./model/'
model_path = os.path.join(trained_model_dir, model_name)

checkpoint = tf.keras.callbacks.ModelCheckpoint(
             filepath=model_path,
             monitor='val_accuracy',
            verbose=1,
            save_weights_only=True,
            save_best_only=True,
            mode='max',
            period=1)


def lr_schedule(epoch):
    # Learning Rate Schedule

    lr =1e-3
    total_epochs =epochs
    check_1 = int(total_epochs * 0.9)
    check_2 = int(total_epochs * 0.8)
    check_3 = int(total_epochs * 0.6)
    check_4 = int(total_epochs * 0.4)

    if epoch > check_1:
        lr *= 1e-4
    elif epoch > check_2:
        lr *= 1e-3
    elif epoch > check_3:
        lr *= 1e-2
    elif epoch > check_4:
        lr *= 1e-1

    return lr



#lr_scheduler =tf.keras.callbacks.LearningRateScheduler(lr_schedule)

lr_scheduler = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2,patience=5, min_lr=0.001)


num_train = len(train_data_gen.filenames)
num_test = len(val_data_gen.filenames)

print(num_train,num_test)

# 模型训练
# 使用fit_generator方法ImageDataGenerator来训练网络。

history = model.fit_generator(
    train_data_gen,
    steps_per_epoch=int(num_train / batch_size),
    epochs=epochs,
    validation_data=val_data_gen,
    validation_steps=int(num_test / batch_size),
    callbacks=[checkpoint,lr_scheduler]
)



控制台输出:

Epoch 1/15
2019-10-31 15:07:47.955641: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcudnn.so.7
2019-10-31 15:07:49.213954: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcublas.so.10.0
14/15 [===========================>..] - ETA: 0s - loss: 1.1185 - accuracy: 0.5206 
15/15 [==============================] - 14s 924ms/step - loss: 1.0906 - accuracy: 0.5246 - val_loss: 0.6935 - val_accuracy: 0.5000
Epoch 2/15
14/15 [===========================>..] - ETA: 0s - loss: 0.6936 - accuracy: 0.4911
15/15 [==============================] - 13s 884ms/step - loss: 0.6935 - accuracy: 0.4948 - val_loss: 0.6932 - val_accuracy: 0.5011
Epoch 3/15
14/15 [===========================>..] - ETA: 0s - loss: 0.6910 - accuracy: 0.5230
15/15 [==============================] - 12s 809ms/step - loss: 0.6926 - accuracy: 0.5175 - val_loss: 0.6894 - val_accuracy: 0.5000
Epoch 4/15
14/15 [===========================>..] - ETA: 0s - loss: 0.6914 - accuracy: 0.5039
15/15 [==============================] - 12s 830ms/step - loss: 0.6915 - accuracy: 0.5031 - val_loss: 0.6919 - val_accuracy: 0.5279
Epoch 5/15
14/15 [===========================>..] - ETA: 0s - loss: 0.6893 - accuracy: 0.5472
15/15 [==============================] - 11s 764ms/step - loss: 0.6891 - accuracy: 0.5482 - val_loss: 0.6893 - val_accuracy: 0.5402
Epoch 6/15
14/15 [===========================>..] - ETA: 0s - loss: 0.6788 - accuracy: 0.5603
15/15 [==============================] - 12s 802ms/step - loss: 0.6784 - accuracy: 0.5641 - val_loss: 0.6866 - val_accuracy: 0.5491
Epoch 7/15
14/15 [===========================>..] - ETA: 0s - loss: 0.6789 - accuracy: 0.5608
15/15 [==============================] - 12s 801ms/step - loss: 0.6778 - accuracy: 0.5641 - val_loss: 0.6702 - val_accuracy: 0.6004
Epoch 8/15
14/15 [===========================>..] - ETA: 0s - loss: 0.6691 - accuracy: 0.5991
15/15 [==============================] - 12s 781ms/step - loss: 0.6681 - accuracy: 0.5976 - val_loss: 0.6632 - val_accuracy: 0.6462
Epoch 9/15
14/15 [===========================>..] - ETA: 0s - loss: 0.6635 - accuracy: 0.6021
15/15 [==============================] - 12s 812ms/step - loss: 0.6606 - accuracy: 0.6052 - val_loss: 0.6562 - val_accuracy: 0.5982
Epoch 10/15
14/15 [===========================>..] - ETA: 0s - loss: 0.6541 - accuracy: 0.6124
15/15 [==============================] - 12s 811ms/step - loss: 0.6512 - accuracy: 0.6165 - val_loss: 0.6373 - val_accuracy: 0.6484
Epoch 11/15
14/15 [===========================>..] - ETA: 0s - loss: 0.6329 - accuracy: 0.6439
15/15 [==============================] - 12s 799ms/step - loss: 0.6334 - accuracy: 0.6437 - val_loss: 0.6395 - val_accuracy: 0.6406
Epoch 12/15
14/15 [===========================>..] - ETA: 0s - loss: 0.6477 - accuracy: 0.6085
15/15 [==============================] - 11s 758ms/step - loss: 0.6498 - accuracy: 0.6047 - val_loss: 0.6416 - val_accuracy: 0.6183
Epoch 13/15
14/15 [===========================>..] - ETA: 0s - loss: 0.6280 - accuracy: 0.6507
15/15 [==============================] - 12s 825ms/step - loss: 0.6277 - accuracy: 0.6505 - val_loss: 0.6114 - val_accuracy: 0.6507
Epoch 14/15
14/15 [===========================>..] - ETA: 0s - loss: 0.6160 - accuracy: 0.6548
15/15 [==============================] - 12s 795ms/step - loss: 0.6147 - accuracy: 0.6565 - val_loss: 0.6012 - val_accuracy: 0.6730
Epoch 15/15
14/15 [===========================>..] - ETA: 0s - loss: 0.6163 - accuracy: 0.6680
15/15 [==============================] - 12s 793ms/step - loss: 0.6124 - accuracy: 0.6709 - val_loss: 0.6073 - val_accuracy: 0.6629

已标记关键词 清除标记
©️2020 CSDN 皮肤主题: 编程工作室 设计师:CSDN官方博客 返回首页