上云无忧 > 文档中心 > 第二章:一个案例吃透深度学习(下) - 8.【手写数字识别】之训练调试与优化
飞桨PaddlePaddle开源深度学习平台
第二章:一个案例吃透深度学习(下) - 8.【手写数字识别】之训练调试与优化

文档简介:
概述: 上一节我们研究了资源部署优化的方法,通过使用单GPU和分布式部署,提升模型训练的效率。本节我们依旧横向展开"横纵式",如 图1 所示,探讨在手写数字识别任务中,为了保证模型的真实效果,在模型训练部分,对模型进行一些调试和优化的方法。
*此产品及展示信息均由百度智能云官方提供。免费试用 咨询热线:400-826-7010,为您提供专业的售前咨询,让您快速了解云产品,助您轻松上云! 微信咨询
  免费试用、价格特惠

概述

上一节我们研究了资源部署优化的方法,通过使用单GPU和分布式部署,提升模型训练的效率。本节我们依旧横向展开"横纵式",如 图1 所示,探讨在手写数字识别任务中,为了保证模型的真实效果,在模型训练部分,对模型进行一些调试和优化的方法。

图1:“横纵式”教学法 — 训练过程

训练过程优化思路主要有如下五个关键环节:

1. 计算分类准确率,观测模型训练效果。

交叉熵损失函数只能作为优化目标,无法直接准确衡量模型的训练效果。准确率可以直接衡量训练效果,但由于其离散性质,不适合做为损失函数优化神经网络。

2. 检查模型训练过程,识别潜在问题。

如果模型的损失或者评估指标表现异常,通常需要打印模型每一层的输入和输出来定位问题,分析每一层的内容来获取错误的原因。

3. 加入校验或测试,更好评价模型效果。

理想的模型训练结果是在训练集和验证集上均有较高的准确率,如果训练集的准确率低于验证集,说明网络训练程度不够;如果训练集的准确率高于验证集,可能是发生了过拟合现象。通过在优化目标中加入正则化项的办法,解决过拟合的问题。

4. 加入正则化项,避免模型过拟合。

飞桨框架支持为整体参数加入正则化项,这是通常的做法。此外,飞桨框架也支持为某一层或某一部分的网络单独加入正则化项,以达到精细调整参数训练的效果。

5. 可视化分析。

用户不仅可以通过打印或使用matplotlib库作图,飞桨还提供了更专业的可视化分析工具VisualDL,提供便捷的可视化分析方法。

计算模型的分类准确率

准确率是一个直观衡量分类模型效果的指标,由于这个指标是离散的,因此不适合作为损失来优化。通常情况下,交叉熵损失越小的模型,分类的准确率也越高。基于分类准确率,我们可以公平地比较两种损失函数的优劣,例如在【手写数字识别】之损失函数章节中均方误差和交叉熵的比较。

使用飞桨提供的计算分类准确率API,可以直接计算准确率。

class paddle.metric.Accuracy

该API的输入参数input为预测的分类结果predict,输入参数label为数据真实的label。飞桨还提供了更多衡量模型效果的计算指标,详细可以查看paddle.meric包下面的API。

在下述代码中,我们在模型前向计算过程forward函数中计算分类准确率,并在训练时打印每个批次样本的分类准确率。

# 加载相关库 import os import random import paddle import numpy as np from PIL import Image
 import gzip import json # 定义数据集读取器 def load_data(mode='train'): # 读取数据文件
 datafile = './work/mnist.json.gz' print('loading mnist dataset from {} ......'.format(datafile))
    data = json.load(gzip.open(datafile)) # 读取数据集中的训练集,验证集和测试集 train_set, 
val_set, eval_set = data # 数据集相关参数,图片高度IMG_ROWS, 图片宽度IMG_COLS IMG_ROWS = 28
 IMG_COLS = 28 # 根据输入mode参数决定使用训练集,验证集还是测试 if mode == 'train':
        imgs = train_set[0]
        labels = train_set[1] elif mode == 'valid':
        imgs = val_set[0]
        labels = val_set[1] elif mode == 'eval':
        imgs = eval_set[0]
        labels = eval_set[1] # 获得所有图像的数量 imgs_length = len(imgs) # 验证图像数量
和标签数量是否一致 assert len(imgs) == len(labels), \ "length of train_imgs({}) should be 
the same as train_labels({})".format( len(imgs), len(labels))

    index_list = list(range(imgs_length)) # 读入数据时用到的batchsize BATCHSIZE = 100 # 
定义数据生成器 def data_generator(): # 训练模式下,打乱训练数据 if mode == 'train':
            random.shuffle(index_list)
        imgs_list = []
        labels_list = [] # 按照索引读取数据 for i in index_list: # 读取图像和标签,转换其
尺寸和类型 img = np.reshape(imgs[i], [1, IMG_ROWS, IMG_COLS]).astype('float32')
            label = np.reshape(labels[i], [1]).astype('int64')
            imgs_list.append(img) 
            labels_list.append(label) # 如果当前数据缓存达到了batch size,就返回一个批次
数据 if len(imgs_list) == BATCHSIZE: yield np.array(imgs_list), np.array(labels_list) 
# 清空数据缓存列表 imgs_list = []
                labels_list = [] # 如果剩余数据的数目小于BATCHSIZE, # 则剩余数据一起构
成一个大小为len(imgs_list)的mini-batch if len(imgs_list) > 0: yield np.array(imgs_list),
 np.array(labels_list) return data_generator

/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/layers/ut
ils.py:26: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To s
ilence this warning, use `int` by itself. Doing this will not modify any behavior and is 
safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specif
y the precision. If you wish to review your current use, check the release note link
 for additional information.
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/r
elease/1.20.0-notes.html#deprecations
  def convert_to_list(value, n, name, dtype=np.int):

# 定义模型结构 import paddle.nn.functional as F from paddle.nn import Conv2D, Ma
xPool2D, Linear # 多层卷积神经网络实现 class MNIST(paddle.nn.Layer): def __init__(self):
 super(MNIST, self).__init__() # 定义卷积层,输出特征通道out_channels设置为20,卷积核的大小
kernel_size为5,卷积步长stride=1,padding=2 self.conv1 = Conv2D(in_channels=1, out_channels=20,
 kernel_size=5, stride=1, padding=2) # 定义池化层,池化核的大小kernel_size为2,池化步长为2 
self.max_pool1 = MaxPool2D(kernel_size=2, stride=2) # 定义卷积层,输出特征通道out_channels
设置为20,卷积核的大小kernel_size为5,卷积步长stride=1,padding=2 self.conv2 = Conv2D(in_ch
annels=20, out_channels=20, kernel_size=5, stride=1, padding=2) # 定义池化层,池化核的大小k
ernel_size为2,池化步长为2 self.max_pool2 = MaxPool2D(kernel_size=2, stride=2) # 定义一层全
连接层,输出维度是10 self.fc = Linear(in_features=980, out_features=10) # 定义网络前向计算过程,
卷积后紧接着使用池化层,最后使用全连接层计算最终输出 # 卷积层激活函数使用Relu,全连接层激活函数
使用softmax def forward(self, inputs, label): x = self.conv1(inputs)
         x = F.relu(x)
         x = self.max_pool1(x)
         x = self.conv2(x)
         x = F.relu(x)
         x = self.max_pool2(x)
         x = paddle.reshape(x, [x.shape[0], 980])
         x = self.fc(x) if label is not None:
             acc = paddle.metric.accuracy(input=x, label=label) return x, acc else: return 
x #调用加载数据的函数 train_loader = load_data('train') #在使用GPU机器时,可以将use_gpu变量设
置成True use_gpu = True paddle.set_device('gpu:0') if use_gpu else paddle.set_device('cpu') 
#仅优化算法的设置有所差别 def train(model): model = MNIST()
    model.train() #四种优化算法的设置方案,可以逐一尝试效果 # opt = paddle.optimizer.SGD(lea
rning_rate=0.01, parameters=model.parameters()) # opt = paddle.optimizer.Momentum(learning_r
ate=0.01, momentum=0.9, parameters=model.parameters()) # opt = paddle.optimizer.Adagrad(lear
ning_rate=0.01, parameters=model.parameters()) opt = paddle.optimizer.Adam(learning_rate=0.01,
 parameters=model.parameters())
    
    EPOCH_NUM = 5 for epoch_id in range(EPOCH_NUM): for batch_id, data in enumerate(train
_loader()): #准备数据 images, labels = data
            images = paddle.to_tensor(images)
            labels = paddle.to_tensor(labels) #前向计算的过程 predicts, acc = model(images
, labels) #计算损失,取一个批次样本损失的平均值 loss = F.cross_entropy(predicts, labels)
            avg_loss = paddle.mean(loss) #每训练了100批次的数据,打印下当前Loss的情况 if batch_id % 200 == 0:
                print("epoch: {}, batch: {}, loss is: {}, acc is {}".format(epoch_id, ba
tch_id, avg_loss.numpy(), acc.numpy())) #后向传播,更新参数,消除梯度的过程 avg_loss.backward()
            opt.step()
            opt.clear_grad() #保存模型参数 paddle.save(model.state_dict(), 'mnist.pdparams')
 #创建模型  model = MNIST() #启动训练过程 train(model)
