文档简介:
概述
上一节我们研究了资源部署优化的方法,通过使用单GPU和分布式部署,提升模型训练的效率。本节我们依旧横向展开"横纵式",如 图1 所示,探讨在手写数字识别任务中,为了保证模型的真实效果,在模型训练部分,对模型进行一些调试和优化的方法。

训练过程优化思路主要有如下五个关键环节:
1. 计算分类准确率,观测模型训练效果。
交叉熵损失函数只能作为优化目标,无法直接准确衡量模型的训练效果。准确率可以直接衡量训练效果,但由于其离散性质,不适合做为损失函数优化神经网络。
2. 检查模型训练过程,识别潜在问题。
如果模型的损失或者评估指标表现异常,通常需要打印模型每一层的输入和输出来定位问题,分析每一层的内容来获取错误的原因。
3. 加入校验或测试,更好评价模型效果。
理想的模型训练结果是在训练集和验证集上均有较高的准确率,如果训练集的准确率低于验证集,说明网络训练程度不够;如果训练集的准确率高于验证集,可能是发生了过拟合现象。通过在优化目标中加入正则化项的办法,解决过拟合的问题。
4. 加入正则化项,避免模型过拟合。
飞桨框架支持为整体参数加入正则化项,这是通常的做法。此外,飞桨框架也支持为某一层或某一部分的网络单独加入正则化项,以达到精细调整参数训练的效果。
5. 可视化分析。
用户不仅可以通过打印或使用matplotlib库作图,飞桨还提供了更专业的可视化分析工具VisualDL,提供便捷的可视化分析方法。
计算模型的分类准确率
准确率是一个直观衡量分类模型效果的指标,由于这个指标是离散的,因此不适合作为损失来优化。通常情况下,交叉熵损失越小的模型,分类的准确率也越高。基于分类准确率,我们可以公平地比较两种损失函数的优劣,例如在【手写数字识别】之损失函数章节中均方误差和交叉熵的比较。
使用飞桨提供的计算分类准确率API,可以直接计算准确率。
class paddle.metric.Accuracy
该API的输入参数input为预测的分类结果predict,输入参数label为数据真实的label。飞桨还提供了更多衡量模型效果的计算指标,详细可以查看paddle.meric包下面的API。
在下述代码中,我们在模型前向计算过程forward函数中计算分类准确率,并在训练时打印每个批次样本的分类准确率。
# 加载相关库 import os import random import paddle import numpy as np from PIL import Image
import gzip import json # 定义数据集读取器 def load_data(mode='train'): # 读取数据文件
datafile = './work/mnist.json.gz' print('loading mnist dataset from {} ......'.format(datafile))
data = json.load(gzip.open(datafile)) # 读取数据集中的训练集,验证集和测试集 train_set,
val_set, eval_set = data # 数据集相关参数,图片高度IMG_ROWS, 图片宽度IMG_COLS IMG_ROWS = 28
IMG_COLS = 28 # 根据输入mode参数决定使用训练集,验证集还是测试 if mode == 'train':
imgs = train_set[0]
labels = train_set[1] elif mode == 'valid':
imgs = val_set[0]
labels = val_set[1] elif mode == 'eval':
imgs = eval_set[0]
labels = eval_set[1] # 获得所有图像的数量 imgs_length = len(imgs) # 验证图像数量
和标签数量是否一致 assert len(imgs) == len(labels), \ "length of train_imgs({}) should be
the same as train_labels({})".format( len(imgs), len(labels))
index_list = list(range(imgs_length)) # 读入数据时用到的batchsize BATCHSIZE = 100 #
定义数据生成器 def data_generator(): # 训练模式下,打乱训练数据 if mode == 'train':
random.shuffle(index_list)
imgs_list = []
labels_list = [] # 按照索引读取数据 for i in index_list: # 读取图像和标签,转换其
尺寸和类型 img = np.reshape(imgs[i], [1, IMG_ROWS, IMG_COLS]).astype('float32')
label = np.reshape(labels[i], [1]).astype('int64')
imgs_list.append(img)
labels_list.append(label) # 如果当前数据缓存达到了batch size,就返回一个批次
数据 if len(imgs_list) == BATCHSIZE: yield np.array(imgs_list), np.array(labels_list)
# 清空数据缓存列表 imgs_list = []
labels_list = [] # 如果剩余数据的数目小于BATCHSIZE, # 则剩余数据一起构
成一个大小为len(imgs_list)的mini-batch if len(imgs_list) > 0: yield np.array(imgs_list),
np.array(labels_list) return data_generator
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/layers/ut
ils.py:26: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To s
ilence this warning, use `int` by itself. Doing this will not modify any behavior and is
safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specif
y the precision. If you wish to review your current use, check the release note link
for additional information. Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/r
elease/1.20.0-notes.html#deprecations def convert_to_list(value, n, name, dtype=np.int):
# 定义模型结构 import paddle.nn.functional as F from paddle.nn import Conv2D, Ma
xPool2D, Linear # 多层卷积神经网络实现 class MNIST(paddle.nn.Layer): def __init__(self):
super(MNIST, self).__init__() # 定义卷积层,输出特征通道out_channels设置为20,卷积核的大小
kernel_size为5,卷积步长stride=1,padding=2 self.conv1 = Conv2D(in_channels=1, out_channels=20,
kernel_size=5, stride=1, padding=2) # 定义池化层,池化核的大小kernel_size为2,池化步长为2
self.max_pool1 = MaxPool2D(kernel_size=2, stride=2) # 定义卷积层,输出特征通道out_channels
设置为20,卷积核的大小kernel_size为5,卷积步长stride=1,padding=2 self.conv2 = Conv2D(in_ch
annels=20, out_channels=20, kernel_size=5, stride=1, padding=2) # 定义池化层,池化核的大小k
ernel_size为2,池化步长为2 self.max_pool2 = MaxPool2D(kernel_size=2, stride=2) # 定义一层全
连接层,输出维度是10 self.fc = Linear(in_features=980, out_features=10) # 定义网络前向计算过程,
卷积后紧接着使用池化层,最后使用全连接层计算最终输出 # 卷积层激活函数使用Relu,全连接层激活函数
使用softmax def forward(self, inputs, label): x = self.conv1(inputs)
x = F.relu(x)
x = self.max_pool1(x)
x = self.conv2(x)
x = F.relu(x)
x = self.max_pool2(x)
x = paddle.reshape(x, [x.shape[0], 980])
x = self.fc(x) if label is not None:
acc = paddle.metric.accuracy(input=x, label=label) return x, acc else: return
x #调用加载数据的函数 train_loader = load_data('train') #在使用GPU机器时,可以将use_gpu变量设
置成True use_gpu = True paddle.set_device('gpu:0') if use_gpu else paddle.set_device('cpu')
#仅优化算法的设置有所差别 def train(model): model = MNIST()
model.train() #四种优化算法的设置方案,可以逐一尝试效果 # opt = paddle.optimizer.SGD(lea
rning_rate=0.01, parameters=model.parameters()) # opt = paddle.optimizer.Momentum(learning_r
ate=0.01, momentum=0.9, parameters=model.parameters()) # opt = paddle.optimizer.Adagrad(lear
ning_rate=0.01, parameters=model.parameters()) opt = paddle.optimizer.Adam(learning_rate=0.01,
parameters=model.parameters())
EPOCH_NUM = 5 for epoch_id in range(EPOCH_NUM): for batch_id, data in enumerate(train
_loader()): #准备数据 images, labels = data
images = paddle.to_tensor(images)
labels = paddle.to_tensor(labels) #前向计算的过程 predicts, acc = model(images
, labels) #计算损失,取一个批次样本损失的平均值 loss = F.cross_entropy(predicts, labels)
avg_loss = paddle.mean(loss) #每训练了100批次的数据,打印下当前Loss的情况 if batch_id % 200 == 0:
print("epoch: {}, batch: {}, loss is: {}, acc is {}".format(epoch_id, ba
tch_id, avg_loss.numpy(), acc.numpy())) #后向传播,更新参数,消除梯度的过程 avg_loss.backward()
opt.step()
opt.clear_grad() #保存模型参数 paddle.save(model.state_dict(), 'mnist.pdparams')
#创建模型 model = MNIST() #启动训练过程 train(model)
loading mnist dataset from ./work/mnist.json.gz ...... epoch: 0, batch: 0, loss is: [2.3825366], acc is [0.13] epoch: 0, batch: 200, loss is: [0.07315266], acc is [0.97] epoch: 0, batch: 400, loss is: [0.08943536], acc is [0.97] epoch: 1, batch: 0, loss is: [0.02650153], acc is [0.99] epoch: 1, batch: 200, loss is: [0.00596725], acc is [1.] epoch: 1, batch: 400, loss is: [0.08135935], acc is [0.97] epoch: 2, batch: 0, loss is: [0.04334498], acc is [0.99] epoch: 2, batch: 200, loss is: [0.0617303], acc is [0.98] epoch: 2, batch: 400, loss is: [0.02098986], acc is [0.99] epoch: 3, batch: 0, loss is: [0.04585475], acc is [0.98] epoch: 3, batch: 200, loss is: [0.02449889], acc is [0.99] epoch: 3, batch: 400, loss is: [0.0548584], acc is [0.99] epoch: 4, batch: 0, loss is: [0.11403488], acc is [0.98] epoch: 4, batch: 200, loss is: [0.05665577], acc is [0.99] epoch: 4, batch: 400, loss is: [0.02090033], acc is [0.99]
import paddle.nn.functional as F # 定义模型结构 class MNIST(paddle.nn.Layer): def __init__(self):
super(MNIST, self).__init__() # 定义卷积层,输出特征通道out_channels设置为20,卷积核的大小kernel_
size为5,卷积步长stride=1,padding=2 self.conv1 = Conv2D(in_channels=1, out_channels=20, kernel_size
=5, stride=1, padding=2) # 定义池化层,池化核的大小kernel_size为2,池化步长为2 self.max_pool1 = MaxP
ool2D(kernel_size=2, stride=2) # 定义卷积层,输出特征通道out_channels设置为20,卷积核的大小kernel_s
ize为5,卷积步长stride=1,padding=2 self.conv2 = Conv2D(in_channels=20, out_channels=20, kernel_siz
e=5, stride=1, padding=2) # 定义池化层,池化核的大小kernel_size为2,池化步长为2 self.max_pool2 = M
axPool2D(kernel_size=2, stride=2) # 定义一层全连接层,输出维度是10 self.fc = Linear(in_features=980
, out_features=10) #加入对每一层输入和输出的尺寸和数据内容的打印,根据check参数决策是否打印每层的参
数和输出尺寸 # 卷积层激活函数使用Relu,全连接层激活函数使用softmax def forward(self, inputs, labe
l=None, check_shape=False, check_content=False): # 给不同层的输出不同命名,
方便调试 outputs1 = self.conv1(inputs)
outputs2 = F.relu(outputs1)
outputs3 = self.max_pool1(outputs2)
outputs4 = self.conv2(outputs3)
outputs5 = F.relu(outputs4)
outputs6 = self.max_pool2(outputs5)
outputs6 = paddle.reshape(outputs6, [outputs6.shape[0], -1])
outputs7 = self.fc(outputs6) # 选择是否打印神经网络每层的参数尺寸和输出尺寸,
验证网络结构是否设置正确 if check_shape: # 打印每层网络设置的超参数-卷积核尺寸,卷积步长,
卷积padding,池化核尺寸 print("\n########## print network layer's superparams ##############")
print("conv1-- kernel_size:{}, padding:{}, stride:{}".format(self.conv1.weight.s
hape, self.conv1._padding, self.conv1._stride))
print("conv2-- kernel_size:{}, padding:{}, stride:{}".format(self.conv2.weight.
shape, self.conv2._padding, self.conv2._stride)) #print("max_pool1-- kernel_size:{}, padding:
{}, stride:{}".format(self.max_pool1.pool_size, self.max_pool1.pool_stride, self.max_pool1._
stride)) #print("max_pool2-- kernel_size:{}, padding:{}, stride:{}".format(self.max_pool2.we
ight.shape, self.max_pool2._padding, self.max_pool2._stride)) print("fc-- weight_size:{}, bia
s_size_{}".format(self.fc.weight.shape, self.fc.bias.shape)) # 打印每层的输出尺寸 print("\n##
######## print shape of features of every layer ###############")
print("inputs_shape: {}".format(inputs.shape))
print("outputs1_shape: {}".format(outputs1.shape))
print("outputs2_shape: {}".format(outputs2.shape))
print("outputs3_shape: {}".format(outputs3.shape))
print("outputs4_shape: {}".format(outputs4.shape))
print("outputs5_shape: {}".format(outputs5.shape))
print("outputs6_shape: {}".format(outputs6.shape))
print("outputs7_shape: {}".format(outputs7.shape)) # print("outputs8_shape: {}".
format(outputs8.shape)) # 选择是否打印训练过程中的参数和输出内容,可用于训练过程中的调试 if
check_content: # 打印卷积层的参数-卷积核权重,权重参数较多,此处只打印部分参数 print("\n######
#### print convolution layer's kernel ###############")
print("conv1 params -- kernel weights:", self.conv1.weight[0][0])
print("conv2 params -- kernel weights:", self.conv2.weight[0][0]) # 创建随机数,
随机打印某一个通道的输出值 idx1 = np.random.randint(0, outputs1.shape[1])
idx2 = np.random.randint(0, outputs4.shape[1]) # 打印卷积-池化后的结果,仅打印batc
h中第一个图像对应的特征 print("\nThe {}th channel of conv1 layer: ".format(idx1), outputs1[0][idx1])
print("The {}th channel of conv2 layer: ".format(idx2), outputs4[0][idx2])
print("The output of last layer:", outputs7[0], '\n') # 如果label不是None,
则计算分类精度并返回 if label is not None:
acc = paddle.metric.accuracy(input=F.softmax(outputs7), label=label) return
outputs7, acc else: return outputs7 #在使用GPU机器时,可以将use_gpu变量设置成True use_gpu
= True paddle.set_device('gpu:0') if use_gpu else paddle.set_device('cpu')
def train(model): model = MNIST()
model.train() #四种优化算法的设置方案,可以逐一尝试效果 opt = paddle.optimizer.SGD
(learning_rate=0.01, parameters=model.parameters()) # opt = paddle.optimizer.Momentum
(learning_rate=0.01, momentum=0.9, parameters=model.parameters()) # opt = paddle.optimizer.
Adagrad(learning_rate=0.01, parameters=model.parameters()) # opt = paddle.optimizer.Adam
(learning_rate=0.01, parameters=model.parameters()) EPOCH_NUM = 1 for epoch_id in range
(EPOCH_NUM): for batch_id, data in enumerate(train_loader()): #准备数据,变得更加简洁
images, labels = data
images = paddle.to_tensor(images)
labels = paddle.to_tensor(labels) #前向计算的过程,同时拿到模型输出值和分类准
确率 if batch_id == 0 and epoch_id==0: # 打印模型参数和每层输出的尺寸 predicts, acc = mode
l(images, labels, check_shape=True, check_content=False) elif batch_id==401: # 打印模型参数
和每层输出的值 predicts, acc = model(images, labels, check_shape=False, check_content=True) else:
predicts, acc = model(images, labels) #计算损失,取一个批次样本损失的平均值
loss = F.cross_entropy(predicts, labels)
avg_loss = paddle.mean(loss) #每训练了100批次的数据,打印下当前
Loss的情况 if batch_id % 200 == 0:
print("epoch: {}, batch: {}, loss is: {}, acc is {}".format(epoch_id,
batch_id, avg_loss.numpy(), acc.numpy())) #后向传播,更新参数的过程 avg_loss.backward()
opt.step()
opt.clear_grad() #保存模型参数 paddle.save(model.state_dict(), 'mnist_test
.pdparams') #创建模型 model = MNIST() #启动训练过程 train(model)
print("Model has been saved.")
########## print network layer's superparams ############## conv1-- kernel_size:[20, 1, 5, 5], padding:2, stride:[1, 1] conv2-- kernel_size:[20, 20, 5, 5], padding:2, stride:[1, 1] fc-- weight_size:[980, 10], bias_size_[10] ########## print shape of features of every layer ############### inputs_shape: [100, 1, 28, 28] outputs1_shape: [100, 20, 28, 28] outputs2_shape: [100, 20, 28, 28] outputs3_shape: [100, 20, 14, 14] outputs4_shape: [100, 20, 14, 14] outputs5_shape: [100, 20, 14, 14] outputs6_shape: [100, 980] outputs7_shape: [100, 10] epoch: 0, batch: 0, loss is: [2.4391866], acc is [0.06] epoch: 0, batch: 200, loss is: [0.29858145], acc is [0.9] epoch: 0, batch: 400, loss is: [0.36437622], acc is [0.88] ########## print convolution layer's kernel ############### conv1 params -- kernel weights: Tensor(shape=[5, 5], dtype=float32, place=CUDAPlace(0), stop_gradient=False, [[-0.31532317, 0.60181510, -0.38462862, 0.52080828, -0.37432989], [ 0.20288531, -0.56747001, 0.53627175, -0.04442852, -0.41307986], [ 0.42562252, -0.15216339, 0.33505937, 0.12434796, 0.50957596], [ 0.03782330, -0.29627943, 0.44766814, -0.04806038, 0.57602686], [ 0.50651461, 0.07721443, -0.27299264, 0.09201541, 0.49806190]]) conv2 params -- kernel weights: Tensor(shape=[5, 5], dtype=float32, place=CUDAPlace(0), stop_gradient=False, [[ 0.06146792, 0.01054612, 0.02672482, 0.04353787, 0.11069481], [-0.06992761, 0.06845246, 0.01267519, 0.01942314, -0.09324966], [ 0.03526127, -0.13740280, -0.01524128, -0.04532520, 0.07689699], [-0.11563610, -0.02698515, -0.13878933, -0.07280355, -0.09811222], [ 0.01492384, -0.00623472, 0.06594540, -0.14806092, 0.09876055]]) The 13th channel of conv1 layer: Tensor(shape=[28, 28], dtype=float32, place=CUDAPlace(0), stop_gradient=False, [[ 0.00016711, 0.00016711, 0.00016711, 0.00016711, 0.00016711, 0.00016711,
0.00016711, 0.00016711, 0.00016711, 0.00004595, 0.00016176, 0.00014502, 0.00013599,
0.00014556, 0.00012218, 0.00014980, 0.00008315, 0.00014404, 0.00011110, 0.00014935,
0.00016037, 0.00016016, 0.00016134, 0.00016354, 0.00015590, 0.00015430, 0.00010933,
0.00016711], [ 0.00016711, 0.00016711, 0.00016711, 0.00016711, 0.00016711, 0.00016711,
0.00016711, 0.00016711, 0.00016711, 0.00018658, 0.00017924, 0.00015597, 0.00017241, 0.00015906,
0.00016783, 0.00015676, 0.00016994, 0.00014616, 0.00016475, 0.00016585, 0.00016795, 0.00016773,
0.00016750, 0.00016457, 0.00016943, 0.00015349, 0.00017223, 0.00016711], [ 0.00016711, 0.00016711, 0.00016711, 0.00016711, 0.00016711, 0.00016711, 0.00016711,
0.00016711, 0.00016711, 0.00016333, 0.00016807, 0.00016483, 0.16958664, 0.53163326, 0.71812445,
0.67207003, 0.60056192, 0.55837387, 0.26870430, -0.18754219, -0.22095074, -0.08460066, -0.07424995,
0.00016539, 0.00016385, 0.00016229, 0.00015654, 0.00016711], [ 0.00016711, 0.00016711, 0.00016711, 0.00016711, 0.00016711, 0.00016711, 0.00016711,
0.00016711, 0.00016711, 0.00016562, 0.00018135, 0.37641406, 0.72576272, 0.60050440, 0.71715701,
0.79094279, 0.73734343, 0.50956964, 0.28189507, 0.22065853, 0.08193393, 0.00188888, 0.06607085,
0.00016618, 0.00016735, 0.00016133, 0.00017266, 0.00016711], [ 0.00016711, 0.00016711, 0.00016711, 0.00016711, 0.00016711, 0.00016711, 0.00016711,
0.00016711, 0.00016711, 0.00020304, 0.00015474, 0.22126870, 0.67534292, 0.74925864, 0.51921457,
0.51779264, 0.64921039, 0.74167162, 0.54174405, 0.35547486, 0.42750961, 0.43111908, 0.19365118,
0.00016654, 0.00016618, 0.00016458, 0.00016853, 0.00016711], [ 0.00016711, 0.00016711, 0.00016711, 0.00016711, 0.00016711, 0.00016711, 0.00016711,
0.00016711, 0.00016711, 0.00014616, 0.00020524, -0.09152068, 0.26438427, 0.97282660, 1.26416242,
1.96458459, 2.45829511, 2.81876373, 2.77497649, 1.87550914, 1.01585639, 0.66619283, 0.15582287,
-0.02733525, 0.00016891, 0.00016170, 0.00018208, 0.00016711], [ 0.00016711, 0.00016711, 0.00016711, 0.00016711, 0.00016711, 0.00016711, 0.00016711,
0.00016711, 0.00016711, 0.00027392, 0.00010776, 0.22614142, 0.44661298, 0.92685556, 1.52151585,
2.11148071, 2.45571613, 2.48675108, 2.32740545, 1.79199696, 0.88100439, 0.00357753, -0.01259188,
-0.08748826, 0.00016679, 0.00016381, 0.00017177, 0.00016711], [ 0.00016711, 0.00016711, 0.00016711, 0.00016711, 0.00016711, 0.00016711, 0.00016711,
0.00016711, 0.00016711, 0.00004795, 0.00031215, 0.40498218, 0.91940767, 0.97722036, 1.26561487
, 1.46854281, 1.70221531, 1.85662627, 1.86650336, 1.52954447, 1.14554369, 0.66387773, 0.38482514
0.10197513, 0.00017440, 0.00015198, 0.00023071, 0.00016711], [ 0.00016711, 0.00016711, 0.00016711, 0.00016711, 0.00016711, 0.00016711, 0.00016711,
0.00016711, 0.00016711, 0.00041899, -0.00007986, 0.19968341, 0.77158284, 1.04039311, 1.17114949,
1.42035317, 1.75722051, 1.79016709, 1.65149999, 1.27333283, 0.77952206, 0.91363120, 0.71511328,
0.44634739, 0.00016911, 0.00014315, 0.00015822, 0.00016711], [ 0.00016876, 0.00016417, 0.00016747, 0.00016646, 0.00016660, 0.00016621, 0.00016568,
0.00016777, 0.00015577, 0.00047584, 0.00003945, 0.03368439, 0.36092034, 0.63123375, 0.47924510,
0.83292925, 0.95876712, 1.08555090, 1.14672971, 0.78074634, 0.53084153, 0.95858377, 0.78540272,
0.57809430, 0.00016151, 0.00007481, 0.00012390, 0.00016711], [ 0.00017036, 0.00016666, 0.00016704, 0.00016713, 0.00016723, 0.00016587, 0.00016669,
0.00016431, 0.00016569, 0.00018057, 0.40719602, 0.74570280, 0.55773079, 0.63881069, 0.66206867,
0.58057255, 0.45469847, 0.64718813, 0.97403514, 1.21538293, 0.96643949, 0.83140928, 0.79778320,
0.15686706, 0.00016446, 0.00015559, 0.00016447, 0.00016711], [ 0.00016735, 0.00016638, 0.00016730, 0.00016633, 0.00016686, 0.00016608, 0.00016683,
0.00016611, 0.00016590, 0.00020591, 0.18480885, 0.72011077, 0.92727220, 0.67771971, 0.72491133,
0.47252920, 0.30209750, 0.89913386, 1.44397974, 1.58122110, 1.38958371, 0.85136491, 0.44405198,
-0.06527461, 0.00016577, 0.00015594, 0.00017212, 0.00016711], [ 0.00016697, 0.00016755, 0.00016656, 0.00016725, 0.00016649, 0.00016684, 0.00016650,
0.00016628, 0.00016599, 0.00013974, 0.19184896, 0.31140241, 0.59208280, 0.96577585, 1.26583183,
1.61420250, 1.71820664, 1.89530361, 2.05936217, 1.61399519, 0.88408417, 0.33046040, -0.03023660,
-0.14290929, 0.00016754, 0.00016573, 0.00016958, 0.00016711], [ 0.00016856, 0.00016593, 0.00016812, 0.00016673, 0.00016780, 0.00016638, 0.00016722,
0.00016513, 0.00016659, 0.00018709, 0.18783921, 0.48937386, 1.08573532, 1.59907711, 2.51122952,
3.28006864, 2.87848496, 1.88626552, 1.30772388, 0.49677867, -0.07801100, -0.18191029, -0.18105611,
0.00016679, 0.00016702, 0.00016487, 0.00017237, 0.00016711], [ 0.00016427, 0.00016872, 0.00016590, 0.00016805, 0.01336800, 0.02355008, 0.01706577,
0.01143187, -0.00421693, -0.00622759, 0.31881458, 0.73334509, 0.86253679, 1.24155474, 1.74971414,
2.33843040, 2.27464581, 1.16401732, 0.68508059, 0.09359124, -0.18980481, -0.06810196, -0.00996031,
0.00016673, 0.00016651, 0.00016944, 0.00015844, 0.00016711], [ 0.00017200, 0.00016413, 0.25320548, 0.42815340, 0.50178123, 0.73268479, 0.58233982,
0.16096711, -0.24943544, -0.12008493, 0.32758334, 0.96908677, 1.21319628, 1.38737893, 1.48044443,
1.44380271, 1.48257983, 0.84109271, 0.61395234, 0.47253403, 0.14157712, 0.00016949, 0.00016579,
0.00016650, 0.00017089, 0.00015875, 0.00019435, 0.00016711], [ 0.00015650, 0.00017225, 0.28561634, 0.77051896, 0.78538889, 0.35063624, 0.11712936,
0.24221942, 0.22844182, 0.05304031, 0.25216714, 0.38469917, 0.45140880, 0.63924140, 0.74463111,
1.04575050, 1.15566206, 0.93379205, 0.52104980, 0.73290265, 0.30839062, -0.01601113, 0.00017011,
0.00017062, 0.00015724, 0.00019127, 0.00011058, 0.00016711], [ 0.00018048, 0.14097929, 0.46947938, 0.56473351, 0.35954222, 0.05292580, 0.01556752,
0.02259169, 0.45759675, 0.52739620, 0.43067965, 0.12061904, 0.09104446, 0.14031146, 0.65303063,
1.42140603, 0.67411160, 0.32287022, 0.33877966, 0.54029298, 0.33827075, 0.01649187, 0.00016520,
0.00015668, 0.00019127, 0.00010782, 0.00027762, 0.00016711], [-0.00031993, -0.03555977, 0.55035859, 1.08602536, 1.28073096, 1.68930268, 1.88079488,
1.53508520, 1.25073099, 0.76980925, 0.40042147, 0.10997950, 0.15476465, 0.21629158, 0.53520423,
1.34934044, 1.21959817, 0.39389715, 0.23692483, 0.60916984, 0.43032351, 0.04772065, 0.00016668,
0.00016761, 0.00016549, 0.00016549, 0.00016416, 0.00016711], [ 0.00021291, -0.03529413, 0.50253916, 1.36402488, 1.87594378, 2.22709346, 2.40977383,
1.80330014, 0.58706552, 0.24696761, -0.05938473, 0.05615317, 0.66042840, 0.63430697, 0.47617677,
0.45704794, 0.80879241, 0.62592763, 0.74972111, 0.71725625, 0.34654897, 0.04229181, 0.00016717,
0.00016691, 0.00016681, 0.00016611, 0.00016686, 0.00016711],
[ 0.00013232, 0.12255687, 0.47535741, 1.14842176, 1.82607973, 1.68255222, 1.76728773,
1.66719759, 1.32665157, 0.99547255, 0.84857190, 0.74153352, 0.47670120, 0.67138344, 0.54188758,
0.72325122, 1.06826282, 0.63845855, 1.04162419, 0.73475826, 0.07851221, -0.03781775,
0.00016703, 0.00016711, 0.00016690, 0.00016680, 0.00016687, 0.00016711], [ 0.00019540, 0.05650518, 0.60105032, 0.81867039, 0.85973227, 1.35835588, 1.77878034,
1.69876838, 1.71830189, 1.50534701, 1.21155989, 1.06631446, 0.73069352, 0.46774510,
0.80238563, 1.61473632, 1.66589832, 1.25347912, 1.01117051, 0.27387971, 0.05848682, 0.00016711,
0.00016712, 0.00016715, 0.00016716, 0.00016689, 0.00016757, 0.00016711], [ 0.00012508, 0.00019216, 0.58093423, 1.13065028, 1.00387037, 1.14928293, 1.10693347,
1.26197362, 1.34477246, 1.47254610, 1.40870392, 1.37475395, 1.71375358, 1.96313608,
2.05283213, 2.09536409, 1.51032364, 1.04987156, 0.22829457, -0.16366917, -0.16316731,
0.00016706, 0.00016711, 0.00016713, 0.00016709, 0.00016711, 0.00016709, 0.00016711], [ 0.00025568, 0.00013889, 0.17277181, 0.59526175, 0.89919132, 1.16894102, 1.05133772,
1.14531720, 1.22429037, 1.48390639, 1.62546599, 1.71606779, 2.12076283, 2.12293720, 2.06063795,
1.39157844, 0.60277039, 0.25422239, -0.16194353, -0.05991813, 0.00016719, 0.00016719, 0.00016718,
0.00016717, 0.00016726, 0.00016712, 0.00016776, 0.00016711], [ 0.00003710, 0.00021211, 0.00014396, 0.10761575, 0.34839991, 0.56968033, 0.86179048,
1.07601869, 1.23462546, 1.48044193, 1.66117048, 1.73838699, 1.56631613, 1.18525302, 0.69454443,
0.24370998, -0.02306978, -0.20283933, -0.11973704, -0.00616391, 0.00016712, 0.00016708, 0.00016717,
0.00016717, 0.00016705, 0.00016733, 0.00016661, 0.00016711], [ 0.00038121, 0.00010757, 0.00021565, 0.00013391, 0.02309544, 0.12140477, 0.24009685,
0.28825933, 0.26269117, 0.13363461, 0.04801013, -0.02159940, -0.10086482, -0.26666415, -0.32494220,
-0.30488664, -0.18981682, -0.06383128, 0.00016595, 0.00016799, 0.00016708, 0.00016729, 0.00016729,
0.00016722, 0.00016755, 0.00016706, 0.00016872, 0.00016711], [-0.00008379, 0.00026677, 0.00010417, 0.00023387, 0.00009082, 0.00032542, -0.00004365,
0.00075457, -0.00016095, 0.00034735, 0.00019286, 0.00014804, 0.00016711, 0.00019000, 0.00027011,
0.00036691, 0.00031016, 0.00167845, 0.00016473, 0.00016591, 0.00016723, 0.00016709, 0.00016721,
0.00016740, 0.00016673, 0.00016812, 0.00016597, 0.00016711], [ 0.00016711, 0.00016711, 0.00016711, 0.00016711, 0.00016711, 0.00016711, 0.00016711,
0.00016711, 0.00016711, 0.00016711, 0.00016711, 0.00016711, 0.00016711, 0.00016711, 0.00016711,
0.00016711, 0.00016711, 0.00016711, 0.00016711, 0.00016711, 0.00016711, 0.00016711, 0.00016711,
0.00016711, 0.00016711, 0.00016711, 0.00016711, 0.00016711]]) The 11th channel of conv2 layer: Tensor(shape=[14, 14], dtype=float32, place=CUDAPlace(0), stop_gradient=False, [[ 0.00312751, 0.00586835, 0.00110976, -0.02058463, -0.33313772, -0.62154233, -0.90149099,
-1.24104977, -1.35232341, -0.99607420, -0.53392363, -0.08951769, -0.02691137, 0.05222606], [-0.00647034, -0.00965514, -0.01638876, -0.11358240, -0.32065463, -0.75975877, -1.26381135,
-2.34910107, -2.77629328, -2.35288906, -1.78367567, -0.70091587, -0.42589089, 0.01768003], [ 0.00256004, 0.00228425, -0.00861163, -0.02454728, -0.18787183, -0.89747912, -1.79948258,
-2.88551641, -3.33107972, -2.52407384, -2.35693097, -1.63116193, -0.55040479, 0.23051211], [ 0.00255759, 0.00228819, -0.00865710, 0.16160299, 0.24479951, 0.23080023, -0.57432461,
-1.87221742, -2.81458592, -2.86401129, -2.53954935, -2.01755428, -0.36022085, 0.12348694], [ 0.00256483, 0.00229422, -0.00844403, 0.22002032, 0.21440531, 0.62519586, 0.21764396,
-0.34551919, -2.15028787, -2.49198937, -2.40828681, -1.67824960, -0.72269619, -0.10116137], [-0.12288841, -0.23138289, -0.35438883, -0.19345880, 0.18002844, 0.08512699, -0.37421370,
-1.37870026, -2.45267701, -2.69670653, -2.26589084, -1.20335031, -0.90513039, -0.25440046], [-0.50523454, -0.67776489, -0.77518046, -0.35613060, -0.86300278, -0.74686277, -1.17987955,
-1.08464479, -1.39198148, -1.40300453, -0.92244530, -0.48179084, -1.00485647, -0.33450115], [-0.73795503, -0.79777455, -1.01502419, -0.62378365, -0.52079970, -0.36418855, -1.14167190,
-1.17794013, -1.00021207, -1.62142158, -0.82172585, -0.27296549, -0.73264360, -0.26811692], [-0.81346792, -1.18431234, -1.88930893, -1.78307796, -1.53371751, -0.34386659, -0.13295141,
-1.24216175, -1.52532828, -1.83941269, -1.25321150, -0.87356132, -0.14615494, -0.14799972], [-0.38031965, -1.09076464, -1.82781386, -2.51046777, -2.10738373, -1.40568686, -0.89903629,
-1.48535275, -2.26525354, -2.02210355, -1.25348115, -0.86374021, -0.02413686, -0.01221605], [ 0.58205724, -0.61353540, -1.69562149, -2.11039186, -2.25769949, -2.21963143, -1.71002507,
-1.48456597, -1.34011984, -0.75660688, -0.83922243, -0.91332853, -0.10387615, -0.00901008], [ 1.60166466, 1.11056924, 0.17553504, -0.76633096, -1.22525370, -1.06701612, -0.58505929,
-0.80741614, -0.33068943, 0.10534806, -0.55697370, -0.86783463, -0.13269022, -0.00899511], [ 1.41887498, 1.19129920, 1.14289379, 0.72163981, 0.07429478, -0.24150907, -0.04159492
, -0.18197620, -0.24404357, -0.23238824, -0.54796576, -0.48223227, -0.11518352, -0.00972981], [ 1.05995369, 1.05843043, 1.06410861, 1.29815662, 1.07523298, 0.58790469, 0.26610154,
-0.03364890, -0.12993495, -0.43510956, -0.38913697, -0.18151908, -0.06200687, -0.01186676]]) The output of last layer: Tensor(shape=[10], dtype=float32, place=CUDAPlace(0), stop_gradient=False, [ 1.45627880, -0.73663777, 1.47225618, 8.28572273, -4.86152220, 0.67528987, -3.78966165
, -3.68850040, 3.13669801, -2.26959062]) Model has been saved.
加入校验或测试,更好评价模型效果
在训练过程中,我们会发现模型在训练样本集上的损失在不断减小。但这是否代表模型在未来的应用场景上依然有效?为了验证模型的有效性,通常将样本集合分成三份,训练集、校验集和测试集。
- 训练集 :用于训练模型的参数,即训练过程中主要完成的工作。
- 校验集 :用于对模型超参数的选择,比如网络结构的调整、正则化项权重的选择等。
- 测试集 :用于模拟模型在应用后的真实效果。因为测试集没有参与任何模型优化或参数训练的工作,所以它对模型来说是完全未知的样本。在不以校验数据优化网络结构或模型超参数时,校验数据和测试数据的效果是类似的,均更真实的反映模型效果。
如下程序读取上一步训练保存的模型参数,读取校验数据集,并测试模型在校验数据集上的效果。
def evaluation(model): print('start evaluation .......') # 定义预测过程 params_file_path =
'mnist.pdparams' # 加载模型参数 param_dict = paddle.load(params_file_path)
model.load_dict(param_dict)
model.eval()
eval_loader = load_data('eval')
acc_set = []
avg_loss_set = [] for batch_id, data in enumerate(eval_loader()):
images, labels = data
images = paddle.to_tensor(images)
labels = paddle.to_tensor(labels)
predicts, acc = model(images, labels)
loss = F.cross_entropy(input=predicts, label=labels)
avg_loss = paddle.mean(loss)
acc_set.append(float(acc.numpy()))
avg_loss_set.append(float(avg_loss.numpy())) #计算多个batch的平均损失和准确率 acc_val_mean = np.array(acc_set).mean()
avg_loss_val_mean = np.array(avg_loss_set).mean()
print('loss={}, acc={}'.format(avg_loss_val_mean, acc_val_mean))
model = MNIST()
evaluation(model)
start evaluation ....... loading mnist dataset from ./work/mnist.json.gz ...... loss=0.05001451548640034, acc=0.9861000066995621
从测试的效果来看,模型在验证集上依然有98.6%的准确率,证明它是有预测效果的。
加入正则化项,避免模型过拟合
过拟合现象
对于样本量有限、但需要使用强大模型的复杂任务,模型很容易出现过拟合的表现,即在训练集上的损失小,在验证集或测试集上的损失较大,如 图2 所示。

