上云无忧 > 文档中心 > 第八章:精通深度学习的高级内容 - 基于飞桨进行二次研发
飞桨PaddlePaddle开源深度学习平台
第八章:精通深度学习的高级内容 - 基于飞桨进行二次研发

文档简介:
基于飞桨二次研发: 飞桨有着丰富的模型资源,开发者可以直接使用飞桨已有的模型,也可以在开源模型的基础上进行二次开发。国内曾有团队使用飞桨的PaddleDetection取得了国际AI比赛的第一名,也有一些企业使用飞桨的官方模型完成了多项业务的落地。
*此产品及展示信息均由百度智能云官方提供。免费试用 咨询热线:400-826-7010,为您提供专业的售前咨询,让您快速了解云产品,助您轻松上云! 微信咨询
  免费试用、价格特惠

基于飞桨二次研发

飞桨有着丰富的模型资源,开发者可以直接使用飞桨已有的模型,也可以在开源模型的基础上进行二次开发。国内曾有团队使用飞桨的PaddleDetection取得了国际AI比赛的第一名,也有一些企业使用飞桨的官方模型完成了多项业务的落地。

用户在使用飞桨研发最新模型或特殊领域模型的时,极少数情况下会遇到缺少一些特殊算子的情况,如:已有的算子无法组合出需要的运算逻辑,或使用已有算子组合得到的运算逻辑无法满足性能需求,此时可以使用飞桨自定义算子机制,编写新的算子。

本节内容主要介绍使用飞桨已有算子组合得到学术算子、自定义算子的实现方法,以及在飞桨Github贡献代码的实现流程。

一、基于飞桨已有算子组Involution学术算子的实现方法

involution在设计上与convolution相反,即在通道维度共享kernel,而在空间维度采用空间特异的kernel进行更灵活的建模。involution kernel的大小为 H×W×K×K×GH \times W \times K \times K \times GH×W×K×K×G,其中G≪CG \ll CGC ,表示所有通道共享G个kernel。involution的计算公式为:

Yi,j,k=∑u,v∈ΔkHi,j,u+⌊K/w⌋,v+⌊K/2⌋,⌈kG/C⌉Xi+u,j+v,kY_{i,j,k}=\sum_{u,v \in \Delta k}H_{i,j,u+\lfloor K/w \rfloor,v+\lfloor K/2 \rfloor,\lceil kG/C \rceil}X_{i+u,j+v,k}Yi,j,k=u,vΔkHi,j,u+K/w,v+K/2,kG/CXi+u,j+v,k

其中H∈RH×W×K×K×GH \in R^{H \times W \times K \times K \times G}HRH×W×K×K×G是involution kernel。

在involution中,我们没有像convolution一样采用固定的weight matrix作为可学习的参数,而是考虑基于输入feature map生成对应的involution kernel,从而确保kernel size和input feature size在空间维度上能够自动对齐。举例来说,如在ImageNet上使用固定224×224224 \times 224224×224大小的图像作为输入训练得到的权重,就无法迁移到输入图像尺寸更大的下游任务中(比如检测、分割等)。involution kernel的计算公式为:

Hi,j=ϕ(XΨi,j)H_{i,j}=\phi(X_{\Psi_{i,j}})Hi,j=ϕ(XΨi,j)

其中Ψi,j\Psi_{i,j}Ψi,j是坐标(i,j)(i,j)(i,j)邻域的一个index集合,因此XΨi,jX_{\Psi_{i,j}}XΨi,j表示feature map上包含Xi,jX_{i,j}Xi,j的某个patch。

关于上述的kernel生成函数ϕ\phiϕ,有多种设计方式,值得大家进一步探索。本节内容从简单、有效的设计理念出发,提供了一种类似于SENet的bottleneck结构进行实验:Ψi,j\Psi_{i,j}Ψi,j取为(i,j){(i,j)}(i,j)这个单点集,即XΨi,jX_{\Psi_{i,j}}XΨi,j取为feature map上坐标为(i,j)(i, j)(i,j)的单个像素,从而得到了involution kernel生成的一种实例化。

Hi,j=ϕ(Xi,j)=W1σ(W0Xi,j)H_{i,j}=\phi(X_{i,j})=W_{1}\sigma (W_{0}X_{i,j})Hi,j=ϕ(Xi,j)=W1σ(W0Xi,j)

其中W0∈RCr×CW_{0}\in R^{\frac{C}{r}\times C}W0RrC×CW1∈R(K×K×G)×CrW_{1}\in R^{(K \times K \times G)\times \frac{C}{r}}W1R(K×K×G)×rC是线性变换矩阵,rrr是通道缩减比率,σ\sigmaσ是中间的BN和ReLU。

值得注意的是,设计不同的kernel生成函数可以得到involution不同的实例化。比如可以去探索更加精巧的设计来继续发掘involution的潜力,另外通过采用特定的实例化方法也可以将其特例化成为self-attention的形式。在上述一种简单的involution kernel的实例化下,完整的involution的示意图如 图1 所示。

图1:involution的示意图
  • 针对输入feature map的一个坐标点上的特征向量,通过ϕ\phiϕ(FC-BN-ReLU-FC)和reshape (channel-to-space)变换展开成kernel的形状,从而得到这个坐标点上对应的involution kernel;
  • 和输入feature map上这个坐标点邻域的特征向量进行Multiply-Add得到最终输出的feature map。

具体操作流程和tensor形状变化如 图2 所示,其中,Ωi,j\Omega_{i,j}Ωi,j是坐标(i,j)(i,j)(i,j)附近K×KK \times KK×K的邻域。

图2:操作流程和tensor形状变化

基于飞桨已有算子组Involution学术算子的代码实现如下所示。

1) 加载第三方库

!pip install wget

Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Collecting wget
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/47/6a/62e288da7bcda82b935ff0c6cf
e542970f04e29c756b0e147251b2fb251f/wget-3.2.zip (10 kB)
  Preparing metadata (setup.py) ... done
Building wheels for collected packages: wget
  Building wheel for wget (setup.py) ... done
  Created wheel for wget: filename=wget-3.2-py3-none-any.whl size=9673 sha256=b9cfd3b8eb
fac2d3bf7570cd1655e512020764133c229e5af284d10b4bc2cc40
  Stored in directory: /home/aistudio/.cache/pip/wheels/dc/31/7f/a4a4cbe7ae34f1a38f54f2a9fc77c06d20b10d1dc8557eb191
Successfully built wget
Installing collected packages: wget
Successfully installed wget-3.2 WARNING: You are using pip version 22.0.4; however, version 22.1 is available.
You should consider upgrading via the '/opt/conda/envs/python35-paddle120-
env/bin/python -m pip install --upgrade pip' command. 

import paddle import paddle.nn as nn import paddle.vision.transforms as T import 
os import wget from paddle.vision.models import resnet from paddle.vision.datasets import Flowers

2) Involution算子的实现