loading mnist dataset from ./work/mnist.json.gz ......
epoch: 0, batch: 0, loss is: [2.3825366], acc is [0.13]
epoch: 0, batch: 200, loss is: [0.07315266], acc is [0.97]
epoch: 0, batch: 400, loss is: [0.08943536], acc is [0.97]
epoch: 1, batch: 0, loss is: [0.02650153], acc is [0.99]
epoch: 1, batch: 200, loss is: [0.00596725], acc is [1.]
epoch: 1, batch: 400, loss is: [0.08135935], acc is [0.97]
epoch: 2, batch: 0, loss is: [0.04334498], acc is [0.99]
epoch: 2, batch: 200, loss is: [0.0617303], acc is [0.98]
epoch: 2, batch: 400, loss is: [0.02098986], acc is [0.99]
epoch: 3, batch: 0, loss is: [0.04585475], acc is [0.98]
epoch: 3, batch: 200, loss is: [0.02449889], acc is [0.99]
epoch: 3, batch: 400, loss is: [0.0548584], acc is [0.99]
epoch: 4, batch: 0, loss is: [0.11403488], acc is [0.98]
epoch: 4, batch: 200, loss is: [0.05665577], acc is [0.99]
epoch: 4, batch: 400, loss is: [0.02090033], acc is [0.99]

检查模型训练过程,识别潜在训练问题

使用飞桨动态图编程可以方便的查看和调试训练的执行过程。在网络定义的Forward函数中,可以打印每一层输入输出的尺寸,以及每层网络的参数。通过查看这些信息,不仅可以更好地理解训练的执行过程,还可以发现潜在问题,或者启发继续优化的思路。

在下述程序中,使用check_shape变量控制是否打印“尺寸”,验证网络结构是否正确。使用check_content变量控制是否打印“内容值”,验证数据分布是否合理。假如在训练中发现中间层的部分输出持续为0,说明该部分的网络结构设计存在问题,没有充分利用。

import paddle.nn.functional as F # 定义模型结构 class MNIST(paddle.nn.Layer): def __init__(self):
 super(MNIST, self).__init__() # 定义卷积层,输出特征通道out_channels设置为20,卷积核的大小kernel_
size为5,卷积步长stride=1,padding=2 self.conv1 = Conv2D(in_channels=1, out_channels=20, kernel_size
=5, stride=1, padding=2) # 定义池化层,池化核的大小kernel_size为2,池化步长为2 self.max_pool1 = MaxP
ool2D(kernel_size=2, stride=2) # 定义卷积层,输出特征通道out_channels设置为20,卷积核的大小kernel_s
ize为5,卷积步长stride=1,padding=2 self.conv2 = Conv2D(in_channels=20, out_channels=20, kernel_siz
e=5, stride=1, padding=2) # 定义池化层,池化核的大小kernel_size为2,池化步长为2 self.max_pool2 = M
axPool2D(kernel_size=2, stride=2) # 定义一层全连接层,输出维度是10 self.fc = Linear(in_features=980
, out_features=10) #加入对每一层输入和输出的尺寸和数据内容的打印,根据check参数决策是否打印每层的参
数和输出尺寸 # 卷积层激活函数使用Relu,全连接层激活函数使用softmax def forward(self, inputs, labe
l=None, check_shape=False, check_content=False): # 给不同层的输出不同命名,
方便调试 outputs1 = self.conv1(inputs)
         outputs2 = F.relu(outputs1)
         outputs3 = self.max_pool1(outputs2)
         outputs4 = self.conv2(outputs3)
         outputs5 = F.relu(outputs4)
         outputs6 = self.max_pool2(outputs5)
         outputs6 = paddle.reshape(outputs6, [outputs6.shape[0], -1])
         outputs7 = self.fc(outputs6) # 选择是否打印神经网络每层的参数尺寸和输出尺寸,
验证网络结构是否设置正确 if check_shape: # 打印每层网络设置的超参数-卷积核尺寸,卷积步长,
卷积padding,池化核尺寸 print("\n########## print network layer's superparams ##############")
             print("conv1-- kernel_size:{}, padding:{}, stride:{}".format(self.conv1.weight.s
hape, self.conv1._padding, self.conv1._stride))
             print("conv2-- kernel_size:{}, padding:{}, stride:{}".format(self.conv2.weight.
shape, self.conv2._padding, self.conv2._stride)) #print("max_pool1-- kernel_size:{}, padding:
{}, stride:{}".format(self.max_pool1.pool_size, self.max_pool1.pool_stride, self.max_pool1._
stride)) #print("max_pool2-- kernel_size:{}, padding:{}, stride:{}".format(self.max_pool2.we
ight.shape, self.max_pool2._padding, self.max_pool2._stride)) print("fc-- weight_size:{}, bia
s_size_{}".format(self.fc.weight.shape, self.fc.bias.shape)) # 打印每层的输出尺寸 print("\n##
######## print shape of features of every layer ###############")
             print("inputs_shape: {}".format(inputs.shape))
             print("outputs1_shape: {}".format(outputs1.shape))
             print("outputs2_shape: {}".format(outputs2.shape))
             print("outputs3_shape: {}".format(outputs3.shape))
             print("outputs4_shape: {}".format(outputs4.shape))
             print("outputs5_shape: {}".format(outputs5.shape))
             print("outputs6_shape: {}".format(outputs6.shape))
             print("outputs7_shape: {}".format(outputs7.shape)) # print("outputs8_shape: {}".
format(outputs8.shape)) # 选择是否打印训练过程中的参数和输出内容,可用于训练过程中的调试 if 
check_content: # 打印卷积层的参数-卷积核权重,权重参数较多,此处只打印部分参数 print("\n######
#### print convolution layer's kernel ###############")
             print("conv1 params -- kernel weights:", self.conv1.weight[0][0])
             print("conv2 params -- kernel weights:", self.conv2.weight[0][0]) # 创建随机数,
随机打印某一个通道的输出值 idx1 = np.random.randint(0, outputs1.shape[1])
             idx2 = np.random.randint(0, outputs4.shape[1]) # 打印卷积-池化后的结果,仅打印batc
h中第一个图像对应的特征 print("\nThe {}th channel of conv1 layer: ".format(idx1), outputs1[0][idx1])
             print("The {}th channel of conv2 layer: ".format(idx2), outputs4[0][idx2])
             print("The output of last layer:", outputs7[0], '\n') # 如果label不是None,
则计算分类精度并返回 if label is not None:
             acc = paddle.metric.accuracy(input=F.softmax(outputs7), label=label) return 
outputs7, acc else: return outputs7 #在使用GPU机器时,可以将use_gpu变量设置成True use_gpu
= True paddle.set_device('gpu:0') if use_gpu else paddle.set_device('cpu')
 def train(model): model = MNIST()
    model.train() #四种优化算法的设置方案,可以逐一尝试效果 opt = paddle.optimizer.SGD
(learning_rate=0.01, parameters=model.parameters()) # opt = paddle.optimizer.Momentum
(learning_rate=0.01, momentum=0.9, parameters=model.parameters()) # opt = paddle.optimizer.
Adagrad(learning_rate=0.01, parameters=model.parameters()) # opt = paddle.optimizer.Adam
(learning_rate=0.01, parameters=model.parameters()) EPOCH_NUM = 1 for epoch_id in range
(EPOCH_NUM): for batch_id, data in enumerate(train_loader()): #准备数据,变得更加简洁 
images, labels = data
            images = paddle.to_tensor(images)
            labels = paddle.to_tensor(labels) #前向计算的过程,同时拿到模型输出值和分类准
确率 if batch_id == 0 and epoch_id==0: # 打印模型参数和每层输出的尺寸 predicts, acc = mode
l(images, labels, check_shape=True, check_content=False) elif batch_id==401: # 打印模型参数
和每层输出的值 predicts, acc = model(images, labels, check_shape=False, check_content=True) else:
                predicts, acc = model(images, labels) #计算损失,取一个批次样本损失的平均值
 loss = F.cross_entropy(predicts, labels)
            avg_loss = paddle.mean(loss) #每训练了100批次的数据,打印下当前
Loss的情况 if batch_id % 200 == 0:
                print("epoch: {}, batch: {}, loss is: {}, acc is {}".format(epoch_id, 
batch_id, avg_loss.numpy(), acc.numpy())) #后向传播,更新参数的过程 avg_loss.backward()
            opt.step()
            opt.clear_grad() #保存模型参数 paddle.save(model.state_dict(), 'mnist_test
.pdparams') #创建模型  model = MNIST() #启动训练过程 train(model)

print("Model has been saved.")

########## print network layer's superparams ##############
conv1-- kernel_size:[20, 1, 5, 5], padding:2, stride:[1, 1]
conv2-- kernel_size:[20, 20, 5, 5], padding:2, stride:[1, 1]
fc-- weight_size:[980, 10], bias_size_[10]