图2:过拟合现象,训练误差不断降低,但测试误差先降后增
反之,如果模型在训练集和测试集上均损失较大,则称为欠拟合。过拟合表示模型过于敏感,学习到了训练数据中的一些误差,而这些误差并不是真实的泛化规律(可推广到测试集上的规律)。欠拟合表示模型还不够强大,还没有很好的拟合已知的训练样本,更别提测试样本了。因为欠拟合情况容易观察和解决,只要训练loss不够好,就不断使用更强大的模型即可,因此实际中我们更需要处理好过拟合的问题。
导致过拟合原因
造成过拟合的原因是模型过于敏感,而训练数据量太少或其中的噪音太多。
如图3 所示,理想的回归模型是一条坡度较缓的抛物线,欠拟合的模型只拟合出一条直线,显然没有捕捉到真实的规律,但过拟合的模型拟合出存在很多拐点的抛物线,显然是过于敏感,也没有正确表达真实规律。

图3:回归模型的过拟合,理想和欠拟合状态的表现
如图4 所示,理想的分类模型是一条半圆形的曲线,欠拟合用直线作为分类边界,显然没有捕捉到真实的边界,但过拟合的模型拟合出很扭曲的分类边界,虽然对所有的训练数据正确分类,但对一些较为个例的样本所做出的妥协,高概率不是真实的规律。