class Involution(nn.Layer): def __init__(self,
                 channels,
                 kernel_size,
                 stride): super(Involution, self).__init__()
        self.kernel_size = kernel_size
        self.stride = stride
        self.channels = channels
        reduction_ratio = 4 self.group_channels = 16 self.groups = self.channels // self.group_channels
        self.conv1 = nn.Sequential(
            ('conv', nn.Conv2D(
                in_channels=channels,
                out_channels=channels // reduction_ratio,
                kernel_size=1,
                bias_attr=False )),
            ('bn', nn.BatchNorm2D(channels // reduction_ratio)),
            ('activate', nn.ReLU())
        )
        self.conv2 = nn.Sequential(
            ('conv', nn.Conv2D(
                in_channels=channels // reduction_ratio,
                out_channels=kernel_size**2 * self.groups,
                kernel_size=1,
                stride=1))
        ) if stride > 1:
            self.avgpool = nn.AvgPool2D(stride, stride) def forward(self, x): weight = self.conv2(self.conv1(
            x if self.stride == 1 else self.avgpool(x)))
        b, c, h, w = weight.shape
        weight = weight.reshape((
            b, self.groups, self.kernel_size**2, h, w)).unsqueeze(2)

        out = nn.functional.unfold(
            x, self.kernel_size, strides=self.stride, paddings=(self.kernel_size-1)//2, dilations=1)
        out = out.reshape(
            (b, self.groups, self.group_channels, self.kernel_size**2, h, w))
        out = (weight * out).sum(axis=3).reshape((b, self.channels, h, w)) return out

如上可见,对于多数新出现的计算函数,我们基本均可以用基础的数学算子来完成计算,这个过程和一般意义上的组网并没有特别的不同。 下面我们就来看看如何利用新实现的involution函数来组建网络,实现对flower数据集的分类任务。

3) 基于RedNet26的Flower数据集分类应用示例

# 进行数据集预处理操作 # Backend: cv2 transforms = T.Compose([
    T.Resize(256),
    T.CenterCrop(224),
    T.Normalize(
        mean=[123.675, 116.28, 103.53],
        std=[58.395, 57.12, 57.375],
        to_rgb=True,
        data_format='HWC' ),
    T.ToTensor(),
])
# 使用Involution搭建RedNet需要的BottleneckBlock层 class BottleneckBlock(resnet.BottleneckBlock): def __init__(self,
                 inplanes,
                 planes,
                 stride=1,
                 downsample=None,
                 groups=1,
                 base_width=64,
                 dilation=1,
                 norm_layer=None): super(BottleneckBlock, self).__init__(inplanes, planes, stride,
                                              downsample, groups, base_width, dilation, norm_layer)
        width = int(planes * (base_width / 64.)) * groups
        self.conv2 = Involution(width, 7, stride)