########## print shape of features of every layer ###############
inputs_shape: [100, 1, 28, 28]
outputs1_shape: [100, 20, 28, 28]
outputs2_shape: [100, 20, 28, 28]
outputs3_shape: [100, 20, 14, 14]
outputs4_shape: [100, 20, 14, 14]
outputs5_shape: [100, 20, 14, 14]
outputs6_shape: [100, 980]
outputs7_shape: [100, 10]
epoch: 0, batch: 0, loss is: [2.4391866], acc is [0.06]
epoch: 0, batch: 200, loss is: [0.29858145], acc is [0.9]
epoch: 0, batch: 400, loss is: [0.36437622], acc is [0.88]

########## print convolution layer's kernel ###############
conv1 params -- kernel weights: Tensor(shape=[5, 5], dtype=float32, place=CUDAPlace(0), stop_gradient=False,
       [[-0.31532317,  0.60181510, -0.38462862,  0.52080828, -0.37432989],
        [ 0.20288531, -0.56747001,  0.53627175, -0.04442852, -0.41307986],
        [ 0.42562252, -0.15216339,  0.33505937,  0.12434796,  0.50957596],
        [ 0.03782330, -0.29627943,  0.44766814, -0.04806038,  0.57602686],
        [ 0.50651461,  0.07721443, -0.27299264,  0.09201541,  0.49806190]])
conv2 params -- kernel weights: Tensor(shape=[5, 5], dtype=float32, place=CUDAPlace(0), stop_gradient=False,
       [[ 0.06146792,  0.01054612,  0.02672482,  0.04353787,  0.11069481],
        [-0.06992761,  0.06845246,  0.01267519,  0.01942314, -0.09324966],
        [ 0.03526127, -0.13740280, -0.01524128, -0.04532520,  0.07689699],
        [-0.11563610, -0.02698515, -0.13878933, -0.07280355, -0.09811222],
        [ 0.01492384, -0.00623472,  0.06594540, -0.14806092,  0.09876055]])

The 13th channel of conv1 layer:  Tensor(shape=[28, 28], dtype=float32, place=CUDAPlace(0), stop_gradient=False,
       [[ 0.00016711,  0.00016711,  0.00016711,  0.00016711,  0.00016711,  0.00016711, 
 0.00016711,  0.00016711,  0.00016711,  0.00004595,  0.00016176,  0.00014502,  0.00013599, 
 0.00014556,  0.00012218,  0.00014980,  0.00008315,  0.00014404,  0.00011110,  0.00014935, 
 0.00016037,  0.00016016,  0.00016134,  0.00016354,  0.00015590,  0.00015430,  0.00010933, 
 0.00016711],
        [ 0.00016711,  0.00016711,  0.00016711,  0.00016711,  0.00016711,  0.00016711,
  0.00016711,  0.00016711,  0.00016711,  0.00018658,  0.00017924,  0.00015597,  0.00017241,  0.00015906, 
 0.00016783,  0.00015676,  0.00016994,  0.00014616,  0.00016475,  0.00016585,  0.00016795,  0.00016773, 
 0.00016750,  0.00016457,  0.00016943,  0.00015349,  0.00017223,  0.00016711],
        [ 0.00016711,  0.00016711,  0.00016711,  0.00016711,  0.00016711,  0.00016711,  0.00016711, 
 0.00016711,  0.00016711,  0.00016333,  0.00016807,  0.00016483,  0.16958664,  0.53163326,  0.71812445, 
 0.67207003,  0.60056192,  0.55837387,  0.26870430, -0.18754219, -0.22095074, -0.08460066, -0.07424995,
  0.00016539,  0.00016385,  0.00016229,  0.00015654,  0.00016711],
        [ 0.00016711,  0.00016711,  0.00016711,  0.00016711,  0.00016711,  0.00016711,  0.00016711, 
 0.00016711,  0.00016711,  0.00016562,  0.00018135,  0.37641406,  0.72576272,  0.60050440,  0.71715701, 
 0.79094279,  0.73734343,  0.50956964,  0.28189507,  0.22065853,  0.08193393,  0.00188888,  0.06607085, 
 0.00016618,  0.00016735,  0.00016133,  0.00017266,  0.00016711],
        [ 0.00016711,  0.00016711,  0.00016711,  0.00016711,  0.00016711,  0.00016711,  0.00016711,  
0.00016711,  0.00016711,  0.00020304,  0.00015474,  0.22126870,  0.67534292,  0.74925864,  0.51921457,
  0.51779264,  0.64921039,  0.74167162,  0.54174405,  0.35547486,  0.42750961,  0.43111908,  0.19365118,
 0.00016654,  0.00016618,  0.00016458,  0.00016853,  0.00016711],
        [ 0.00016711,  0.00016711,  0.00016711,  0.00016711,  0.00016711,  0.00016711,  0.00016711, 
 0.00016711,  0.00016711,  0.00014616,  0.00020524, -0.09152068,  0.26438427,  0.97282660,  1.26416242,
  1.96458459,  2.45829511,  2.81876373,  2.77497649,  1.87550914,  1.01585639,  0.66619283,  0.15582287,
 -0.02733525,  0.00016891,  0.00016170,  0.00018208,  0.00016711],
        [ 0.00016711,  0.00016711,  0.00016711,  0.00016711,  0.00016711,  0.00016711,  0.00016711,
  0.00016711,  0.00016711,  0.00027392,  0.00010776,  0.22614142,  0.44661298,  0.92685556,  1.52151585, 
 2.11148071,  2.45571613,  2.48675108,  2.32740545,  1.79199696,  0.88100439,  0.00357753, -0.01259188, 
-0.08748826,  0.00016679,  0.00016381,  0.00017177,  0.00016711],
        [ 0.00016711,  0.00016711,  0.00016711,  0.00016711,  0.00016711,  0.00016711,  0.00016711, 
 0.00016711,  0.00016711,  0.00004795,  0.00031215,  0.40498218,  0.91940767,  0.97722036,  1.26561487
,  1.46854281,  1.70221531,  1.85662627,  1.86650336,  1.52954447,  1.14554369,  0.66387773,  0.38482514
  0.10197513,  0.00017440,  0.00015198,  0.00023071,  0.00016711],
        [ 0.00016711,  0.00016711,  0.00016711,  0.00016711,  0.00016711,  0.00016711,  0.00016711, 
 0.00016711,  0.00016711,  0.00041899, -0.00007986,  0.19968341,  0.77158284,  1.04039311,  1.17114949, 
 1.42035317,  1.75722051,  1.79016709,  1.65149999,  1.27333283,  0.77952206,  0.91363120,  0.71511328,
  0.44634739,  0.00016911,  0.00014315,  0.00015822,  0.00016711],
        [ 0.00016876,  0.00016417,  0.00016747,  0.00016646,  0.00016660,  0.00016621,  0.00016568, 
 0.00016777,  0.00015577,  0.00047584,  0.00003945,  0.03368439,  0.36092034,  0.63123375,  0.47924510, 
 0.83292925,  0.95876712,  1.08555090,  1.14672971,  0.78074634,  0.53084153,  0.95858377,  0.78540272, 
 0.57809430,  0.00016151,  0.00007481,  0.00012390,  0.00016711],
        [ 0.00017036,  0.00016666,  0.00016704,  0.00016713,  0.00016723,  0.00016587,  0.00016669, 
 0.00016431,  0.00016569,  0.00018057,  0.40719602,  0.74570280,  0.55773079,  0.63881069,  0.66206867,
  0.58057255,  0.45469847,  0.64718813,  0.97403514,  1.21538293,  0.96643949,  0.83140928,  0.79778320,
  0.15686706,  0.00016446,  0.00015559,  0.00016447,  0.00016711],
        [ 0.00016735,  0.00016638,  0.00016730,  0.00016633,  0.00016686,  0.00016608,  0.00016683,
  0.00016611,  0.00016590,  0.00020591,  0.18480885,  0.72011077,  0.92727220,  0.67771971,  0.72491133, 
 0.47252920,  0.30209750,  0.89913386,  1.44397974,  1.58122110,  1.38958371,  0.85136491,  0.44405198,
 -0.06527461,  0.00016577,  0.00015594,  0.00017212,  0.00016711],
        [ 0.00016697,  0.00016755,  0.00016656,  0.00016725,  0.00016649,  0.00016684,  0.00016650,
  0.00016628,  0.00016599,  0.00013974,  0.19184896,  0.31140241,  0.59208280,  0.96577585,  1.26583183, 
 1.61420250,  1.71820664,  1.89530361,  2.05936217,  1.61399519,  0.88408417,  0.33046040, -0.03023660,
 -0.14290929,  0.00016754,  0.00016573,  0.00016958,  0.00016711],
        [ 0.00016856,  0.00016593,  0.00016812,  0.00016673,  0.00016780,  0.00016638,  0.00016722, 
 0.00016513,  0.00016659,  0.00018709,  0.18783921,  0.48937386,  1.08573532,  1.59907711,  2.51122952, 
 3.28006864,  2.87848496,  1.88626552,  1.30772388,  0.49677867, -0.07801100, -0.18191029, -0.18105611,
  0.00016679,  0.00016702,  0.00016487,  0.00017237,  0.00016711],
        [ 0.00016427,  0.00016872,  0.00016590,  0.00016805,  0.01336800,  0.02355008,  0.01706577, 
 0.01143187, -0.00421693, -0.00622759,  0.31881458,  0.73334509,  0.86253679,  1.24155474,  1.74971414,
  2.33843040,  2.27464581,  1.16401732,  0.68508059,  0.09359124, -0.18980481, -0.06810196, -0.00996031, 
 0.00016673,  0.00016651,  0.00016944,  0.00015844,  0.00016711],
        [ 0.00017200,  0.00016413,  0.25320548,  0.42815340,  0.50178123,  0.73268479,  0.58233982, 
 0.16096711, -0.24943544, -0.12008493,  0.32758334,  0.96908677,  1.21319628,  1.38737893,  1.48044443,
  1.44380271,  1.48257983,  0.84109271,  0.61395234,  0.47253403,  0.14157712,  0.00016949,  0.00016579,
  0.00016650,  0.00017089,  0.00015875,  0.00019435,  0.00016711],
        [ 0.00015650,  0.00017225,  0.28561634,  0.77051896,  0.78538889,  0.35063624,  0.11712936, 
 0.24221942,  0.22844182,  0.05304031,  0.25216714,  0.38469917,  0.45140880,  0.63924140,  0.74463111, 
 1.04575050,  1.15566206,  0.93379205,  0.52104980,  0.73290265,  0.30839062, -0.01601113,  0.00017011,
  0.00017062,  0.00015724,  0.00019127,  0.00011058,  0.00016711],
        [ 0.00018048,  0.14097929,  0.46947938,  0.56473351,  0.35954222,  0.05292580,  0.01556752,  
0.02259169,  0.45759675,  0.52739620,  0.43067965,  0.12061904,  0.09104446,  0.14031146,  0.65303063,
  1.42140603,  0.67411160,  0.32287022,  0.33877966,  0.54029298,  0.33827075,  0.01649187,  0.00016520,
  0.00015668,  0.00019127,  0.00010782,  0.00027762,  0.00016711],
        [-0.00031993, -0.03555977,  0.55035859,  1.08602536,  1.28073096,  1.68930268,  1.88079488,  
1.53508520,  1.25073099,  0.76980925,  0.40042147,  0.10997950,  0.15476465,  0.21629158,  0.53520423, 
 1.34934044,  1.21959817,  0.39389715,  0.23692483,  0.60916984,  0.43032351,  0.04772065,  0.00016668,
  0.00016761,  0.00016549,  0.00016549,  0.00016416,  0.00016711],
        [ 0.00021291, -0.03529413,  0.50253916,  1.36402488,  1.87594378,  2.22709346,  2.40977383, 
 1.80330014,  0.58706552,  0.24696761, -0.05938473,  0.05615317,  0.66042840,  0.63430697,  0.47617677, 
 0.45704794,  0.80879241,  0.62592763,  0.74972111,  0.71725625,  0.34654897,  0.04229181,  0.00016717,
  0.00016691,  0.00016681,  0.00016611,  0.00016686,  0.00016711],
        [ 0.00013232,  0.12255687,  0.47535741,  1.14842176,  1.82607973,  1.68255222,  1.76728773, 
 1.66719759,  1.32665157,  0.99547255,  0.84857190,  0.74153352,  0.47670120,  0.67138344,  0.54188758,
  0.72325122,  1.06826282,  0.63845855,  1.04162419,  0.73475826,  0.07851221, -0.03781775, 
 0.00016703,  0.00016711,  0.00016690,  0.00016680,  0.00016687,  0.00016711],
        [ 0.00019540,  0.05650518,  0.60105032,  0.81867039,  0.85973227,  1.35835588,  1.77878034,
  1.69876838,  1.71830189,  1.50534701,  1.21155989,  1.06631446,  0.73069352,  0.46774510, 
 0.80238563,  1.61473632,  1.66589832,  1.25347912,  1.01117051,  0.27387971,  0.05848682,  0.00016711,
  0.00016712,  0.00016715,  0.00016716,  0.00016689,  0.00016757,  0.00016711],
        [ 0.00012508,  0.00019216,  0.58093423,  1.13065028,  1.00387037,  1.14928293,  1.10693347,
  1.26197362,  1.34477246,  1.47254610,  1.40870392,  1.37475395,  1.71375358,  1.96313608, 
 2.05283213,  2.09536409,  1.51032364,  1.04987156,  0.22829457, -0.16366917, -0.16316731, 
 0.00016706,  0.00016711,  0.00016713,  0.00016709,  0.00016711,  0.00016709,  0.00016711],
        [ 0.00025568,  0.00013889,  0.17277181,  0.59526175,  0.89919132,  1.16894102,  1.05133772,
 1.14531720,  1.22429037,  1.48390639,  1.62546599,  1.71606779,  2.12076283,  2.12293720,  2.06063795,
  1.39157844,  0.60277039,  0.25422239, -0.16194353, -0.05991813,  0.00016719,  0.00016719,  0.00016718, 
 0.00016717,  0.00016726,  0.00016712,  0.00016776,  0.00016711],
        [ 0.00003710,  0.00021211,  0.00014396,  0.10761575,  0.34839991,  0.56968033,  0.86179048, 
 1.07601869,  1.23462546,  1.48044193,  1.66117048,  1.73838699,  1.56631613,  1.18525302,  0.69454443,
  0.24370998, -0.02306978, -0.20283933, -0.11973704, -0.00616391,  0.00016712,  0.00016708,  0.00016717,
  0.00016717,  0.00016705,  0.00016733,  0.00016661,  0.00016711],
        [ 0.00038121,  0.00010757,  0.00021565,  0.00013391,  0.02309544,  0.12140477,  0.24009685, 
 0.28825933,  0.26269117,  0.13363461,  0.04801013, -0.02159940, -0.10086482, -0.26666415, -0.32494220,
 -0.30488664, -0.18981682, -0.06383128,  0.00016595,  0.00016799,  0.00016708,  0.00016729,  0.00016729,
  0.00016722,  0.00016755,  0.00016706,  0.00016872,  0.00016711],
        [-0.00008379,  0.00026677,  0.00010417,  0.00023387,  0.00009082,  0.00032542, -0.00004365, 
 0.00075457, -0.00016095,  0.00034735,  0.00019286,  0.00014804,  0.00016711,  0.00019000,  0.00027011,
  0.00036691,  0.00031016,  0.00167845,  0.00016473,  0.00016591,  0.00016723,  0.00016709,  0.00016721,
  0.00016740,  0.00016673,  0.00016812,  0.00016597,  0.00016711],
        [ 0.00016711,  0.00016711,  0.00016711,  0.00016711,  0.00016711,  0.00016711,  0.00016711, 
 0.00016711,  0.00016711,  0.00016711,  0.00016711,  0.00016711,  0.00016711,  0.00016711,  0.00016711,
  0.00016711,  0.00016711,  0.00016711,  0.00016711,  0.00016711,  0.00016711,  0.00016711,  0.00016711,
 0.00016711,  0.00016711,  0.00016711,  0.00016711,  0.00016711]])