图4:分类模型的欠拟合,理想和过拟合状态的表现
过拟合的成因与防控
为了更好的理解过拟合的成因,可以参考侦探定位罪犯的案例逻辑,如 图5 所示。

图5:侦探定位罪犯与模型假设示意
对于这个案例,假设侦探也会犯错,通过分析发现可能的原因:
-
情况1:罪犯证据存在错误,依据错误的证据寻找罪犯肯定是缘木求鱼。
-
情况2:搜索范围太大的同时证据太少,导致符合条件的候选(嫌疑人)太多,无法准确定位罪犯。
那么侦探解决这个问题的方法有两种:或者缩小搜索范围(比如假设该案件只能是熟人作案),或者寻找更多的证据。
归结到深度学习中,假设模型也会犯错,通过分析发现可能的原因:
-
情况1:训练数据存在噪音,导致模型学到了噪音,而不是真实规律。
-
情况2:使用强大模型(表示空间大)的同时训练数据太少,导致在训练数据上表现良好的候选假设太多,锁定了一个“虚假正确”的假设。
对于情况1,我们使用数据清洗和修正来解决。 对于情况2,我们或者限制模型表示能力,或者收集更多的训练数据。
而清洗训练数据中的错误,或收集更多的训练数据往往是一句“正确的废话”,在任何时候我们都想获得更多更高质量的数据。在实际项目中,更快、更低成本可控制过拟合的方法,只有限制模型的表示能力。
正则化项
为了防止模型过拟合,在没有扩充样本量的可能下,只能降低模型的复杂度,可以通过限制参数的数量或可能取值(参数值尽量小)实现。
具体来说,在模型的优化目标(损失)中人为加入对参数规模的惩罚项。当参数越多或取值越大时,该惩罚项就越大。通过调整惩罚项的权重系数,可以使模型在“尽量减少训练损失”和“保持模型的泛化能力”之间取得平衡。泛化能力表示模型在没有见过的样本上依然有效。正则化项的存在,增加了模型在训练集上的损失。
飞桨支持为所有参数加上统一的正则化项,也支持为特定的参数添加正则化项。前者的实现如下代码所示,仅在优化器中设置weight_decay参数即可实现。使用参数coeff调节正则化项的权重,权重越大时,对模型复杂度的惩罚越高。
def train(model): model.train() #各种优化算法均可以加入正则化项,避免过拟合,参数
regularization_coeff调节正则化项的权重 opt = paddle.optimizer.Adam(learning_rate=0.01,
weight_decay=paddle.regularizer.L2Decay(coeff=1e-5), parameters=model.parameters())
EPOCH_NUM = 5 for epoch_id in range(EPOCH_NUM): for batch_id, data in enumerate
(train_loader()): #准备数据,变得更加简洁 images, labels = data
images = paddle.to_tensor(images)
labels = paddle.to_tensor(labels) #前向计算的过程,同时拿到模型输出值和分
类准确率 predicts, acc = model(images, labels) #计算损失,取一个批次样本损失的平均值
loss = F.cross_entropy(predicts, labels)
avg_loss = paddle.mean(loss) #每训练了100批次的数据,打印下当前Loss的情况 if batch_id % 200 == 0:
print("epoch: {}, batch: {}, loss is: {}, acc is {}".format(epoch_id,
batch_id, avg_loss.numpy(), acc.numpy())) #后向传播,更新参数的过程 avg_loss.backward()
opt.step()
opt.clear_grad() #保存模型参数 paddle.save(model.state_dict(), 'mnist_regul.pdparams')
model = MNIST()
train(model)
epoch: 0, batch: 0, loss is: [2.625702], acc is [0.14] epoch: 0, batch: 200, loss is: [0.19216107], acc is [0.96] epoch: 0, batch: 400, loss is: [0.10491196], acc is [0.96] epoch: 1, batch: 0, loss is: [0.04200032], acc is [0.98] epoch: 1, batch: 200, loss is: [0.03362055], acc is [1.] epoch: 1, batch: 400, loss is: [0.10716258], acc is [0.98] epoch: 2, batch: 0, loss is: [0.06904294], acc is [0.97] epoch: 2, batch: 200, loss is: [0.05971422], acc is [0.98] epoch: 2, batch: 400, loss is: [0.04597992], acc is [0.99] epoch: 3, batch: 0, loss is: [0.04722928], acc is [0.97] epoch: 3, batch: 200, loss is: [0.01793179], acc is [0.99] epoch: 3, batch: 400, loss is: [0.11060252], acc is [0.98] epoch: 4, batch: 0, loss is: [0.01327007], acc is [1.] epoch: 4, batch: 200, loss is: [0.07705414], acc is [0.98] epoch: 4, batch: 400, loss is: [0.02693538], acc is [0.99]
def evaluation(model): print('start evaluation .......') # 定义预测过程 params_file_path
= 'mnist_regul.pdparams' # 加载模型参数 param_dict = paddle.load(params_file_path)
model.load_dict(param_dict)
model.eval()
eval_loader = load_data('eval')
acc_set = []
avg_loss_set = [] for batch_id, data in enumerate(eval_loader()):
images, labels = data
images = paddle.to_tensor(images)
labels = paddle.to_tensor(labels)
predicts, acc = model(images, labels)
loss = F.cross_entropy(input=predicts, label=labels)
avg_loss = paddle.mean(loss)
acc_set.append(float(acc.numpy()))
avg_loss_set.append(float(avg_loss.numpy())) #计算多个batch的平均损失和准确率 acc_val_mean = np.array(acc_set).mean()
avg_loss_val_mean = np.array(avg_loss_set).mean()
print('loss={}, acc={}'.format(avg_loss_val_mean, acc_val_mean))
model = MNIST()
evaluation(model)
start evaluation ....... loading mnist dataset from ./work/mnist.json.gz ...... loss=1.4758862841129303, acc=0.9866000074148178
可视化分析
训练模型时,经常需要观察模型的评价指标,分析模型的优化过程,以确保训练是有效的。可选用这两种工具:Matplotlib库和VisualDL。
- Matplotlib库:Matplotlib库是Python中使用的最多的2D图形绘图库,它有一套完全仿照MATLAB的函数形式的绘图接口,使用轻量级的PLT库(Matplotlib)作图是非常简单的。
- VisualDL:如果期望使用更加专业的作图工具,可以尝试VisualDL,飞桨可视化分析工具。VisualDL能够有效地展示飞桨在运行过程中的计算图、各种指标变化趋势和数据信息。
使用Matplotlib库绘制损失随训练下降的曲线图
将训练的批次编号作为X轴坐标,该批次的训练损失作为Y轴坐标。
- 训练开始前,声明两个列表变量存储对应的批次编号(iters=[])和训练损失(losses=[])。
iters=[]
losses=[] for epoch_id in range(EPOCH_NUM): """start to training"""
- 随着训练的进行,将iter和losses两个列表填满。
import paddle.nn.functional as F
iters=[] losses=[] for epoch_id in range(EPOCH_NUM): for batch_id, data in enumerate(train_loader()):
images, labels = data
predicts, acc = model(images, labels)
loss = F.cross_entropy(predicts, label = labels.astype('int64'))
avg_loss = paddle.mean(loss)
# 累计迭代次数和对应的loss
iters.append(batch_id + epoch_id*len(list(train_loader()))
losses.append(avg_loss)
- 训练结束后,将两份数据以参数形式导入PLT的横纵坐标。
plt.xlabel("iter", fontsize=14),plt.ylabel("loss", fontsize=14)
- 最后,调用plt.plot()函数即可完成作图。
plt.plot(iters, losses,color='red',label='train loss')
详细代码如下:
#引入matplotlib库 import matplotlib.pyplot as plt def train(model): model.train()
opt = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())
EPOCH_NUM = 10 iter=0 iters=[]
losses=[] for epoch_id in range(EPOCH_NUM): for batch_id, data in enumerate(train_loader())
: #准备数据,变得更加简洁 images, labels = data
images = paddle.to_tensor(images)
labels = paddle.to_tensor(labels) #前向计算的过程,同时拿到模型输出值和分类准确率
predicts, acc = model(images, labels) #计算损失,取一个批次样本损失的平均值
loss = F.cross_entropy(predicts, labels)
avg_loss = paddle.mean(loss) #每训练了100批次的数据,打印下当前Loss的情况 if batch_id % 100 == 0:
print("epoch: {}, batch: {}, loss is: {}, acc is {}".format
(epoch_id, batch_id, avg_loss.numpy(), acc.numpy()))
iters.append(iter)
losses.append(avg_loss.numpy()) iter = iter + 100 #后向传播,更新参数的过程 avg_loss.backward()
opt.step()
opt.clear_grad() #保存模型参数 paddle.save(model.state_dict(), 'mnist.pdparams') return iters, losses
model = MNIST()
iters, losses = train(model)
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/__init
__.py:107: DeprecationWarning: Using or importing the ABCs from 'collections' instead
of from 'collections.abc' is deprecated, and in 3.8 it will stop working from collections import MutableMapping /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.
py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of
from 'collections.abc' is deprecated, and in 3.8 it will stop working from collections import Iterable, Mapping /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.
py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of
from 'collections.abc' is deprecated, and in 3.8 it will stop working from collections import Sized 2021-07-19 19:31:46,551 - INFO - font search path ['/opt/conda/envs/python35-paddle12
0-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/ttf', '/opt/conda/envs/py
thon35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/afm', '/op
t/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/pdfcorefonts'] 2021-07-19 19:31:47,290 - INFO - generated new fontManager
epoch: 0, batch: 0, loss is: [2.4713075], acc is [0.23] epoch: 0, batch: 100, loss is: [0.28999496], acc is [0.9] epoch: 0, batch: 200, loss is: [0.1286069], acc is [0.95] epoch: 0, batch: 300, loss is: [0.07993691], acc is [0.99] epoch: 0, batch: 400, loss is: [0.0870489], acc is [0.97] epoch: 1, batch: 0, loss is: [0.10289712], acc is [0.96] epoch: 1, batch: 100, loss is: [0.13878293], acc is [0.96] epoch: 1, batch: 200, loss is: [0.03113149], acc is [0.99] epoch: 1, batch: 300, loss is: [0.03857609], acc is [0.99] epoch: 1, batch: 400, loss is: [0.173149], acc is [0.95] epoch: 2, batch: 0, loss is: [0.08971839], acc is [0.95] epoch: 2, batch: 100, loss is: [0.13954298], acc is [0.97] epoch: 2, batch: 200, loss is: [0.07455762], acc is [0.98] epoch: 2, batch: 300, loss is: [0.01438686], acc is [1.] epoch: 2, batch: 400, loss is: [0.10036014], acc is [0.98] epoch: 3, batch: 0, loss is: [0.01797576], acc is [1.] epoch: 3, batch: 100, loss is: [0.03165066], acc is [0.99] epoch: 3, batch: 200, loss is: [0.01972163], acc is [1.] epoch: 3, batch: 300, loss is: [0.01137184], acc is [1.] epoch: 3, batch: 400, loss is: [0.01236508], acc is [1.] epoch: 4, batch: 0, loss is: [0.04229391], acc is [0.99] epoch: 4, batch: 100, loss is: [0.10908616], acc is [0.96] epoch: 4, batch: 200, loss is: [0.04014423], acc is [0.99] epoch: 4, batch: 300, loss is: [0.00870011], acc is [1.] epoch: 4, batch: 400, loss is: [0.04189901], acc is [0.98] epoch: 5, batch: 0, loss is: [0.02763639], acc is [0.99] epoch: 5, batch: 100, loss is: [0.02963991], acc is [0.99] epoch: 5, batch: 200, loss is: [0.01873996], acc is [0.99] epoch: 5, batch: 300, loss is: [0.05402319], acc is [0.98] epoch: 5, batch: 400, loss is: [0.01004678], acc is [1.] epoch: 6, batch: 0, loss is: [0.04475012], acc is [0.98] epoch: 6, batch: 100, loss is: [0.00611814], acc is [1.] epoch: 6, batch: 200, loss is: [0.05699715], acc is [0.98] epoch: 6, batch: 300, loss is: [0.06741343], acc is [0.99] epoch: 6, batch: 400, loss is: [0.01350569], acc is [1.] epoch: 7, batch: 0, loss is: [0.01066239], acc is [1.] epoch: 7, batch: 100, loss is: [0.02057902], acc is [0.99] epoch: 7, batch: 200, loss is: [0.01313765], acc is [0.99] epoch: 7, batch: 300, loss is: [0.02999004], acc is [0.99] epoch: 7, batch: 400, loss is: [0.0057228], acc is [1.] epoch: 8, batch: 0, loss is: [0.0265837], acc is [0.99] epoch: 8, batch: 100, loss is: [0.00410403], acc is [1.] epoch: 8, batch: 200, loss is: [0.00223313], acc is [1.] epoch: 8, batch: 300, loss is: [0.00909729], acc is [1.] epoch: 8, batch: 400, loss is: [0.02154225], acc is [0.99] epoch: 9, batch: 0, loss is: [0.05837032], acc is [0.99] epoch: 9, batch: 100, loss is: [0.00985126], acc is [0.99] epoch: 9, batch: 200, loss is: [0.01810642], acc is [0.99] epoch: 9, batch: 300, loss is: [0.01097528], acc is [1.] epoch: 9, batch: 400, loss is: [0.00586931], acc is [1.]
#画出训练过程中Loss的变化曲线 plt.figure()
plt.title("train loss", fontsize=24)
plt.xlabel("iter", fontsize=14)
plt.ylabel("loss", fontsize=14)
plt.plot(iters, losses,color='red',label='train loss')
plt.grid()
plt.show()
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook
/__init__.py:2349: DeprecationWarning: Using or importing the ABCs from 'collections'
instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working if isinstance(obj, collections.Iterator): /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__
init__.py:2366: DeprecationWarning: Using or importing the ABCs from 'collections'
instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working return list(data) if isinstance(data, collections.MappingView) else data

使用VisualDL可视化分析
VisualDL是飞桨可视化分析工具,以丰富的图表呈现训练参数变化趋势、模型结构、数据样本、高维数据分布等。帮助用户清晰直观地理解深度学习模型训练过程及模型结构,进而实现高效的模型调优,具体代码实现如下。
- 步骤1:引入VisualDL库,定义作图数据存储位置(供第3步使用),本案例的路径是“log”。
from visualdl import LogWriter
log_writer = LogWriter("./log")
- 步骤2:在训练过程中插入作图语句。当每100个batch训练完成后,将当前损失作为一个新增的数据点(iter和acc的映射对)存储到第一步设置的文件中。使用变量iter记录下已经训练的批次数,作为作图的X轴坐标。
log_writer.add_scalar(tag = 'acc', step = iter, value = avg_acc.numpy())
log_writer.add_scalar(tag = 'loss', step = iter, value = avg_loss.numpy())
iter = iter + 100
# 安装VisualDL !pip install --upgrade --pre visualdl
Looking in indexes: https://mirror.baidu.com/pypi/simple/ Requirement already up-to-date: visualdl in /opt/conda/envs/python35-paddle120-env
/lib/python3.7/site-packages (2.2.0) Requirement already satisfied, skipping upgrade: bce-python-sdk in /opt/conda/envs
/python35-paddle120-env/lib/python3.7/site-packages (from visualdl) (0.8.53) Requirement already satisfied, skipping upgrade: Pillow>=7.0.0 in /opt/conda/envs
/python35-paddle120-env/lib/python3.7/site-packages (from visualdl) (7.1.2) Requirement already satisfied, skipping upgrade: shellcheck-py in /opt/conda/envs
/python35-paddle120-env/lib/python3.7/site-packages (from visualdl) (0.7.1.1) Requirement already satisfied, skipping upgrade: protobuf>=3.11.0 in /opt/conda
/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl) (3.14.0) Requirement already satisfied, skipping upgrade: matplotlib in /opt/conda/envs/p
ython35-paddle120-env/lib/python3.7/site-packages (from visualdl) (2.2.3) Requirement already satisfied, skipping upgrade: pandas in /opt/conda/envs/pyth
on35-paddle120-env/lib/python3.7/site-packages (from visualdl) (1.1.5) Requirement already satisfied, skipping upgrade: requests in /opt/conda/envs/py
thon35-paddle120-env/lib/python3.7/site-packages (from visualdl) (2.22.0) Requirement already satisfied, skipping upgrade: numpy in /opt/conda/envs/python
35-paddle120-env/lib/python3.7/site-packages (from visualdl) (1.20.3) Requirement already satisfied, skipping upgrade: flask>=1.1.1 in /opt/conda/envs
/python35-paddle120-env/lib/python3.7/site-packages (from visualdl) (1.1.1) Requirement already satisfied, skipping upgrade: flake8>=3.7.9 in /opt/conda/en
vs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl) (3.8.2) Requirement already satisfied, skipping upgrade: six>=1.14.0 in /opt/conda/envs
/python35-paddle120-env/lib/python3.7/site-packages (from visualdl) (1.15.0) Requirement already satisfied, skipping upgrade: pre-commit in /opt/conda/envs/p
ython35-paddle120-env/lib/python3.7/site-packages (from visualdl) (1.21.0) Requirement already satisfied, skipping upgrade: Flask-Babel>=1.0.0 in /opt/cond
a/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl) (1.0.0) Requirement already satisfied, skipping upgrade: pycryptodome>=3.8.0 in /opt/conda
/envs/python35-paddle120-env/lib/python3.7/site-packages (from bce-python-sdk->visualdl) (3.9.9) Requirement already satisfied, skipping upgrade: future>=0.6.0 in /opt/conda/env
s/python35-paddle120-env/lib/python3.7/site-packages (from bce-python-sdk->visualdl) (0.18.0) Requirement already satisfied, skipping upgrade: cycler>=0.10 in /opt/conda/env
s/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib->visualdl) (0.10.0) Requirement already satisfied, skipping upgrade: pytz in /opt/conda/envs/pytho
n35-paddle120-env/lib/python3.7/site-packages (from matplotlib->visualdl) (2019.3) Requirement already satisfied, skipping upgrade: pyparsing!=2.0.4,!=2.1.2,!=2
.1.6,>=2.0.1 in /opt/conda/envs/python35-paddle120-env/lib
/python3.7/site-packages (from matplotlib->visualdl) (2.4.2) Requirement already satisfied, skipping upgrade: python-dateutil>=2.1 in /opt
/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib->visualdl) (2.8.0) Requirement already satisfied, skipping upgrade: kiwisolver>=1.0.1 in /opt/conda/env
s/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib->visualdl) (1.1.0) Requirement already satisfied, skipping upgrade: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.
1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl) (1.25.6) Requirement already satisfied, skipping upgrade: idna<2.9,>=2.5 in /opt/conda/envs/pyth
on35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl) (2.8) Requirement already satisfied, skipping upgrade: certifi>=2017.4.17 in /opt/conda/envs/
python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl) (2019.9.11) Requirement already satisfied, skipping upgrade: chardet<3.1.0,>=3.0.2 in /opt/conda/en
vs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl) (3.0.4) Requirement already satisfied, skipping upgrade: Werkzeug>=0.15 in /opt/conda/envs/pyth
on35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl) (0.16.0) Requirement already satisfied, skipping upgrade: itsdangerous>=0.24 in /opt/conda/envs
/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl) (1.1.0) Requirement already satisfied, skipping upgrade: Jinja2>=2.10.1 in /opt/conda/envs/pyth
on35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl) (2.10.1) Requirement already satisfied, skipping upgrade: click>=5.1 in /opt/conda/envs/python35
-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl) (7.0) Requirement already satisfied, skipping upgrade: mccabe<0.7.0,>=0.6.0 in /opt/conda/env
s/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl) (0.6.1) Requirement already satisfied, skipping upgrade: pyflakes<2.3.0,>=2.2.0 in /opt/conda/e
nvs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl) (2.2.0) Requirement already satisfied, skipping upgrade: importlib-metadata; python_version < "
3.8" in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl) (0.23) Requirement already satisfied, skipping upgrade: pycodestyle<2.7.0,>=2.6.0a1 in /opt/co
nda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl) (2.6.0) Requirement already satisfied, skipping upgrade: cfgv>=2.0.0 in /opt/conda/envs/python35
-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl) (2.0.1) Requirement already satisfied, skipping upgrade: pyyaml in /opt/conda/envs/python35-padd
le120-env/lib/python3.7/site-packages (from pre-commit->visualdl) (5.1.2) Requirement already satisfied, skipping upgrade: virtualenv>=15.2 in /opt/conda/envs/py
thon35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl) (16.7.9) Requirement already satisfied, skipping upgrade: nodeenv>=0.11.1 in /opt/conda/envs/py
thon35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl) (1.3.4) Requirement already satisfied, skipping upgrade: identify>=1.0.0 in /opt/conda/envs/p
ython35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl) (1.4.10)
Requirement already satisfied, skipping upgrade: aspy.yaml in /opt/conda/envs/python35-p
addle120-env/lib/python3.7/site-packages (from pre-commit->visualdl) (1.3.0) Requirement already satisfied, skipping upgrade: toml in /opt/conda/envs/python35-paddl
e120-env/lib/python3.7/site-packages (from pre-commit->visualdl) (0.10.0) Requirement already satisfied, skipping upgrade: Babel>=2.3 in /opt/conda/envs/python35
-paddle120-env/lib/python3.7/site-packages (from Flask-Babel>=1.0.0->visualdl) (2.8.0) Requirement already satisfied, skipping upgrade: setuptools in /opt/conda/envs/python35
-paddle120-env/lib/python3.7/site-packages (from kiwisolver>=1.0.1->matplotlib->visualdl) (56.2.0) Requirement already satisfied, skipping upgrade: MarkupSafe>=0.23 in /opt/conda/envs/pyt
hon35-paddle120-env/lib/python3.7/site-packages (from Jinja2>=2.10.1->flask>=1.1.1->visualdl) (1.1.1) Requirement already satisfied, skipping upgrade: zipp>=0.5 in /opt/conda/envs/python35-pa
ddle120-env/lib/python3.7/site-packages (from importlib-metadata; python_version
< "3.8"->flake8>=3.7.9->visualdl) (0.6.0) Requirement already satisfied, skipping upgrade: more-itertools in /opt/conda/envs/p
ython35-paddle120-env/lib/python3.7/site-packages (from zipp>=0.5->importlib-metadata;
python_version < "3.8"->flake8>=3.7.9->visualdl) (7.2.0)
#引入VisualDL库,并设定保存作图数据的文件位置 from visualdl import LogWriter
log_writer = LogWriter(logdir="./log") def train(model): model.train()
opt = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())
EPOCH_NUM = 10 iter = 0 for epoch_id in range(EPOCH_NUM): for batch_id, data in enumerate
(train_loader()): #准备数据,变得更加简洁 images, labels = data
images = paddle.to_tensor(images)
labels = paddle.to_tensor(labels) #前向计算的过程,同时拿到模型输出值和分类准确率
predicts, avg_acc = model(images, labels) #计算损失,取一个批次样本损失的平均值
loss = F.cross_entropy(predicts, labels)
avg_loss = paddle.mean(loss) #每训练了100批次的数据,打印下当前Loss的情况
if batch_id % 100 == 0:
print("epoch: {}, batch: {}, loss is: {}, acc is {}".format(epoch_id,
batch_id, avg_loss.numpy(), avg_acc.numpy()))
log_writer.add_scalar(tag = 'acc', step = iter, value = avg_acc.numpy())
log_writer.add_scalar(tag = 'loss', step = iter, value = avg_loss.numpy())
iter = iter + 100 #后向传播,更新参数的过程 avg_loss.backward()
opt.step()
opt.clear_grad() #保存模型参数 paddle.save(model.state_dict(), 'mnist.pdparams')
model = MNIST()
train(model)
epoch: 0, batch: 0, loss is: [2.4428766], acc is [0.11] epoch: 0, batch: 100, loss is: [0.14379957], acc is [0.98] epoch: 0, batch: 200, loss is: [0.14659742], acc is [0.97] epoch: 0, batch: 300, loss is: [0.15617883], acc is [0.95] epoch: 0, batch: 400, loss is: [0.08011218], acc is [0.97] epoch: 1, batch: 0, loss is: [0.02419444], acc is [1.] epoch: 1, batch: 100, loss is: [0.05039296], acc is [0.98] epoch: 1, batch: 200, loss is: [0.03004572], acc is [1.] epoch: 1, batch: 300, loss is: [0.02300201], acc is [1.] epoch: 1, batch: 400, loss is: [0.05973276], acc is [0.97] epoch: 2, batch: 0, loss is: [0.04129345], acc is [0.98] epoch: 2, batch: 100, loss is: [0.10553801], acc is [0.97] epoch: 2, batch: 200, loss is: [0.03386007], acc is [0.99] epoch: 2, batch: 300, loss is: [0.01647828], acc is [1.] epoch: 2, batch: 400, loss is: [0.12856933], acc is [0.96] epoch: 3, batch: 0, loss is: [0.02018252], acc is [0.99] epoch: 3, batch: 100, loss is: [0.0415822], acc is [0.98] epoch: 3, batch: 200, loss is: [0.01284742], acc is [0.99] epoch: 3, batch: 300, loss is: [0.01533189], acc is [1.] epoch: 3, batch: 400, loss is: [0.04745318], acc is [0.98] epoch: 4, batch: 0, loss is: [0.04060025], acc is [0.97] epoch: 4, batch: 100, loss is: [0.09810024], acc is [0.96] epoch: 4, batch: 200, loss is: [0.01504112], acc is [1.] epoch: 4, batch: 300, loss is: [0.06550632], acc is [0.98] epoch: 4, batch: 400, loss is: [0.01058023], acc is [1.] epoch: 5, batch: 0, loss is: [0.01061755], acc is [1.] epoch: 5, batch: 100, loss is: [0.00329281], acc is [1.] epoch: 5, batch: 200, loss is: [0.04983099], acc is [0.99] epoch: 5, batch: 300, loss is: [0.00971762], acc is [1.] epoch: 5, batch: 400, loss is: [0.04493643], acc is [0.99] epoch: 6, batch: 0, loss is: [0.00563607], acc is [1.] epoch: 6, batch: 100, loss is: [0.01756587], acc is [0.99] epoch: 6, batch: 200, loss is: [0.00185569], acc is [1.] epoch: 6, batch: 300, loss is: [0.0228942], acc is [0.99] epoch: 6, batch: 400, loss is: [0.01174758], acc is [1.] epoch: 7, batch: 0, loss is: [0.0134629], acc is [0.99] epoch: 7, batch: 100, loss is: [0.00740626], acc is [1.] epoch: 7, batch: 200, loss is: [0.00820032], acc is [1.] epoch: 7, batch: 300, loss is: [0.03157279], acc is [0.98] epoch: 7, batch: 400, loss is: [0.00324935], acc is [1.] epoch: 8, batch: 0, loss is: [0.11782757], acc is [0.97] epoch: 8, batch: 100, loss is: [0.01778471], acc is [0.99] epoch: 8, batch: 200, loss is: [0.03268703], acc is [0.98] epoch: 8, batch: 300, loss is: [0.0152339], acc is [0.99] epoch: 8, batch: 400, loss is: [0.01905194], acc is [0.99] epoch: 9, batch: 0, loss is: [0.01949414], acc is [0.99] epoch: 9, batch: 100, loss is: [0.00271552], acc is [1.] epoch: 9, batch: 200, loss is: [0.00252313], acc is [1.] epoch: 9, batch: 300, loss is: [0.01054896], acc is [1.] epoch: 9, batch: 400, loss is: [0.01838213], acc is [1.]
- 步骤3:命令行启动VisualDL。
使用“visualdl --logdir [数据文件所在文件夹路径] 的命令启动VisualDL。在VisualDL启动后,命令行会打印出可用浏览器查阅图形结果的网址。
$ visualdl --logdir ./log --port 8080
- 步骤4:打开浏览器,查看作图结果,如 图6 所示。
查阅的网址在第三步的启动命令后会打印出来(如http://127.0.0.1:8080/),将该网址输入浏览器地址栏刷新页面的效果如下图所示。除了右侧对数据点的作图外,左侧还有一个控制板,可以调整诸多作图的细节。

图6:VisualDL作图示例