# 搭建RedNet网络 class RedNet(resnet.ResNet): def __init__(self, block, depth,
 class_dim=1000, with_pool=True): super(RedNet, self).__init__(block=block, depth=50,
                                     num_classes=class_dim, with_pool=with_pool)
        layer_cfg = { 26: [1, 2, 4, 1], 38: [2, 3, 5, 2], 50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3]
        }
        layers = layer_cfg[depth]

        self.conv1 = None self.bn1 = None self.relu = None self.inplanes = 64 self.class_dim = class_dim

        self.stem = nn.Sequential(
            nn.Sequential(
                ('conv', nn.Conv2D(
                    in_channels=3,
                    out_channels=self.inplanes // 2,
                    kernel_size=3,
                    stride=2,
                    padding=1,
                    bias_attr=False )),
                ('bn', nn.BatchNorm2D(self.inplanes // 2)),
                ('activate', nn.ReLU())
            ),
            Involution(self.inplanes // 2, 3, 1),
            nn.BatchNorm2D(self.inplanes // 2),
            nn.ReLU(),
            nn.Sequential(
                ('conv', nn.Conv2D(
                    in_channels=self.inplanes // 2,
                    out_channels=self.inplanes,
                    kernel_size=3,
                    stride=1,
                    padding=1,
                    bias_attr=False )),
                ('bn', nn.BatchNorm2D(self.inplanes)),
                ('activate', nn.ReLU())
            )
        )

        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2) def forward(self, x): x = self.stem(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x) if self.with_pool:
            x = self.avgpool(x) if self.class_dim > 0:
            x = paddle.flatten(x, 1)
            x = self.fc(x) return x

# 该函数用于下载预训练模型并加载 def load_model(model, url): file_name = url.split(r'filename%3D')[-1]
    model_path = os.path.join('pretrained_models', file_name) if not os.path.isfile(model_path):
 if not os.path.exists('pretrained_models'):
            os.mkdir('pretrained_models')
        wget.download(url, out=model_path)
    params = paddle.load(model_path)
    model.set_dict(params) return model

urls = { 'rednet_26': r'https://bj.bcebos.com/v1/ai-studio-online/14091d6c21774c5fb48d74723db7
eaf22e1c5ff621154a588534cb92918c04e2?responseContentDisposition=attachment%3B%20filename%3Drednet26.pdparams' }

pretrained=True model = RedNet(BottleneckBlock, 26) if pretrained:
    model = load_model(model, urls['rednet_26'])

W0512 10:38:56.849705    99 gpu_context.cc:244] Please NOTE: device: 0, GPU Compute Capability: 
7.0, Driver API Version: 10.1, Runtime API Version: 10.1
W0512 10:38:56.855173    99 gpu_context.cc:272] device: 0, cuDNN Version: 7.6.

# Model summary  paddle.summary(model, input_size=(1, 3, 224, 224))
-------------------------------------------------------------------------------
   Layer (type)         Input Shape          Output Shape         Param #    
===============================================================================
     Conv2D-86       [[1, 3, 224, 224]]   [1, 32, 112, 112]         864      
  BatchNorm2D-70    [[1, 32, 112, 112]]   [1, 32, 112, 112]         128      
      ReLU-34       [[1, 32, 112, 112]]   [1, 32, 112, 112]          0       
     Conv2D-87      [[1, 32, 112, 112]]    [1, 8, 112, 112]         256      
  BatchNorm2D-71     [[1, 8, 112, 112]]    [1, 8, 112, 112]         32       
      ReLU-35        [[1, 8, 112, 112]]    [1, 8, 112, 112]          0       
     Conv2D-88       [[1, 8, 112, 112]]   [1, 18, 112, 112]         162      
   Involution-17    [[1, 32, 112, 112]]   [1, 32, 112, 112]          0       
  BatchNorm2D-72    [[1, 32, 112, 112]]   [1, 32, 112, 112]         128      
      ReLU-36       [[1, 32, 112, 112]]   [1, 32, 112, 112]          0       
     Conv2D-89      [[1, 32, 112, 112]]   [1, 64, 112, 112]       18,432     
  BatchNorm2D-73    [[1, 64, 112, 112]]   [1, 64, 112, 112]         256      
      ReLU-37       [[1, 64, 112, 112]]   [1, 64, 112, 112]          0       
    MaxPool2D-1     [[1, 64, 112, 112]]    [1, 64, 56, 56]           0       
     Conv2D-91       [[1, 64, 56, 56]]     [1, 64, 56, 56]         4,096     
  BatchNorm2D-75     [[1, 64, 56, 56]]     [1, 64, 56, 56]          256      
      ReLU-38        [[1, 256, 56, 56]]    [1, 256, 56, 56]          0       
     Conv2D-94       [[1, 64, 56, 56]]     [1, 16, 56, 56]         1,024     
  BatchNorm2D-78     [[1, 16, 56, 56]]     [1, 16, 56, 56]          64       
      ReLU-39        [[1, 16, 56, 56]]     [1, 16, 56, 56]           0       
     Conv2D-95       [[1, 16, 56, 56]]     [1, 196, 56, 56]        3,332     
   Involution-18     [[1, 64, 56, 56]]     [1, 64, 56, 56]           0       
  BatchNorm2D-76     [[1, 64, 56, 56]]     [1, 64, 56, 56]          256      
     Conv2D-93       [[1, 64, 56, 56]]     [1, 256, 56, 56]       16,384     
  BatchNorm2D-77     [[1, 256, 56, 56]]    [1, 256, 56, 56]        1,024     
     Conv2D-90       [[1, 64, 56, 56]]     [1, 256, 56, 56]       16,384     
  BatchNorm2D-74     [[1, 256, 56, 56]]    [1, 256, 56, 56]        1,024     
BottleneckBlock-17   [[1, 64, 56, 56]]     [1, 256, 56, 56]          0       
     Conv2D-97       [[1, 256, 56, 56]]    [1, 128, 56, 56]       32,768     
  BatchNorm2D-80     [[1, 128, 56, 56]]    [1, 128, 56, 56]         512      
      ReLU-40        [[1, 512, 28, 28]]    [1, 512, 28, 28]          0       
    AvgPool2D-4      [[1, 128, 56, 56]]    [1, 128, 28, 28]          0       
    Conv2D-100       [[1, 128, 28, 28]]    [1, 32, 28, 28]         4,096     
  BatchNorm2D-83     [[1, 32, 28, 28]]     [1, 32, 28, 28]          128      
      ReLU-41        [[1, 32, 28, 28]]     [1, 32, 28, 28]           0       
    Conv2D-101       [[1, 32, 28, 28]]     [1, 392, 28, 28]       12,936     
   Involution-19     [[1, 128, 56, 56]]    [1, 128, 28, 28]          0       
  BatchNorm2D-81     [[1, 128, 28, 28]]    [1, 128, 28, 28]         512      
     Conv2D-99       [[1, 128, 28, 28]]    [1, 512, 28, 28]       65,536     
  BatchNorm2D-82     [[1, 512, 28, 28]]    [1, 512, 28, 28]        2,048     
     Conv2D-96       [[1, 256, 56, 56]]    [1, 512, 28, 28]       131,072    
  BatchNorm2D-79     [[1, 512, 28, 28]]    [1, 512, 28, 28]        2,048     
BottleneckBlock-18   [[1, 256, 56, 56]]    [1, 512, 28, 28]          0       
    Conv2D-102       [[1, 512, 28, 28]]    [1, 128, 28, 28]       65,536     
  BatchNorm2D-84     [[1, 128, 28, 28]]    [1, 128, 28, 28]         512      
      ReLU-42        [[1, 512, 28, 28]]    [1, 512, 28, 28]          0       
    Conv2D-105       [[1, 128, 28, 28]]    [1, 32, 28, 28]         4,096     
  BatchNorm2D-87     [[1, 32, 28, 28]]     [1, 32, 28, 28]          128      
      ReLU-43        [[1, 32, 28, 28]]     [1, 32, 28, 28]           0       
    Conv2D-106       [[1, 32, 28, 28]]     [1, 392, 28, 28]       12,936     
   Involution-20     [[1, 128, 28, 28]]    [1, 128, 28, 28]          0       
  BatchNorm2D-85     [[1, 128, 28, 28]]    [1, 128, 28, 28]         512      
    Conv2D-104       [[1, 128, 28, 28]]    [1, 512, 28, 28]       65,536     
  BatchNorm2D-86     [[1, 512, 28, 28]]    [1, 512, 28, 28]        2,048     
BottleneckBlock-19   [[1, 512, 28, 28]]    [1, 512, 28, 28]          0       
    Conv2D-108       [[1, 512, 28, 28]]    [1, 256, 28, 28]       131,072    
  BatchNorm2D-89     [[1, 256, 28, 28]]    [1, 256, 28, 28]        1,024     
      ReLU-44       [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
    AvgPool2D-5      [[1, 256, 28, 28]]    [1, 256, 14, 14]          0       
    Conv2D-111       [[1, 256, 14, 14]]    [1, 64, 14, 14]        16,384     
  BatchNorm2D-92     [[1, 64, 14, 14]]     [1, 64, 14, 14]          256      
      ReLU-45        [[1, 64, 14, 14]]     [1, 64, 14, 14]           0       
    Conv2D-112       [[1, 64, 14, 14]]     [1, 784, 14, 14]       50,960     
   Involution-21     [[1, 256, 28, 28]]    [1, 256, 14, 14]          0       
  BatchNorm2D-90     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
    Conv2D-110       [[1, 256, 14, 14]]   [1, 1024, 14, 14]       262,144    
  BatchNorm2D-91    [[1, 1024, 14, 14]]   [1, 1024, 14, 14]        4,096     
    Conv2D-107       [[1, 512, 28, 28]]   [1, 1024, 14, 14]       524,288    
  BatchNorm2D-88    [[1, 1024, 14, 14]]   [1, 1024, 14, 14]        4,096     
BottleneckBlock-20   [[1, 512, 28, 28]]   [1, 1024, 14, 14]          0       
    Conv2D-113      [[1, 1024, 14, 14]]    [1, 256, 14, 14]       262,144    
  BatchNorm2D-93     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
      ReLU-46       [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
    Conv2D-116       [[1, 256, 14, 14]]    [1, 64, 14, 14]        16,384     
  BatchNorm2D-96     [[1, 64, 14, 14]]     [1, 64, 14, 14]          256      
      ReLU-47        [[1, 64, 14, 14]]     [1, 64, 14, 14]           0       
    Conv2D-117       [[1, 64, 14, 14]]     [1, 784, 14, 14]       50,960     
   Involution-22     [[1, 256, 14, 14]]    [1, 256, 14, 14]          0       
  BatchNorm2D-94     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
    Conv2D-115       [[1, 256, 14, 14]]   [1, 1024, 14, 14]       262,144    
  BatchNorm2D-95    [[1, 1024, 14, 14]]   [1, 1024, 14, 14]        4,096     
BottleneckBlock-21  [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
    Conv2D-118      [[1, 1024, 14, 14]]    [1, 256, 14, 14]       262,144    
  BatchNorm2D-97     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
      ReLU-48       [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
    Conv2D-121       [[1, 256, 14, 14]]    [1, 64, 14, 14]        16,384     
  BatchNorm2D-100    [[1, 64, 14, 14]]     [1, 64, 14, 14]          256      
      ReLU-49        [[1, 64, 14, 14]]     [1, 64, 14, 14]           0       
    Conv2D-122       [[1, 64, 14, 14]]     [1, 784, 14, 14]       50,960     
   Involution-23     [[1, 256, 14, 14]]    [1, 256, 14, 14]          0       
  BatchNorm2D-98     [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
    Conv2D-120       [[1, 256, 14, 14]]   [1, 1024, 14, 14]       262,144    
  BatchNorm2D-99    [[1, 1024, 14, 14]]   [1, 1024, 14, 14]        4,096     
BottleneckBlock-22  [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
    Conv2D-123      [[1, 1024, 14, 14]]    [1, 256, 14, 14]       262,144    
  BatchNorm2D-101    [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
      ReLU-50       [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
    Conv2D-126       [[1, 256, 14, 14]]    [1, 64, 14, 14]        16,384     
  BatchNorm2D-104    [[1, 64, 14, 14]]     [1, 64, 14, 14]          256      
      ReLU-51        [[1, 64, 14, 14]]     [1, 64, 14, 14]           0       
    Conv2D-127       [[1, 64, 14, 14]]     [1, 784, 14, 14]       50,960     
   Involution-24     [[1, 256, 14, 14]]    [1, 256, 14, 14]          0       
  BatchNorm2D-102    [[1, 256, 14, 14]]    [1, 256, 14, 14]        1,024     
    Conv2D-125       [[1, 256, 14, 14]]   [1, 1024, 14, 14]       262,144    
  BatchNorm2D-103   [[1, 1024, 14, 14]]   [1, 1024, 14, 14]        4,096     
BottleneckBlock-23  [[1, 1024, 14, 14]]   [1, 1024, 14, 14]          0       
    Conv2D-129      [[1, 1024, 14, 14]]    [1, 512, 14, 14]       524,288    
  BatchNorm2D-106    [[1, 512, 14, 14]]    [1, 512, 14, 14]        2,048     
      ReLU-52        [[1, 2048, 7, 7]]     [1, 2048, 7, 7]           0       
    AvgPool2D-6      [[1, 512, 14, 14]]     [1, 512, 7, 7]           0       
    Conv2D-132        [[1, 512, 7, 7]]      [1, 128, 7, 7]        65,536     
  BatchNorm2D-109     [[1, 128, 7, 7]]      [1, 128, 7, 7]          512      
      ReLU-53         [[1, 128, 7, 7]]      [1, 128, 7, 7]           0       
    Conv2D-133        [[1, 128, 7, 7]]     [1, 1568, 7, 7]        202,272    
   Involution-25     [[1, 512, 14, 14]]     [1, 512, 7, 7]           0       
  BatchNorm2D-107     [[1, 512, 7, 7]]      [1, 512, 7, 7]         2,048     
    Conv2D-131        [[1, 512, 7, 7]]     [1, 2048, 7, 7]       1,048,576   
  BatchNorm2D-108    [[1, 2048, 7, 7]]     [1, 2048, 7, 7]         8,192     
    Conv2D-128      [[1, 1024, 14, 14]]    [1, 2048, 7, 7]       2,097,152   
  BatchNorm2D-105    [[1, 2048, 7, 7]]     [1, 2048, 7, 7]         8,192     
BottleneckBlock-24  [[1, 1024, 14, 14]]    [1, 2048, 7, 7]           0       
AdaptiveAvgPool2D-1  [[1, 2048, 7, 7]]     [1, 2048, 1, 1]           0       
     Linear-1           [[1, 2048]]           [1, 1000]          2,049,000   
===============================================================================
Total params: 9,264,318
Trainable params: 9,202,014
Non-trainable params: 62,304
-------------------------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 188.62
Params size (MB): 35.34
Estimated Total Size (MB): 224.53
-------------------------------------------------------------------------------

{'total_params': 9264318, 'trainable_params': 9202014}

# 设置优化器 opt = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters()) 
# 设置损失函数 loss = nn.CrossEntropyLoss() # 设置Accuracy评估方式 metric = paddle.metric.Accur
acy(topk=(1, 5)) # 使用PaddlePaddle的高级API model = paddle.Model(model) # 准备函数 
model.prepare(optimizer=opt, loss=loss, metrics=metric)

# 加载Flower数据集 train_dataset = Flowers(mode='train', transform=transforms, backend='cv2')
val_dataset = Flowers(mode='test',  transform=transforms, backend='cv2')
item   216/84195 [..............................] - ETA: 1:19 - 941us/it

Cache file /home/aistudio/.cache/paddle/dataset/flowers/102flowers.tgz not found, 
downloading http://paddlemodels.bj.bcebos.com/flowers/102flowers.tgz 
Begin to download


item   217/84195 [..............................] - ETA: 1:18 - 939us/item

item   218/84195 [..............................] - ETA: 1:19 - 941us/item

item   219/84195 [..............................] - ETA: 1:18 - 939us/item

item   220/84195 [..............................] - ETA: 1:18 - 941us/item

item   221/84195 [..............................] - ETA: 1:18 - 940us/item

item   222/84195 [..............................] - ETA: 1:19 - 941us/item

item   223/84195 [..............................] - ETA: 1:18 - 939us/item

item   224/84195 [..............................] - ETA: 1:18 - 938us/item

item   225/84195 [..............................] - ETA: 1:18 - 939us/item

item   226/84195 [..............................] - ETA: 1:18 - 938us/item

item   227/84195 [..............................] - ETA: 1:18 - 937us/item

item   228/84195 [..............................] - ETA: 1:18 - 940us/item

item   229/84195 [..............................] - ETA: 1:18 - 939us/item

item   230/84195 [..............................] - ETA: 1:18 - 940us/item

item   231/84195 [..............................] - ETA: 1:18 - 938us/item

item   232/84195 [..............................] - ETA: 1:18 - 940us/item

item   233/84195 [..............................] - ETA: 1:18 - 938us/item

item   234/84195 [..............................] - ETA: 1:18 - 940us/item

item   235/84195 [..............................] - ETA: 1:18 - 940us/item

item   236/84195 [..............................] - ETA: 1:18 - 940us/item

item   237/84195 [..............................] - ETA: 1:18 - 940us/item

item   238/84195 [..............................] - ETA: 1:19 - 941us/item

item   239/84195 [..............................] - ETA: 1:19 - 941us/item

item   240/84195 [..............................] - ETA: 1:19 - 942us/item

item   241/84195 [..............................] - ETA: 1:18 - 940us/item

item   242/84195 [..............................] - ETA: 1:19 - 942us/item

item   243/84195 [..............................] - ETA: 1:18 - 940us/item

item   244/84195 [..............................] - ETA: 1:18 - 939us/item

item   245/84195 [..............................] - ETA: 1:19 - 943us/item

item   246/84195 [..............................] - ETA: 1:19 - 943us/item

item   247/84195 [..............................] - ETA: 1:19 - 941us/item

item   248/84195 [..............................] - ETA: 1:19 - 943us/item

item   249/84195 [..............................] - ETA: 1:19 - 943us/item

item   250/84195 [..............................] - ETA: 1:19 - 943us/item

item   251/84195 [..............................] - ETA: 1:19 - 943us/item

item   252/84195 [..............................] - ETA: 1:19 - 943us/item

item   253/84195 [..............................] - ETA: 1:19 - 942us/item

item   254/84195 [..............................] - ETA: 1:19 - 943us/item

item   255/84195 [..............................] - ETA: 1:19 - 941us/item

item   256/84195 [..............................] - ETA: 1:19 - 943us/item

item   257/84195 [..............................] - ETA: 1:19 - 942us/item

item   258/84195 [..............................] - ETA: 1:19 - 943us/item

item   259/84195 [..............................] - ETA: 1:19 - 942us/item

item   260/84195 [..............................] - ETA: 1:18 - 940us/item

item   261/84195 [..............................] - ETA: 1:19 - 942us/item

item   262/84195 [..............................] - ETA: 1:19 - 943us/item

item   263/84195 [..............................] - ETA: 1:19 - 942us/item

item   264/84195 [..............................] - ETA: 1:18 - 941us/item

item   265/84195 [..............................] - ETA: 1:19 - 941us/item

item   266/84195 [..............................] - ETA: 1:19 - 945us/item

item   267/84195 [..............................] - ETA: 1:19 - 944us/item

item   268/84195 [..............................] - ETA: 1:19 - 945us/item

item   269/84195 [..............................] - ETA: 1:19 - 945us/item

item   270/84195 [..............................] - ETA: 1:19 - 943us/item

item   271/84195 [..............................] - ETA: 1:19 - 944us/item

item   272/84195 [..............................] - ETA: 1:19 - 942us/item

item   273/84195 [..............................] - ETA: 1:19 - 944us/item

item   274/84195 [..............................] - ETA: 1:19 - 943us/item

item   275/84195 [..............................] - ETA: 1:19 - 942us/item

item   276/84195 [..............................] - ETA: 1:19 - 945us/item

item   277/84195 [..............................] - ETA: 1:19 - 943us/item

item   278/84195 [..............................] - ETA: 1:19 - 942us/item

item   279/84195 [..............................] - ETA: 1:19 - 944us/item

item   280/84195 [..............................] - ETA: 1:19 - 946us/item

item   281/84195 [..............................] - ETA: 1:19 - 947us/item

item   282/84195 [..............................] - ETA: 1:19 - 948us/item

item   283/84195 [..............................] - ETA: 1:19 - 946us/item

item   284/84195 [..............................] - ETA: 1:19 - 947us/item

item   285/84195 [..............................] - ETA: 1:19 - 948us/item

item   286/84195 [..............................] - ETA: 1:19 - 946us/item

item   287/84195 [..............................] - ETA: 1:19 - 945us/item

item   288/84195 [..............................] - ETA: 1:19 - 944us/item

item   289/84195 [..............................] - ETA: 1:19 - 944us/item

item   290/84195 [..............................] - ETA: 1:19 - 948us/item

item   291/84195 [..............................] - ETA: 1:19 - 947us/item

item   292/84195 [..............................] - ETA: 1:19 - 948us/item

item   293/84195 [..............................] - ETA: 1:19 - 947us/item

item   294/84195 [..............................] - ETA: 1:19 - 945us/item

item   295/84195 [..............................] - ETA: 1:19 - 946us/item

item   296/84195 [..............................] - ETA: 1:19 - 949us/item

item   297/84195 [..............................] - ETA: 1:19 - 949us/item

item   298/84195 [..............................] - ETA: 1:19 - 950us/item

item   299/84195 [..............................] - ETA: 1:19 - 949us/item

item   300/84195 [..............................] - ETA: 1:19 - 950us/item

item   301/84195 [..............................] - ETA: 1:19 - 949us/item

item   302/84195 [..............................] - ETA: 1:19 - 950us/item

item   303/84195 [..............................] - ETA: 1:19 - 949us/item

item   304/84195 [..............................] - ETA: 1:19 - 950us/item

item   305/84195 [..............................] - ETA: 1:19 - 950us/item

item   306/84195 [..............................] - ETA: 1:19 - 951us/item

item   307/84195 [..............................] - ETA: 1:19 - 951us/item

item   308/84195 [..............................] - ETA: 1:19 - 951us/item

item   309/84195 [..............................] - ETA: 1:19 - 950us/item

item   310/84195 [..............................] - ETA: 1:19 - 949us/item

item   311/84195 [..............................] - ETA: 1:19 - 951us/item

item   312/84195 [..............................] - ETA: 1:19 - 951us/item

item   313/84195 [..............................] - ETA: 1:19 - 951us/item

item   314/84195 [..............................] - ETA: 1:19 - 950us/item

item   315/84195 [..............................] - ETA: 1:19 - 949us/item

item   316/84195 [..............................] - ETA: 1:19 - 951us/item

item   317/84195 [..............................] - ETA: 1:19 - 950us/item

item   318/84195 [..............................] - ETA: 1:19 - 951us/item

item   319/84195 [..............................] - ETA: 1:19 - 950us/item

item   320/84195 [..............................] - ETA: 1:19 - 951us/item

item   321/84195 [..............................] - ETA: 1:19 - 949us/item

item   322/84195 [..............................] - ETA: 1:19 - 948us/item

item   323/84195 [..............................] - ETA: 1:19 - 949us/item

item   324/84195 [..............................] - ETA: 1:19 - 951us/item

item   325/84195 [..............................] - ETA: 1:19 - 950us/item

item   326/84195 [..............................] - ETA: 1:19 - 951us/item

item   327/84195 [..............................] - ETA: 1:19 - 951us/item

item   328/84195 [..............................] - ETA: 1:19 - 951us/item

item   329/84195 [..............................] - ETA: 1:19 - 951us/item

item   330/84195 [..............................] - ETA: 1:19 - 951us/item

item   331/84195 [..............................] - ETA: 1:19 - 950us/item

item   332/84195 [..............................] - ETA: 1:19 - 951us/item

item   333/84195 [..............................] - ETA: 1:19 - 950us/item

item   334/84195 [..............................] - ETA: 1:19 - 950us/item

item   335/84195 [..............................] - ETA: 1:19 - 949us/item

item   336/84195 [..............................] - ETA: 1:19 - 950us/item

item   337/84195 [..............................] - ETA: 1:19 - 950us/item

item   338/84195 [..............................] - ETA: 1:19 - 951us/item

item   339/84195 [..............................] - ETA: 1:19 - 950us/item

item   340/84195 [..............................] - ETA: 1:19 - 949us/item

item   341/84195 [..............................] - ETA: 1:19 - 950us/item

item   342/84195 [..............................] - ETA: 1:19 - 951us/item

item   343/84195 [..............................] - ETA: 1:19 - 950us/item

item   344/84195 [..............................] - ETA: 1:19 - 951us/item

item   345/84195 [..............................] - ETA: 1:19 - 950us/item

item   346/84195 [..............................] - ETA: 1:19 - 951us/item

item   347/84195 [..............................] - ETA: 1:19 - 950us/item

item   348/84195 [..............................] - ETA: 1:19 - 949us/item

item   349/84195 [..............................] - ETA: 1:19 - 950us/item

item   350/84195 [..............................] - ETA: 1:19 - 952us/item

item   351/84195 [..............................] - ETA: 1:19 - 951us/item

item   352/84195 [..............................] - ETA: 1:19 - 952us/item

item   353/84195 [..............................] - ETA: 1:19 - 951us/item

item   354/84195 [..............................] - ETA: 1:19 - 950us/item

item   355/84195 [..............................] - ETA: 1:19 - 953us/item

item   356/84195 [..............................] - ETA: 1:19 - 953us/item

item   357/84195 [..............................] - ETA: 1:19 - 952us/item

item   358/84195 [..............................] - ETA: 1:19 - 953us/item

item   359/84195 [..............................] - ETA: 1:19 - 952us/item

item   360/84195 [..............................] - ETA: 1:19 - 953us/item

item   361/84195 [..............................] - ETA: 1:19 - 952us/item

item   362/84195 [..............................] - ETA: 1:19 - 953us/item

item   363/84195 [..............................] - ETA: 1:19 - 952us/item

item   364/84195 [..............................] - ETA: 1:19 - 952us/item

item   365/84195 [..............................] - ETA: 1:19 - 951us/item

item   366/84195 [..............................] - ETA: 1:19 - 952us/item

item   367/84195 [..............................] - ETA: 1:19 - 952us/item

item   368/84195 [..............................] - ETA: 1:19 - 953us/item

item   369/84195 [..............................] - ETA: 1:19 - 953us/item

item   370/84195 [..............................] - ETA: 1:19 - 952us/item

item   371/84195 [..............................] - ETA: 1:19 - 951us/item

item   372/84195 [..............................] - ETA: 1:19 - 952us/item

item   373/84195 [..............................] - ETA: 1:19 - 951us/item

item   374/84195 [..............................] - ETA: 1:19 - 952us/item

item   375/84195 [..............................] - ETA: 1:19 - 952us/item

item   376/84195 [..............................] - ETA: 1:19 - 952us/item

item   377/84195 [..............................] - ETA: 1:19 - 952us/item

item   378/84195 [..............................] - ETA: 1:19 - 952us/item

item   379/84195 [..............................] - ETA: 1:19 - 952us/item

item   380/84195 [..............................] - ETA: 1:19 - 953us/item

item   381/84195 [..............................] - ETA: 1:19 - 952us/item

item   382/84195 [..............................] - ETA: 1:19 - 952us/item

item   383/84195 [..............................] - ETA: 1:19 - 951us/item

item   384/84195 [..............................] - ETA: 1:19 - 952us/item

item   385/84195 [..............................] - ETA: 1:19 - 952us/item

item   386/84195 [..............................] - ETA: 1:19 - 952us/item

item   387/84195 [..............................] - ETA: 1:19 - 951us/item

item   388/84195 [..............................] - ETA: 1:19 - 952us/item

item   389/84195 [..............................] - ETA: 1:19 - 951us/item

item   390/84195 [..............................] - ETA: 1:19 - 952us/item

item   391/84195 [..............................] - ETA: 1:19 - 951us/item

item   392/84195 [..............................] - ETA: 1:19 - 952us/item

item   393/84195 [..............................] - ETA: 1:19 - 951us/item

item   394/84195 [..............................] - ETA: 1:19 - 952us/item

item   395/84195 [..............................] - ETA: 1:19 - 952us/item

item   396/84195 [..............................] - ETA: 1:19 - 952us/item

item   397/84195 [..............................] - ETA: 1:19 - 951us/item

item   398/84195 [..............................] - ETA: 1:19 - 952us/item

item   399/84195 [..............................] - ETA: 1:19 - 951us/item

item   400/84195 [..............................] - ETA: 1:19 - 951us/item

item   401/84195 [..............................] - ETA: 1:19 - 951us/item

item   402/84195 [..............................] - ETA: 1:19 - 951us/item

item   403/84195 [..............................] - ETA: 1:19 - 950us/item

item   404/84195 [..............................] - ETA: 1:19 - 950us/item

item   405/84195 [..............................] - ETA: 1:19 - 950us/item

item   406/84195 [..............................] - ETA: 1:19 - 949us/item

item   407/84195 [..............................] - ETA: 1:19 - 950us/item

item   408/84195 [..............................] - ETA: 1:19 - 953us/item

item   409/84195 [..............................] - ETA: 1:19 - 952us/item

item   410/84195 [..............................] - ETA: 1:19 - 953us/item

item   411/84195 [..............................] - ETA: 1:19 - 952us/item

item   412/84195 [..............................] - ETA: 1:19 - 952us/item

item   413/84195 [..............................] - ETA: 1:19 - 952us/item

item   414/84195 [..............................] - ETA: 1:19 - 952us/item

item   415/84195 [..............................] - ETA: 1:19 - 952us/item

item   416/84195 [..............................] - ETA: 1:19 - 952us/item

item   417/84195 [..............................] - ETA: 1:19 - 952us/item

item   418/84195 [..............................] - ETA: 1:19 - 952us/item

item   419/84195 [..............................] - ETA: 1:19 - 952us/item

item   420/84195 [..............................] - ETA: 1:19 - 953us/item

item   421/84195 [..............................] - ETA: 1:19 - 952us/item

item   422/84195 [..............................] - ETA: 1:19 - 951us/item

item   423/84195 [..............................] - ETA: 1:19 - 951us/item

item   424/84195 [..............................] - ETA: 1:19 - 951us/item

item   425/84195 [..............................] - ETA: 1:19 - 951us/item

item   426/84195 [..............................] - ETA: 1:19 - 953us/item

item   427/84195 [..............................] - ETA: 1:19 - 953us/item

item   428/84195 [..............................] - ETA: 1:19 - 954us/item

item   429/84195 [..............................] - ETA: 1:19 - 953us/item

item   430/84195 [..............................] - ETA: 1:19 - 953us/item

item   431/84195 [..............................] - ETA: 1:19 - 953us/item

item   432/84195 [..............................] - ETA: 1:19 - 953us/item

item   433/84195 [..............................] - ETA: 1:19 - 953us/item

item   434/84195 [..............................] - ETA: 1:19 - 953us/item

item   435/84195 [..............................] - ETA: 1:19 - 953us/item

item   436/84195 [..............................] - ETA: 1:19 - 953us/item

item   437/84195 [..............................] - ETA: 1:19 - 952us/item

item   438/84195 [..............................] - ETA: 1:19 - 952us/item

item   439/84195 [..............................] - ETA: 1:19 - 952us/item

item   440/84195 [..............................] - ETA: 1:19 - 953us/item

item   441/84195 [..............................] - ETA: 1:19 - 953us/item

item   442/84195 [..............................] - ETA: 1:19 - 953us/item

item   443/84195 [..............................] - ETA: 1:19 - 953us/item

item   444/84195 [..............................] - ETA: 1:19 - 953us/item

item   445/84195 [..............................] - ETA: 1:19 - 953us/item

item   446/84195 [..............................] - ETA: 1:19 - 952us/item

item   447/84195 [..............................] - ETA: 1:19 - 952us/item

item   448/84195 [..............................] - ETA: 1:19 - 953us/item

item   449/84195 [..............................] - ETA: 1:19 - 953us/item

item   450/84195 [..............................] - ETA: 1:19 - 953us/item

item   451/84195 [..............................] - ETA: 1:19 - 953us/item

item   452/84195 [..............................] - ETA: 1:19 - 953us/item
item   819/84195 [..............................] - ETA: 1:19 - 956us/
item  2749/84195 [..............................] - ETA: 1:13 - 899us/it
IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

item  3953/84195 [>.............................] - ETA: 1:06 - 834us/
item  6849/84195 [=>............................] - ETA: 57s - 748us/it
IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

item  8111/84195 [=>............................] - ETA: 55s - 733us/it
item 10976/84195 [==>...........................] - ETA: 51s - 710us/it
IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

item 12229/84195 [===>..........................] - ETA: 50s - 702us/it
item 15119/84195 [====>.........................] - ETA: 47s - 689us/ite
IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

item 16687/84195 [====>.........................] - ETA: 46s - 684us/it
item 19235/84195 [=====>........................] - ETA: 44s - 681us/it
IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

item 20764/84195 [======>.......................] - ETA: 43s - 679us/i
item 23337/84195 [=======>......................] - ETA: 41s - 676us/ite
IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

item 24716/84195 [=======>......................] - ETA: 40s - 675us/i
item 27465/84195 [========>.....................] - ETA: 38s - 676us/ite
IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

item 30124/84195 [=========>....................] - ETA: 36s - 674us/i
item 31582/84195 [==========>...................] - ETA: 35s - 674us/it
IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

item 32707/84195 [==========>...................] - ETA: 34s - 673us/it
item 35488/84195 [===========>..................] - ETA: 32s - 673us/it
IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

item 37488/84195 [============>.................] - ETA: 31s - 671us/i
item 39674/84195 [=============>................] - ETA: 29s - 670us/ite
IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

item 41012/84195 [=============>................] - ETA: 28s - 670us/i
item 43792/84195 [==============>...............] - ETA: 27s - 669us/it
IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

item 44971/84195 [===============>..............] - ETA: 26s - 670us/it
item 47681/84195 [===============>..............] - ETA: 24s - 670us/ite
IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

item 49553/84195 [================>.............] - ETA: 23s - 671us/it
item 51516/84195 [=================>............] - ETA: 21s - 672us/it
IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

item 52717/84195 [=================>............] - ETA: 21s - 673us/i
item 55441/84195 [==================>...........] - ETA: 19s - 674us/it
IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

item 56512/84195 [===================>..........] - ETA: 18s - 673us/it
item 61259/84195 [====================>.........] - ETA: 15s - 672us/item ETA: 15s
IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

item 62897/84195 [=====================>........] - ETA: 14s - 672us/it
item 65179/84195 [======================>.......] - ETA: 12s - 671us/ite
IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

item 69318/84195 [=======================>......] - ETA: 9s - 671us/
item 71799/84195 [========================>.....] - ETA: 8s - 671us/it
IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

item 73978/84195 [=========================>....] - ETA: 6s - 671us/
item 78213/84195 [==========================>...] - ETA: 4s - 670us/it
IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

item 82696/84195 [============================>.] - ETA: 1s - 669us/it
IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

# Finetune the model  model.fit(
    train_data=train_dataset, 
    eval_data=val_dataset, 
    batch_size=16, 
    epochs=2, 
    eval_freq=1, 
    log_freq=1, 
    save_dir='save_models', 
    save_freq=1, 
    verbose=1, 
    drop_last=False, 
    shuffle=True,
    num_workers=0 )
The loss value printed in the log is the current step, and the metric is the average value of previous steps.
Epoch 1/2

/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/layers/utils
.py:77: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from '
collections.abc' is deprecated, and in 3.8 it will stop working
  return (isinstance(seq, collections.Sequence) and
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/nn/layer/norm.py:654: 
UserWarning: When training, we now always track global mean and variance.
  "When training, we now always track global mean and variance.")

step   1/385 [..............................] - loss: 7.5023 - acc_top1: 0.0000e+00 
- acc_top5: 0.0000e+00 - ETA: 4:11 - 655ms/step

step 385/385 [==============================] - loss: 1.9796 - acc_top1: 0.6012 - acc_top5: 0.8195 - 152ms/step        
save checkpoint at /home/aistudio/save_models/0
Eval begin...
step 64/64 [==============================] - loss: 1.7739 - acc_top1: 0.7235 - acc_top5: 0.9324 - 115ms/step         
Eval samples: 1020
Epoch 2/2
step 385/385 [==============================] - loss: 2.0446 - acc_top1: 0.8278 - acc_top5: 0.9689 - 150ms/step        
save checkpoint at /home/aistudio/save_models/1
Eval begin...
step 64/64 [==============================] - loss: 2.8384 - acc_top1: 0.7441 - acc_top5: 0.9412 - 115ms/step         
Eval samples: 1020
save checkpoint at /home/aistudio/save_models/final

二、自定义算子

自定义算子是一个广义的概念,指由用户自定义某种运算的前向(Forward)和反向(Backward)逻辑,封装后用于模型组网,其扮演的角色相当于飞桨框架内部的算子。但此处的自定义算子又不同于已有的飞桨框架内部的算子,它更关注运算本质,仅需要编写必要的计算函数即可,而不需要关注框架内部概念、重新编译飞桨框架,能够以类似 即插即用 的方式使用。

目前,飞桨支持以Python和C++两种语言编写自定义算子,一般在如下场景中使用:

  • 飞桨已有算子无法组合出需要的运算逻辑(建议使用Python或C++编写自定义算子)。
  • 使用飞桨已有算子组合得到的运算逻辑无法满足您的性能需求(建议使用C++编写自定义算子)。

自定义算子Python实现

如果读者发现缺失的算子并不是可以用Paddle提供的基础计算函数搭建的,那么我们就需要自行实现算子的计算逻辑。自定义算子有Python的实现方法和C++原生算子的实现方法,其中C++实现的方法与框架原始提供的各种算子一样,可以实现更高的计算性能。但C++实现的方法略微复杂,这里只介绍Python的实现方法,在不是特别追求极致性能的情况下,使用Python实现缺失的算子逻辑也是可行的。

使用飞桨动态图自定义Python算子开发需要实现如下两步:

1)创建PyLayer子类:定义前向函数和反向函数; 2)调用算子组网:使用apply方法组建网络。

1) 创建PyLayer子类。

让我们先从一段实际的程序开始,下述代码是使用自定义算子的方式实现了Tanh函数。实现自定义算子的核心是定义算子的正向计算逻辑和反向求导逻辑,即下面代码中的forward函数和backward函数。其中,正向计算后需要记录并传递计算结果变量给反向求导函数,以便求导的逻辑中直接使用,这个记录和传递是通过ctx对象实现的。

下面让我们先大致观察一下实现的代码,再进行更深入的解读。

import paddle from paddle.autograd import PyLayer # 通过创建`PyLayer`子类的方式实现动态图Python 
Op class cus_tanh(PyLayer):  @staticmethod def forward(ctx, x): y = paddle.tanh(x) # ctx 
为PyLayerContext对象,可以把y从forward传递到backward。 ctx.save_for_backward(y) return y  
@staticmethod # 因为forward只有一个输出,因此除了ctx外,backward只有一个输入。 def backward(ctx, dy):
 # ctx 为PyLayerContext对象,saved_tensor获取在forward时暂存的y。 y, = ctx.saved_tensor() # 调用Paddle
 API自定义反向计算 grad = dy * (1 - paddle.square(y)) 
# forward只有一个Tensor输入,因此,backward只有一个输出。 return grad

前向函数和反向函数均由Python编写,可以方便地使用Paddle相关API来实现一个自定义的Op,需要遵守以下规则:

  • forward和backward都是静态函数,它们的第一个参数是PyLayerContext对象;
  • backward 除了第一个参数以外,其他参数都是forward函数的输出Tensor的梯度。因此,backward输入的Tensor的数量必须等于forward输出Tensor的数量。如果您需在backward中使用forward中的Tensor,您可以利用save_for_backward和saved_tensor这两个方法传递Tensor;
  • backward的输出可以是Tensor或者list/tuple(Tensor),这些Tensor是forward输出Tensor的梯度。因此,backward的输出Tensor的个数等于forward输入Tensor的个数。如果backward的某个返回值(梯度)在forward中对应的Tensor是需要梯度,这个返回值必须是Tensor类型。

有了初步的印象后,我们来更深入的解读。飞桨是通过 PyLayer 接口和PyLayerContext接口支持动态图的Python端自定义Op,接口描述如下:

class PyLayer:  @staticmethod def forward(ctx, *args, **kwargs): pass  @staticmethod def
 backward(ctx, *args, **kwargs): pass  @classmethod def apply(cls, *args, **kwargs): pass 

其中,

  • forward 是自定义Op的前向函数,必须被子类重写,它的第一个参数是 PyLayerContext 对象,其他输入参数的类型和数量任意;
  • backward 是自定义Op的反向函数,必须被子类重写,其第一个参数为 PyLayerContext 对象,其他输入参数为forward输出Tensor的梯度。它的输出Tensor为forward输入Tensor的梯度;
  • apply 是自定义Op的执行方法,构建完自定义Op后,通过apply运行Op。

PyLayerContext 接口描述如下:

class PyLayerContext: def save_for_backward(self, *tensors): pass def saved_tensor(self): pass 

其中,

  • save_for_backward 用于暂存backward需要的Tensor,这个API只能被调用一次,且只能在forward中调用;
  • saved_tensor 获取被save_for_backward暂存的Tensor。

2) 通过apply方法组建网络。

在实现了算子cus_tanh的计算逻辑后(包括正向计算和反向求导),我们可以使用 cus_tanh.apply API来实现算子的调用。apply的输入为forward中除了第一个参数(ctx)以外的输入,apply的输出即为forward的输出。简单案例如下,我们构造一个2*3的数组,通过cus_tanh.apply()进行前向计算,调用backward()实现反向计算梯度。

data = paddle.randn([2, 3], dtype="float32")
data.stop_gradient = False # 通过 apply运行这个Python算子 z = cus_tanh.apply(data)
z.mean().backward()

print(data.grad)

# import paddle # from paddle.autograd import PyLayer # # 通过创建`PyLayer`子类的方式实现动态图
Python Op # class cus_tanh(PyLayer): #     @staticmethod #     def forward(ctx, x): # 
 y = paddle.tanh(x) #         # ctx 为PyLayerContext对象,可以把y从forward传递到backward。 # 
 ctx.save_for_backward(y) #         return y #     @staticmethod #     # 因为forward只有一个输出,
因此除了ctx外,backward只有一个输入。 #     def backward(ctx, dy): #         # ctx 为PyLayerContext对象,
saved_tensor获取在forward时暂存的y。 #         y, = ctx.saved_tensor() # 
 # 调用Paddle API自定义反向计算 #         grad = dy * (1 - paddle.square(y)) # 
 # forward只有一个Tensor输入,因此,backward只有一个输出。 #         return grad # data 
= paddle.randn([2, 3], dtype="float32") # data.stop_gradient = False # # 通过 apply运行这个Python算子 
# z = cus_tanh.apply(data) # z.mean().backward() # print(data.grad) import paddle from 
paddle.autograd import PyLayer # Inherit from PyLayer class cus_tanh(PyLayer):  @staticmethod 
def forward(ctx, x, func1, func2=paddle.square): # ctx is a context object that store
 some objects for backward. ctx.func = func2
        y = func1(x) # Pass tensors to backward. ctx.save_for_backward(y) return y 
 @staticmethod # forward has only one output, so there is only one gradient in the input
 of backward. def backward(ctx, dy): # Get the tensors passed by forward. y, = ctx.saved_tensor()
        grad = dy * (1 - ctx.func(y)) # forward has only one input, so only one
 gradient tensor is returned. return grad

data = paddle.randn([2, 3], dtype="float64")
data.stop_gradient = False z = cus_tanh.apply(data, func1=paddle.tanh)
z.mean().backward()

print(data.grad.numpy())
[[0.09197889 0.15834022 0.11797523]
 [0.09277528 0.09983508 0.03058447]]

在GitHub上贡献模型代码

如果您使用飞桨实现了优秀的模型,无论您是已入职场的深度学习从业者、爱好者,亦或是在校学生,百度飞桨非常欢迎您能够在开源生态Github中贡献代码,与我们实时分享项目的成功应用和您的奇思妙想。贡献的代码可以是算法模型、框架的算子、框架新增功能、飞桨平台优化建议或者模型的使用教程等。一旦您贡献的代码被飞桨接受,将有机会让更多的深度学习用户受益。同时,为了促进深度学习快速发展和应用,飞桨会定期组织优秀代码展播和表彰等活动,您可以随时关注飞桨官网了解更详细的信息。

在飞桨Github贡献模型代码的流程请参考:官网→文档→Github代码贡献,期望大家在飞桨生态中有“人人为我,我为人人”的精神。

相似文档
  • 工业部署: 飞桨不仅是一个深度学习框架,还是集深度学习核心框架、基础模型库、端到端开发套件、工具组件和服务平台于一体,为用户提供了多样化的配套服务产品,助力深度学习技术的应用落地。如 图1 所示,飞桨针对不同的模型部署场景,提供了多种部署工具。同时也提供了模型压缩工具PaddleSlim,满足对模型尺寸和速度有更高需求的部署场景。
  • 飞桨原生推理库Paddle Inference 在实际应用中,推理阶段会面临和训练时完全不一样的硬件环境,当然也对应着不一样的计算性能要求。我们训练得到的模型,需要能在具体生产环境中正确、高效地实现推理功能,完成上线部署。 上线部署可能会遇到各种问题,比如:
  • 飞桨轻量化推理引擎Paddle Lite: 飞桨具有完善的从训练到部署的一系列框架或工具,当读者完成模型的编写和训练后,如果希望将训练好的模型放到手机端或嵌入式端(如摄像头)等去运行,可以使用飞桨轻量化推理引擎Paddle Lite。
  • 飞桨场景应用开发套件-PaddleX: PaddleX是飞桨场景应用开发套件,它集成飞桨智能视觉领域图像分类、目标检测、语义分割、实例分割任务能力,将深度学习开发全流程从数据准备、模型训练与优化到多端部署端到端打通,并提供统一任务API接口及图形化开发界面Demo。开发者无需分别安装不同套件,以低代码的形式即可快速完成飞桨全流程开发。
  • 人工智能在中国的发展和落地概况: 根据艾瑞的分析报告,人工智能在未来十年迎来落地应用的黄金期,会全面赋能实体经济,行业的经济规模年增长率达40%+。在过去中国经济高速发展的四十年,人们形成了统一的认知:对于个人发展,选择大于能力。一个人选择跳上一辆高速行驶的火车,比个人奔跑快要重要。人工智能在各行业落地相关的产业就是未来十年的高速列车,所以恭喜学习本教程的诸位读者。在可预见的未来,大家会成为各行业应用人工智能技术的弄潮儿。
官方微信
联系客服
400-826-7010
7x24小时客服热线
分享
  • QQ好友
  • QQ空间
  • 微信
  • 微博
返回顶部