The 11th channel of conv2 layer:  Tensor(shape=[14, 14], dtype=float32, place=CUDAPlace(0), stop_gradient=False,
       [[ 0.00312751,  0.00586835,  0.00110976, -0.02058463, -0.33313772, -0.62154233, -0.90149099,
 -1.24104977, -1.35232341, -0.99607420, -0.53392363, -0.08951769, -0.02691137,  0.05222606],
        [-0.00647034, -0.00965514, -0.01638876, -0.11358240, -0.32065463, -0.75975877, -1.26381135, 
-2.34910107, -2.77629328, -2.35288906, -1.78367567, -0.70091587, -0.42589089,  0.01768003],
        [ 0.00256004,  0.00228425, -0.00861163, -0.02454728, -0.18787183, -0.89747912, -1.79948258, 
-2.88551641, -3.33107972, -2.52407384, -2.35693097, -1.63116193, -0.55040479,  0.23051211],
        [ 0.00255759,  0.00228819, -0.00865710,  0.16160299,  0.24479951,  0.23080023, -0.57432461, 
-1.87221742, -2.81458592, -2.86401129, -2.53954935, -2.01755428, -0.36022085,  0.12348694],
        [ 0.00256483,  0.00229422, -0.00844403,  0.22002032,  0.21440531,  0.62519586,  0.21764396,
 -0.34551919, -2.15028787, -2.49198937, -2.40828681, -1.67824960, -0.72269619, -0.10116137],
        [-0.12288841, -0.23138289, -0.35438883, -0.19345880,  0.18002844,  0.08512699, -0.37421370,
 -1.37870026, -2.45267701, -2.69670653, -2.26589084, -1.20335031, -0.90513039, -0.25440046],
        [-0.50523454, -0.67776489, -0.77518046, -0.35613060, -0.86300278, -0.74686277, -1.17987955,
 -1.08464479, -1.39198148, -1.40300453, -0.92244530, -0.48179084, -1.00485647, -0.33450115],
        [-0.73795503, -0.79777455, -1.01502419, -0.62378365, -0.52079970, -0.36418855, -1.14167190,
 -1.17794013, -1.00021207, -1.62142158, -0.82172585, -0.27296549, -0.73264360, -0.26811692],
        [-0.81346792, -1.18431234, -1.88930893, -1.78307796, -1.53371751, -0.34386659, -0.13295141,
 -1.24216175, -1.52532828, -1.83941269, -1.25321150, -0.87356132, -0.14615494, -0.14799972],
        [-0.38031965, -1.09076464, -1.82781386, -2.51046777, -2.10738373, -1.40568686, -0.89903629,
 -1.48535275, -2.26525354, -2.02210355, -1.25348115, -0.86374021, -0.02413686, -0.01221605],
        [ 0.58205724, -0.61353540, -1.69562149, -2.11039186, -2.25769949, -2.21963143, -1.71002507,
 -1.48456597, -1.34011984, -0.75660688, -0.83922243, -0.91332853, -0.10387615, -0.00901008],
        [ 1.60166466,  1.11056924,  0.17553504, -0.76633096, -1.22525370, -1.06701612, -0.58505929,
 -0.80741614, -0.33068943,  0.10534806, -0.55697370, -0.86783463, -0.13269022, -0.00899511],
        [ 1.41887498,  1.19129920,  1.14289379,  0.72163981,  0.07429478, -0.24150907, -0.04159492
, -0.18197620, -0.24404357, -0.23238824, -0.54796576, -0.48223227, -0.11518352, -0.00972981],
        [ 1.05995369,  1.05843043,  1.06410861,  1.29815662,  1.07523298,  0.58790469,  0.26610154,
 -0.03364890, -0.12993495, -0.43510956, -0.38913697, -0.18151908, -0.06200687, -0.01186676]])
The output of last layer: Tensor(shape=[10], dtype=float32, place=CUDAPlace(0), stop_gradient=False,
       [ 1.45627880, -0.73663777,  1.47225618,  8.28572273, -4.86152220,  0.67528987, -3.78966165
, -3.68850040,  3.13669801, -2.26959062]) 

Model has been saved.

加入校验或测试,更好评价模型效果

在训练过程中,我们会发现模型在训练样本集上的损失在不断减小。但这是否代表模型在未来的应用场景上依然有效?为了验证模型的有效性,通常将样本集合分成三份,训练集、校验集和测试集。

  • 训练集 :用于训练模型的参数,即训练过程中主要完成的工作。
  • 校验集 :用于对模型超参数的选择,比如网络结构的调整、正则化项权重的选择等。
  • 测试集 :用于模拟模型在应用后的真实效果。因为测试集没有参与任何模型优化或参数训练的工作,所以它对模型来说是完全未知的样本。在不以校验数据优化网络结构或模型超参数时,校验数据和测试数据的效果是类似的,均更真实的反映模型效果。

如下程序读取上一步训练保存的模型参数,读取校验数据集,并测试模型在校验数据集上的效果。

def evaluation(model): print('start evaluation .......') # 定义预测过程 params_file_path = 
'mnist.pdparams' # 加载模型参数 param_dict = paddle.load(params_file_path)
    model.load_dict(param_dict)

    model.eval()
    eval_loader = load_data('eval')

    acc_set = []
    avg_loss_set = [] for batch_id, data in enumerate(eval_loader()):
        images, labels = data
        images = paddle.to_tensor(images)
        labels = paddle.to_tensor(labels)
        predicts, acc = model(images, labels)
        loss = F.cross_entropy(input=predicts, label=labels)
        avg_loss = paddle.mean(loss)
        acc_set.append(float(acc.numpy()))
        avg_loss_set.append(float(avg_loss.numpy())) #计算多个batch的平均损失和准确率 acc_val_mean = np.array(acc_set).mean()
    avg_loss_val_mean = np.array(avg_loss_set).mean()

    print('loss={}, acc={}'.format(avg_loss_val_mean, acc_val_mean))

model = MNIST()
evaluation(model)
start evaluation .......
loading mnist dataset from ./work/mnist.json.gz ......
loss=0.05001451548640034, acc=0.9861000066995621

从测试的效果来看,模型在验证集上依然有98.6%的准确率,证明它是有预测效果的。

加入正则化项,避免模型过拟合

过拟合现象

对于样本量有限、但需要使用强大模型的复杂任务,模型很容易出现过拟合的表现,即在训练集上的损失小,在验证集或测试集上的损失较大,如 图2 所示。


图2:过拟合现象,训练误差不断降低,但测试误差先降后增


反之,如果模型在训练集和测试集上均损失较大,则称为欠拟合。过拟合表示模型过于敏感,学习到了训练数据中的一些误差,而这些误差并不是真实的泛化规律(可推广到测试集上的规律)。欠拟合表示模型还不够强大,还没有很好的拟合已知的训练样本,更别提测试样本了。因为欠拟合情况容易观察和解决,只要训练loss不够好,就不断使用更强大的模型即可,因此实际中我们更需要处理好过拟合的问题。

导致过拟合原因

造成过拟合的原因是模型过于敏感,而训练数据量太少或其中的噪音太多。

图3 所示,理想的回归模型是一条坡度较缓的抛物线,欠拟合的模型只拟合出一条直线,显然没有捕捉到真实的规律,但过拟合的模型拟合出存在很多拐点的抛物线,显然是过于敏感,也没有正确表达真实规律。


图3:回归模型的过拟合,理想和欠拟合状态的表现


图4 所示,理想的分类模型是一条半圆形的曲线,欠拟合用直线作为分类边界,显然没有捕捉到真实的边界,但过拟合的模型拟合出很扭曲的分类边界,虽然对所有的训练数据正确分类,但对一些较为个例的样本所做出的妥协,高概率不是真实的规律。


图4:分类模型的欠拟合,理想和过拟合状态的表现


过拟合的成因与防控

为了更好的理解过拟合的成因,可以参考侦探定位罪犯的案例逻辑,如 图5 所示。


图5:侦探定位罪犯与模型假设示意


对于这个案例,假设侦探也会犯错,通过分析发现可能的原因:

  1. 情况1:罪犯证据存在错误,依据错误的证据寻找罪犯肯定是缘木求鱼。

  2. 情况2:搜索范围太大的同时证据太少,导致符合条件的候选(嫌疑人)太多,无法准确定位罪犯。

那么侦探解决这个问题的方法有两种:或者缩小搜索范围(比如假设该案件只能是熟人作案),或者寻找更多的证据。

归结到深度学习中,假设模型也会犯错,通过分析发现可能的原因:

  1. 情况1:训练数据存在噪音,导致模型学到了噪音,而不是真实规律。

  2. 情况2:使用强大模型(表示空间大)的同时训练数据太少,导致在训练数据上表现良好的候选假设太多,锁定了一个“虚假正确”的假设。

对于情况1,我们使用数据清洗和修正来解决。 对于情况2,我们或者限制模型表示能力,或者收集更多的训练数据。

而清洗训练数据中的错误,或收集更多的训练数据往往是一句“正确的废话”,在任何时候我们都想获得更多更高质量的数据。在实际项目中,更快、更低成本可控制过拟合的方法,只有限制模型的表示能力。

正则化项

为了防止模型过拟合,在没有扩充样本量的可能下,只能降低模型的复杂度,可以通过限制参数的数量或可能取值(参数值尽量小)实现。

具体来说,在模型的优化目标(损失)中人为加入对参数规模的惩罚项。当参数越多或取值越大时,该惩罚项就越大。通过调整惩罚项的权重系数,可以使模型在“尽量减少训练损失”和“保持模型的泛化能力”之间取得平衡。泛化能力表示模型在没有见过的样本上依然有效。正则化项的存在,增加了模型在训练集上的损失。

飞桨支持为所有参数加上统一的正则化项,也支持为特定的参数添加正则化项。前者的实现如下代码所示,仅在优化器中设置weight_decay参数即可实现。使用参数coeff调节正则化项的权重,权重越大时,对模型复杂度的惩罚越高。

def train(model): model.train() #各种优化算法均可以加入正则化项,避免过拟合,参数
regularization_coeff调节正则化项的权重 opt = paddle.optimizer.Adam(learning_rate=0.01, 
weight_decay=paddle.regularizer.L2Decay(coeff=1e-5), parameters=model.parameters())           

    EPOCH_NUM = 5 for epoch_id in range(EPOCH_NUM): for batch_id, data in enumerate
(train_loader()): #准备数据,变得更加简洁 images, labels = data
            images = paddle.to_tensor(images)
            labels = paddle.to_tensor(labels) #前向计算的过程,同时拿到模型输出值和分
类准确率 predicts, acc = model(images, labels) #计算损失,取一个批次样本损失的平均值
 loss = F.cross_entropy(predicts, labels)
            avg_loss = paddle.mean(loss) #每训练了100批次的数据,打印下当前Loss的情况 if batch_id % 200 == 0:
                print("epoch: {}, batch: {}, loss is: {}, acc is {}".format(epoch_id,
 batch_id, avg_loss.numpy(), acc.numpy())) #后向传播,更新参数的过程 avg_loss.backward()
            opt.step()
            opt.clear_grad() #保存模型参数 paddle.save(model.state_dict(), 'mnist_regul.pdparams')

model = MNIST()
train(model)
epoch: 0, batch: 0, loss is: [2.625702], acc is [0.14]
epoch: 0, batch: 200, loss is: [0.19216107], acc is [0.96]
epoch: 0, batch: 400, loss is: [0.10491196], acc is [0.96]
epoch: 1, batch: 0, loss is: [0.04200032], acc is [0.98]
epoch: 1, batch: 200, loss is: [0.03362055], acc is [1.]
epoch: 1, batch: 400, loss is: [0.10716258], acc is [0.98]
epoch: 2, batch: 0, loss is: [0.06904294], acc is [0.97]
epoch: 2, batch: 200, loss is: [0.05971422], acc is [0.98]
epoch: 2, batch: 400, loss is: [0.04597992], acc is [0.99]
epoch: 3, batch: 0, loss is: [0.04722928], acc is [0.97]
epoch: 3, batch: 200, loss is: [0.01793179], acc is [0.99]
epoch: 3, batch: 400, loss is: [0.11060252], acc is [0.98]
epoch: 4, batch: 0, loss is: [0.01327007], acc is [1.]
epoch: 4, batch: 200, loss is: [0.07705414], acc is [0.98]
epoch: 4, batch: 400, loss is: [0.02693538], acc is [0.99]

def evaluation(model): print('start evaluation .......') # 定义预测过程 params_file_path 
= 'mnist_regul.pdparams' # 加载模型参数 param_dict = paddle.load(params_file_path)
    model.load_dict(param_dict)

    model.eval()
    eval_loader = load_data('eval')

    acc_set = []
    avg_loss_set = [] for batch_id, data in enumerate(eval_loader()):
        images, labels = data
        images = paddle.to_tensor(images)
        labels = paddle.to_tensor(labels)
        predicts, acc = model(images, labels)
        loss = F.cross_entropy(input=predicts, label=labels)
        avg_loss = paddle.mean(loss)
        acc_set.append(float(acc.numpy()))
        avg_loss_set.append(float(avg_loss.numpy())) #计算多个batch的平均损失和准确率 acc_val_mean = np.array(acc_set).mean()
    avg_loss_val_mean = np.array(avg_loss_set).mean()

    print('loss={}, acc={}'.format(avg_loss_val_mean, acc_val_mean))

model = MNIST()
evaluation(model)
start evaluation .......
loading mnist dataset from ./work/mnist.json.gz ......
loss=1.4758862841129303, acc=0.9866000074148178

可视化分析

训练模型时,经常需要观察模型的评价指标,分析模型的优化过程,以确保训练是有效的。可选用这两种工具:Matplotlib库和VisualDL。

  • Matplotlib库:Matplotlib库是Python中使用的最多的2D图形绘图库,它有一套完全仿照MATLAB的函数形式的绘图接口,使用轻量级的PLT库(Matplotlib)作图是非常简单的。
  • VisualDL:如果期望使用更加专业的作图工具,可以尝试VisualDL,飞桨可视化分析工具。VisualDL能够有效地展示飞桨在运行过程中的计算图、各种指标变化趋势和数据信息。

使用Matplotlib库绘制损失随训练下降的曲线图

将训练的批次编号作为X轴坐标,该批次的训练损失作为Y轴坐标。

  1. 训练开始前,声明两个列表变量存储对应的批次编号(iters=[])和训练损失(losses=[])。
iters=[]
losses=[] for epoch_id in range(EPOCH_NUM): """start to training""" 
  1. 随着训练的进行,将iter和losses两个列表填满。
import paddle.nn.functional as F

iters=[] losses=[] for epoch_id in range(EPOCH_NUM): for batch_id, data in enumerate(train_loader()):
        images, labels = data
        predicts, acc = model(images, labels)
        loss = F.cross_entropy(predicts, label = labels.astype('int64'))
        avg_loss = paddle.mean(loss)
        # 累计迭代次数和对应的loss
   	iters.append(batch_id + epoch_id*len(list(train_loader()))
	losses.append(avg_loss)
  1. 训练结束后,将两份数据以参数形式导入PLT的横纵坐标。
plt.xlabel("iter", fontsize=14),plt.ylabel("loss", fontsize=14)
  1. 最后,调用plt.plot()函数即可完成作图。
plt.plot(iters, losses,color='red',label='train loss')

详细代码如下:

#引入matplotlib库 import matplotlib.pyplot as plt def train(model): model.train()
    
    opt = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())
    
    EPOCH_NUM = 10 iter=0 iters=[]
    losses=[] for epoch_id in range(EPOCH_NUM): for batch_id, data in enumerate(train_loader())
: #准备数据,变得更加简洁 images, labels = data
            images = paddle.to_tensor(images)
            labels = paddle.to_tensor(labels) #前向计算的过程,同时拿到模型输出值和分类准确率
 predicts, acc = model(images, labels) #计算损失,取一个批次样本损失的平均值 
loss = F.cross_entropy(predicts, labels)
            avg_loss = paddle.mean(loss) #每训练了100批次的数据,打印下当前Loss的情况 if batch_id % 100 == 0:
                print("epoch: {}, batch: {}, loss is: {}, acc is {}".format
(epoch_id, batch_id, avg_loss.numpy(), acc.numpy()))
                iters.append(iter)
                losses.append(avg_loss.numpy()) iter = iter + 100 #后向传播,更新参数的过程 avg_loss.backward()
            opt.step()
            opt.clear_grad() #保存模型参数 paddle.save(model.state_dict(), 'mnist.pdparams') return iters, losses
    
model = MNIST()
iters, losses = train(model)

/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/__init
__.py:107: DeprecationWarning: Using or importing the ABCs from 'collections' instead
 of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import MutableMapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.
py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of 
from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Iterable, Mapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.
py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of 
from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Sized
2021-07-19 19:31:46,551 - INFO - font search path ['/opt/conda/envs/python35-paddle12
0-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/ttf', '/opt/conda/envs/py
thon35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/afm', '/op
t/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/pdfcorefonts']
2021-07-19 19:31:47,290 - INFO - generated new fontManager

epoch: 0, batch: 0, loss is: [2.4713075], acc is [0.23]
epoch: 0, batch: 100, loss is: [0.28999496], acc is [0.9]
epoch: 0, batch: 200, loss is: [0.1286069], acc is [0.95]
epoch: 0, batch: 300, loss is: [0.07993691], acc is [0.99]
epoch: 0, batch: 400, loss is: [0.0870489], acc is [0.97]
epoch: 1, batch: 0, loss is: [0.10289712], acc is [0.96]
epoch: 1, batch: 100, loss is: [0.13878293], acc is [0.96]
epoch: 1, batch: 200, loss is: [0.03113149], acc is [0.99]
epoch: 1, batch: 300, loss is: [0.03857609], acc is [0.99]
epoch: 1, batch: 400, loss is: [0.173149], acc is [0.95]
epoch: 2, batch: 0, loss is: [0.08971839], acc is [0.95]
epoch: 2, batch: 100, loss is: [0.13954298], acc is [0.97]
epoch: 2, batch: 200, loss is: [0.07455762], acc is [0.98]
epoch: 2, batch: 300, loss is: [0.01438686], acc is [1.]
epoch: 2, batch: 400, loss is: [0.10036014], acc is [0.98]
epoch: 3, batch: 0, loss is: [0.01797576], acc is [1.]
epoch: 3, batch: 100, loss is: [0.03165066], acc is [0.99]
epoch: 3, batch: 200, loss is: [0.01972163], acc is [1.]
epoch: 3, batch: 300, loss is: [0.01137184], acc is [1.]
epoch: 3, batch: 400, loss is: [0.01236508], acc is [1.]
epoch: 4, batch: 0, loss is: [0.04229391], acc is [0.99]
epoch: 4, batch: 100, loss is: [0.10908616], acc is [0.96]
epoch: 4, batch: 200, loss is: [0.04014423], acc is [0.99]
epoch: 4, batch: 300, loss is: [0.00870011], acc is [1.]
epoch: 4, batch: 400, loss is: [0.04189901], acc is [0.98]
epoch: 5, batch: 0, loss is: [0.02763639], acc is [0.99]
epoch: 5, batch: 100, loss is: [0.02963991], acc is [0.99]
epoch: 5, batch: 200, loss is: [0.01873996], acc is [0.99]
epoch: 5, batch: 300, loss is: [0.05402319], acc is [0.98]
epoch: 5, batch: 400, loss is: [0.01004678], acc is [1.]
epoch: 6, batch: 0, loss is: [0.04475012], acc is [0.98]
epoch: 6, batch: 100, loss is: [0.00611814], acc is [1.]
epoch: 6, batch: 200, loss is: [0.05699715], acc is [0.98]
epoch: 6, batch: 300, loss is: [0.06741343], acc is [0.99]
epoch: 6, batch: 400, loss is: [0.01350569], acc is [1.]
epoch: 7, batch: 0, loss is: [0.01066239], acc is [1.]
epoch: 7, batch: 100, loss is: [0.02057902], acc is [0.99]
epoch: 7, batch: 200, loss is: [0.01313765], acc is [0.99]
epoch: 7, batch: 300, loss is: [0.02999004], acc is [0.99]
epoch: 7, batch: 400, loss is: [0.0057228], acc is [1.]
epoch: 8, batch: 0, loss is: [0.0265837], acc is [0.99]
epoch: 8, batch: 100, loss is: [0.00410403], acc is [1.]
epoch: 8, batch: 200, loss is: [0.00223313], acc is [1.]
epoch: 8, batch: 300, loss is: [0.00909729], acc is [1.]
epoch: 8, batch: 400, loss is: [0.02154225], acc is [0.99]
epoch: 9, batch: 0, loss is: [0.05837032], acc is [0.99]
epoch: 9, batch: 100, loss is: [0.00985126], acc is [0.99]
epoch: 9, batch: 200, loss is: [0.01810642], acc is [0.99]
epoch: 9, batch: 300, loss is: [0.01097528], acc is [1.]
epoch: 9, batch: 400, loss is: [0.00586931], acc is [1.]

#画出训练过程中Loss的变化曲线 plt.figure()
plt.title("train loss", fontsize=24)
plt.xlabel("iter", fontsize=14)
plt.ylabel("loss", fontsize=14)
plt.plot(iters, losses,color='red',label='train loss') 
plt.grid()
plt.show()

/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook
/__init__.py:2349: DeprecationWarning: Using or importing the ABCs from 'collections' 
instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  if isinstance(obj, collections.Iterator):
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__
init__.py:2366: DeprecationWarning: Using or importing the ABCs from 'collections' 
instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  return list(data) if isinstance(data, collections.MappingView) else data



			

使用VisualDL可视化分析

VisualDL是飞桨可视化分析工具,以丰富的图表呈现训练参数变化趋势、模型结构、数据样本、高维数据分布等。帮助用户清晰直观地理解深度学习模型训练过程及模型结构,进而实现高效的模型调优,具体代码实现如下。

  • 步骤1:引入VisualDL库,定义作图数据存储位置(供第3步使用),本案例的路径是“log”。
from visualdl import LogWriter
log_writer = LogWriter("./log")
  • 步骤2:在训练过程中插入作图语句。当每100个batch训练完成后,将当前损失作为一个新增的数据点(iter和acc的映射对)存储到第一步设置的文件中。使用变量iter记录下已经训练的批次数,作为作图的X轴坐标。
log_writer.add_scalar(tag = 'acc', step = iter, value = avg_acc.numpy())
log_writer.add_scalar(tag = 'loss', step = iter, value = avg_loss.numpy())
iter = iter + 100 

# 安装VisualDL !pip install --upgrade --pre visualdl

Looking in indexes: https://mirror.baidu.com/pypi/simple/
Requirement already up-to-date: visualdl in /opt/conda/envs/python35-paddle120-env
/lib/python3.7/site-packages (2.2.0)
Requirement already satisfied, skipping upgrade: bce-python-sdk in /opt/conda/envs
/python35-paddle120-env/lib/python3.7/site-packages (from visualdl) (0.8.53)
Requirement already satisfied, skipping upgrade: Pillow>=7.0.0 in /opt/conda/envs
/python35-paddle120-env/lib/python3.7/site-packages (from visualdl) (7.1.2)
Requirement already satisfied, skipping upgrade: shellcheck-py in /opt/conda/envs
/python35-paddle120-env/lib/python3.7/site-packages (from visualdl) (0.7.1.1)
Requirement already satisfied, skipping upgrade: protobuf>=3.11.0 in /opt/conda
/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl) (3.14.0)
Requirement already satisfied, skipping upgrade: matplotlib in /opt/conda/envs/p
ython35-paddle120-env/lib/python3.7/site-packages (from visualdl) (2.2.3)
Requirement already satisfied, skipping upgrade: pandas in /opt/conda/envs/pyth
on35-paddle120-env/lib/python3.7/site-packages (from visualdl) (1.1.5)
Requirement already satisfied, skipping upgrade: requests in /opt/conda/envs/py
thon35-paddle120-env/lib/python3.7/site-packages (from visualdl) (2.22.0)
Requirement already satisfied, skipping upgrade: numpy in /opt/conda/envs/python
35-paddle120-env/lib/python3.7/site-packages (from visualdl) (1.20.3)
Requirement already satisfied, skipping upgrade: flask>=1.1.1 in /opt/conda/envs
/python35-paddle120-env/lib/python3.7/site-packages (from visualdl) (1.1.1)
Requirement already satisfied, skipping upgrade: flake8>=3.7.9 in /opt/conda/en
vs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl) (3.8.2)
Requirement already satisfied, skipping upgrade: six>=1.14.0 in /opt/conda/envs
/python35-paddle120-env/lib/python3.7/site-packages (from visualdl) (1.15.0)
Requirement already satisfied, skipping upgrade: pre-commit in /opt/conda/envs/p
ython35-paddle120-env/lib/python3.7/site-packages (from visualdl) (1.21.0)
Requirement already satisfied, skipping upgrade: Flask-Babel>=1.0.0 in /opt/cond
a/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl) (1.0.0)
Requirement already satisfied, skipping upgrade: pycryptodome>=3.8.0 in /opt/conda
/envs/python35-paddle120-env/lib/python3.7/site-packages (from bce-python-sdk->visualdl) (3.9.9)
Requirement already satisfied, skipping upgrade: future>=0.6.0 in /opt/conda/env
s/python35-paddle120-env/lib/python3.7/site-packages (from bce-python-sdk->visualdl) (0.18.0)
Requirement already satisfied, skipping upgrade: cycler>=0.10 in /opt/conda/env
s/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib->visualdl) (0.10.0)
Requirement already satisfied, skipping upgrade: pytz in /opt/conda/envs/pytho
n35-paddle120-env/lib/python3.7/site-packages (from matplotlib->visualdl) (2019.3)
Requirement already satisfied, skipping upgrade: pyparsing!=2.0.4,!=2.1.2,!=2
.1.6,>=2.0.1 in /opt/conda/envs/python35-paddle120-env/lib
/python3.7/site-packages (from matplotlib->visualdl) (2.4.2)
Requirement already satisfied, skipping upgrade: python-dateutil>=2.1 in /opt
/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib->visualdl) (2.8.0)
Requirement already satisfied, skipping upgrade: kiwisolver>=1.0.1 in /opt/conda/env
s/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib->visualdl) (1.1.0)
Requirement already satisfied, skipping upgrade: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.
1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl) (1.25.6)
Requirement already satisfied, skipping upgrade: idna<2.9,>=2.5 in /opt/conda/envs/pyth
on35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl) (2.8)
Requirement already satisfied, skipping upgrade: certifi>=2017.4.17 in /opt/conda/envs/
python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl) (2019.9.11)
Requirement already satisfied, skipping upgrade: chardet<3.1.0,>=3.0.2 in /opt/conda/en
vs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl) (3.0.4)
Requirement already satisfied, skipping upgrade: Werkzeug>=0.15 in /opt/conda/envs/pyth
on35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl) (0.16.0)
Requirement already satisfied, skipping upgrade: itsdangerous>=0.24 in /opt/conda/envs
/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl) (1.1.0)
Requirement already satisfied, skipping upgrade: Jinja2>=2.10.1 in /opt/conda/envs/pyth
on35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl) (2.10.1)
Requirement already satisfied, skipping upgrade: click>=5.1 in /opt/conda/envs/python35
-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl) (7.0)
Requirement already satisfied, skipping upgrade: mccabe<0.7.0,>=0.6.0 in /opt/conda/env
s/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl) (0.6.1)
Requirement already satisfied, skipping upgrade: pyflakes<2.3.0,>=2.2.0 in /opt/conda/e
nvs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl) (2.2.0)
Requirement already satisfied, skipping upgrade: importlib-metadata; python_version < "
3.8" in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl) (0.23)
Requirement already satisfied, skipping upgrade: pycodestyle<2.7.0,>=2.6.0a1 in /opt/co
nda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl) (2.6.0)
Requirement already satisfied, skipping upgrade: cfgv>=2.0.0 in /opt/conda/envs/python35
-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl) (2.0.1)
Requirement already satisfied, skipping upgrade: pyyaml in /opt/conda/envs/python35-padd
le120-env/lib/python3.7/site-packages (from pre-commit->visualdl) (5.1.2)
Requirement already satisfied, skipping upgrade: virtualenv>=15.2 in /opt/conda/envs/py
thon35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl) (16.7.9)
Requirement already satisfied, skipping upgrade: nodeenv>=0.11.1 in /opt/conda/envs/py
thon35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl) (1.3.4)
Requirement already satisfied, skipping upgrade: identify>=1.0.0 in /opt/conda/envs/p
ython35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl) (1.4.10)
Requirement already satisfied, skipping upgrade: aspy.yaml in /opt/conda/envs/python35-p
addle120-env/lib/python3.7/site-packages (from pre-commit->visualdl) (1.3.0)
Requirement already satisfied, skipping upgrade: toml in /opt/conda/envs/python35-paddl
e120-env/lib/python3.7/site-packages (from pre-commit->visualdl) (0.10.0)
Requirement already satisfied, skipping upgrade: Babel>=2.3 in /opt/conda/envs/python35
-paddle120-env/lib/python3.7/site-packages (from Flask-Babel>=1.0.0->visualdl) (2.8.0)
Requirement already satisfied, skipping upgrade: setuptools in /opt/conda/envs/python35
-paddle120-env/lib/python3.7/site-packages (from kiwisolver>=1.0.1->matplotlib->visualdl) (56.2.0)
Requirement already satisfied, skipping upgrade: MarkupSafe>=0.23 in /opt/conda/envs/pyt
hon35-paddle120-env/lib/python3.7/site-packages (from Jinja2>=2.10.1->flask>=1.1.1->visualdl) (1.1.1)
Requirement already satisfied, skipping upgrade: zipp>=0.5 in /opt/conda/envs/python35-pa
ddle120-env/lib/python3.7/site-packages (from importlib-metadata; python_version 
< "3.8"->flake8>=3.7.9->visualdl) (0.6.0)
Requirement already satisfied, skipping upgrade: more-itertools in /opt/conda/envs/p
ython35-paddle120-env/lib/python3.7/site-packages (from zipp>=0.5->importlib-metadata; 
python_version < "3.8"->flake8>=3.7.9->visualdl) (7.2.0)

#引入VisualDL库,并设定保存作图数据的文件位置 from visualdl import LogWriter
log_writer = LogWriter(logdir="./log") def train(model): model.train()
    
    opt = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())
    
    EPOCH_NUM = 10 iter = 0 for epoch_id in range(EPOCH_NUM): for batch_id, data in enumerate
(train_loader()): #准备数据,变得更加简洁 images, labels = data
            images = paddle.to_tensor(images)
            labels = paddle.to_tensor(labels) #前向计算的过程,同时拿到模型输出值和分类准确率
 predicts, avg_acc = model(images, labels) #计算损失,取一个批次样本损失的平均值 
loss = F.cross_entropy(predicts, labels)
            avg_loss = paddle.mean(loss) #每训练了100批次的数据,打印下当前Loss的情况 
if batch_id % 100 == 0:
                print("epoch: {}, batch: {}, loss is: {}, acc is {}".format(epoch_id,
 batch_id, avg_loss.numpy(), avg_acc.numpy()))
                log_writer.add_scalar(tag = 'acc', step = iter, value = avg_acc.numpy())
                log_writer.add_scalar(tag = 'loss', step = iter, value = avg_loss.numpy())
 iter = iter + 100 #后向传播,更新参数的过程 avg_loss.backward()
            opt.step()
            opt.clear_grad() #保存模型参数 paddle.save(model.state_dict(), 'mnist.pdparams')
    
model = MNIST()
train(model)
epoch: 0, batch: 0, loss is: [2.4428766], acc is [0.11]
epoch: 0, batch: 100, loss is: [0.14379957], acc is [0.98]
epoch: 0, batch: 200, loss is: [0.14659742], acc is [0.97]
epoch: 0, batch: 300, loss is: [0.15617883], acc is [0.95]
epoch: 0, batch: 400, loss is: [0.08011218], acc is [0.97]
epoch: 1, batch: 0, loss is: [0.02419444], acc is [1.]
epoch: 1, batch: 100, loss is: [0.05039296], acc is [0.98]
epoch: 1, batch: 200, loss is: [0.03004572], acc is [1.]
epoch: 1, batch: 300, loss is: [0.02300201], acc is [1.]
epoch: 1, batch: 400, loss is: [0.05973276], acc is [0.97]
epoch: 2, batch: 0, loss is: [0.04129345], acc is [0.98]
epoch: 2, batch: 100, loss is: [0.10553801], acc is [0.97]
epoch: 2, batch: 200, loss is: [0.03386007], acc is [0.99]
epoch: 2, batch: 300, loss is: [0.01647828], acc is [1.]
epoch: 2, batch: 400, loss is: [0.12856933], acc is [0.96]
epoch: 3, batch: 0, loss is: [0.02018252], acc is [0.99]
epoch: 3, batch: 100, loss is: [0.0415822], acc is [0.98]
epoch: 3, batch: 200, loss is: [0.01284742], acc is [0.99]
epoch: 3, batch: 300, loss is: [0.01533189], acc is [1.]
epoch: 3, batch: 400, loss is: [0.04745318], acc is [0.98]
epoch: 4, batch: 0, loss is: [0.04060025], acc is [0.97]
epoch: 4, batch: 100, loss is: [0.09810024], acc is [0.96]
epoch: 4, batch: 200, loss is: [0.01504112], acc is [1.]
epoch: 4, batch: 300, loss is: [0.06550632], acc is [0.98]
epoch: 4, batch: 400, loss is: [0.01058023], acc is [1.]
epoch: 5, batch: 0, loss is: [0.01061755], acc is [1.]
epoch: 5, batch: 100, loss is: [0.00329281], acc is [1.]
epoch: 5, batch: 200, loss is: [0.04983099], acc is [0.99]
epoch: 5, batch: 300, loss is: [0.00971762], acc is [1.]
epoch: 5, batch: 400, loss is: [0.04493643], acc is [0.99]
epoch: 6, batch: 0, loss is: [0.00563607], acc is [1.]
epoch: 6, batch: 100, loss is: [0.01756587], acc is [0.99]
epoch: 6, batch: 200, loss is: [0.00185569], acc is [1.]
epoch: 6, batch: 300, loss is: [0.0228942], acc is [0.99]
epoch: 6, batch: 400, loss is: [0.01174758], acc is [1.]
epoch: 7, batch: 0, loss is: [0.0134629], acc is [0.99]
epoch: 7, batch: 100, loss is: [0.00740626], acc is [1.]
epoch: 7, batch: 200, loss is: [0.00820032], acc is [1.]
epoch: 7, batch: 300, loss is: [0.03157279], acc is [0.98]
epoch: 7, batch: 400, loss is: [0.00324935], acc is [1.]
epoch: 8, batch: 0, loss is: [0.11782757], acc is [0.97]
epoch: 8, batch: 100, loss is: [0.01778471], acc is [0.99]
epoch: 8, batch: 200, loss is: [0.03268703], acc is [0.98]
epoch: 8, batch: 300, loss is: [0.0152339], acc is [0.99]
epoch: 8, batch: 400, loss is: [0.01905194], acc is [0.99]
epoch: 9, batch: 0, loss is: [0.01949414], acc is [0.99]
epoch: 9, batch: 100, loss is: [0.00271552], acc is [1.]
epoch: 9, batch: 200, loss is: [0.00252313], acc is [1.]
epoch: 9, batch: 300, loss is: [0.01054896], acc is [1.]
epoch: 9, batch: 400, loss is: [0.01838213], acc is [1.]

  • 步骤3:命令行启动VisualDL。

使用“visualdl --logdir [数据文件所在文件夹路径] 的命令启动VisualDL。在VisualDL启动后,命令行会打印出可用浏览器查阅图形结果的网址。

$ visualdl --logdir ./log --port 8080
  • 步骤4:打开浏览器,查看作图结果,如 图6 所示。

查阅的网址在第三步的启动命令后会打印出来(如http://127.0.0.1:8080/),将该网址输入浏览器地址栏刷新页面的效果如下图所示。除了右侧对数据点的作图外,左侧还有一个控制板,可以调整诸多作图的细节。


图6:VisualDL作图示例


作业题 2-4

  • 将普通神经网络模型的每层输出打印,观察内容。
  • 将分类准确率的指标 用PLT库画图表示。
  • 通过分类准确率,判断以采用不同损失函数训练模型的效果优劣。
  • 作图比较:随着训练进行,模型在训练集和测试集上的Loss曲线。
  • 调节正则化权重,观察4的作图曲线的变化,并分析原因。
相似文档
  • 模型加载及恢复训练: 在快速入门中,我们已经介绍了将训练好的模型保存到磁盘文件的方法。应用程序可以随时加载模型,完成预测任务。但是在日常训练工作中我们会遇到一些突发情况,导致训练过程主动或被动的中断。如果训练一个模型需要花费几天的训练时间,中断后从初始状态重新训练是不可接受的。
  • 动静转换: 动态图有诸多优点,比如易用的接口、Python风格的编程体验、友好的调试交互机制等。在动态图模式下,代码可以按照我们编写的顺序依次执行。这种机制更符合Python程序员的使用习惯,可以很方便地将脑海中的想法快速地转化为实际代码,也更容易调试。
  • 截止目前,诸位读者已经掌握了使用飞桨完成深度学习建模的方法,并且可以编写相当强大的模型。如果将每个模型部分均展开,整个模型实现有几百行代码,可以灵活的实现各种建模过程中的需求。
  • 计算机视觉作为一门让机器学会如何去“看”的学科,具体的说,就是让机器去识别摄像机拍摄的图片或视频中的物体,检测出物体所在的位置,并对目标物体进行跟踪,从而理解并描述出图片或视频里的场景和故事,以此来模拟人脑视觉系统。因此,计算机视觉也通常被叫做机器视觉,其目的是建立能够从图像或者视频中“感知”信息的人工系统。
  • 图像分类是根据图像的语义信息对不同类别图像进行区分,是计算机视觉的核心,是物体检测、图像分割、物体跟踪、行为分析、人脸识别等其他高层次视觉任务的基础。图像分类在许多领域都有着广泛的应用,如:安防领域的人脸识别和智能视频分析等,交通领域的交通场景识别,互联网领域基于内容的图像检索和相册自动归类,医学领域的图像识别等。
官方微信
联系客服
400-826-7010
7x24小时客服热线
分享
  • QQ好友
  • QQ空间
  • 微信
  • 微博
返回顶部