?

PyTorch最佳实践,怎样才能写出一手风格优美的代码

虽然这是一个非官方的 PyTorch 指南,但本文总结了一年多使用 PyTorch 框架的经验,尤其是用它开发深度学习相关工作的最优解决方案。请注意,我们分享的经验大多是从研究和实践角度出发的。

这是一个开发的项目,欢迎其它读者改进该文档:

https://github.com/IgorSusmelj/pytorch-styleguide。

本文档主要由三个部分构成:首先,本文会简要清点 Python 中的最好装备。接着,本文会介绍一些使用 PyTorch 的技巧和建议。最后,我们分享了一些使用其它框架的见解和经验,这些框架通常帮助我们改进工作流。

一、清点 Python 装备

1. 建议使用 Python 3.6 以上版本

根据我们的经验,我们推荐使用 Python 3.6 以上的版本,因为它们具有以下特性,这些特性可以使我们很容易写出简洁的代码:

自 Python 3.6 以后支持「typing」模块

自 Python 3.6 以后支持格式化字符串(f string)

2. Python 风格指南

我们试图遵循 Google 的 Python 编程风格。请参阅 Google 提供的优秀的 python 编码风格指南:

地址:https://github.com/google/styleguide/blob/gh-pages/pyguide.md。

在这里,我们会给出一个最常用命名规范小结:

PyTorch最佳实践,怎样才能写出一手风格优美的代码

3. 集成开发环境

一般来说,我们建议使用 visual studio 或 PyCharm 这样的集成开发环境。而 VS Code 在相对轻量级的编辑器中提供语法高亮和自动补全功能,PyCharm 则拥有许多用于处理远程集群任务的高级特性。

4. Jupyter Notebooks VS Python 脚本

一般来说,我们建议使用 Jupyter Notebook 进行初步的探索,或尝试新的模型和代码。如果你想在更大的数据集上训练该模型,就应该使用 Python 脚本,因为在更大的数据集上,复现性更加重要。

我们推荐你采取下面的工作流程:

在开始的阶段,使用 Jupyter Notebook

对数据和模型进行探索

在 notebook 的单元中构建你的类/方法

将代码移植到 Python 脚本中

在服务器上训练/部署

PyTorch最佳实践,怎样才能写出一手风格优美的代码

5. 开发常备库

常用的程序库有:

PyTorch最佳实践,怎样才能写出一手风格优美的代码

6. 文件组织

不要将所有的层和模型放在同一个文件中。最好的做法是将最终的网络分离到独立的文件(networks.py)中,并将层、损失函数以及各种操作保存在各自的文件中(layers.py,losses.py,ops.py)。最终得到的模型(由一个或多个网络组成)应该用该模型的名称命名(例如,yolov3.py,DCGAN.py),且引用各个模块。

主程序、单独的训练和测试脚本应该只需要导入带有模型名字的 Python 文件。

二、PyTorch 开发风格与技巧

我们建议将网络分解为更小的可复用的片段。一个 nn.Module 网络包含各种操作或其它构建模块。损失函数也是包含在 nn.Module 内,因此它们可以被直接整合到网络中。

继承 nn.Module 的类必须拥有一个「forward」方法,它实现了各个层或操作的前向传导。

一个 nn.module 可以通过「self.net(input)」处理输入数据。在这里直接使用了对象的「call()」方法将输入数据传递给模块。

output = self.net(input) 

1. PyTorch 环境下的一个简单网络

使用下面的模式可以实现具有单个输入和输出的简单网络:

class ConvBlock(nn.Module): 

    def __init__(self): 

        super(ConvBlock, self).__init__() 

        block = [nn.Conv2d(...)] 

        block += [nn.ReLU()] 

        block += [nn.BatchNorm2d(...)] 

        self.block = nn.Sequential(*block) 

 

    def forward(self, x): 

        return self.block(x) 

 

class SimpleNetwork(nn.Module): 

    def __init__(self, num_resnet_blocks=6): 

        super(SimpleNetwork, self).__init__() 

        # here we add the individual layers 

        layers = [ConvBlock(...)] 

        for i in range(num_resnet_blocks): 

            layers += [ResBlock(...)] 

        self.net = nn.Sequential(*layers) 

 

    def forward(self, x): 

        return self.net(x) 

请注意以下几点:

我们复用了简单的循环构建模块(如卷积块 ConvBlocks),它们由相同的循环模式(卷积、激活函数、归一化)组成,并装入独立的 nn.Module 中。

我们构建了一个所需要层的列表,并最终使用「nn.Sequential()」将所有层级组合到了一个模型中。我们在 list 对象前使用「*」操作来展开它。

在前向传导过程中,我们直接使用输入数据运行模型。

2. PyTorch 环境下的简单残差网络

class ResnetBlock(nn.Module): 

    def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): 

        super(ResnetBlock, self).__init__() 

        selfself.conv_block = self.build_conv_block(...) 

 

    def build_conv_block(self, ...): 

        conv_block = [] 

 

        conv_block += [nn.Conv2d(...), 

                       norm_layer(...), 

                       nn.ReLU()] 

        if use_dropout: 

            conv_block += [nn.Dropout(...)] 

 

        conv_block += [nn.Conv2d(...), 

                       norm_layer(...)] 

 

        return nn.Sequential(*conv_block) 

 

    def forward(self, x): 

        out = x + self.conv_block(x) 

        return ou 

在这里,ResNet 模块的跳跃连接直接在前向传导过程中实现了,PyTorch 允许在前向传导过程中进行动态操作。

3. PyTorch 环境下的带多个输出的网络

对于有多个输出的网络(例如使用一个预训练好的 VGG 网络构建感知损失),我们使用以下模式:

class Vgg19(torch.nn.Module): 

  def __init__(self, requires_grad=False): 

    super(Vgg19, self).__init__() 

    vgg_pretrained_features = models.vgg19(pretrained=True).features 

    self.slice1 = torch.nn.Sequential() 

    self.slice2 = torch.nn.Sequential() 

    self.slice3 = torch.nn.Sequential() 

 

    for x in range(7): 

        self.slice1.add_module(str(x), vgg_pretrained_features[x]) 

    for x in range(7, 21): 

        self.slice2.add_module(str(x), vgg_pretrained_features[x]) 

    for x in range(21, 30): 

        self.slice3.add_module(str(x), vgg_pretrained_features[x]) 

    if not requires_grad: 

        for param in self.parameters(): 

            param.requires_grad = False 

 

  def forward(self, x): 

    h_relu1 = self.slice1(x) 

    h_relu2 = self.slice2(h_relu1)         

    h_relu3 = self.slice3(h_relu2)         

    out = [h_relu1, h_relu2, h_relu3] 

    return out 

请注意以下几点:

我们使用由「torchvision」包提供的预训练模型

我们将一个网络切分成三个模块,每个模块由预训练模型中的层组成

我们通过设置「requires_grad = False」来固定网络权重

我们返回一个带有三个模块输出的 list

4. 自定义损失函数

相关推荐
新闻聚焦
猜你喜欢
热门推荐
  • 微软AI面试题有多难?这里有一份样卷

      究竟什么样的AI人才能被微软这样的巨头聘用呢?今天,文摘君就淘来了几道微软AI 面试题,同时给出了最基本的解答......

    06-25????来源:澎湃新闻网

    分享
  • 全球最聪明的大脑怎么看AI?他们预测了

      2017年AI领域取得了诸多成果。2018年AI又将何去何从?以下是来自世界顶级研究人员和行业领军人物对2018年AI领域发展作......

    02-20????来源:虎嗅网

    分享
  • 2017JavaScript框架战报 - React分战场

      我们来看看与React有关的软件包的生态系统。当Facebook构建React时,就有许多来自开源社区的第三方软件包。为提供完......

    02-27????来源:湖北新闻网

    分享
  • 小白学数据:教你用Python实现简单监督学

      监督学习作为运用最广泛的机器学习方法,一直以来都是从数据挖掘信息的重要手段。即便是在无监督学习兴起的近......

    03-05????来源:今日头条

    分享
  • 现代编程语言Swift、Kotlin等十大有趣功能

      最近学习了一些现代编程语言,比如Reason,Swift,Kotlin和Dart。这些编程语言提供了许多新功能,本文主要分享了我认......

    04-29????来源:祁东新闻网

    分享
  • 领域场景分析的6W模型

      组成场景的要素常常被称之为6W模型,即描写场景的过程必须包含Who,What,Why,Where,When与hoW这六个要素。......

    04-30????来源:砍柴网

    分享
  • 开源应用服务器WildFly 12发新季度交付模式

      WildFly 12 Final版本现在已经可以下载了,WildFly是一款灵活的开源应用服务器,支持开发人员构建轻量级应用程序。支持......

    05-10????来源:青岛新闻网

    分享
  • 基于Spring Cloud的微服务落地

      微服务架构模式的核心在于如何识别服务的边界,设计出合理的微服务。但如果要将微服务架构运用到生产项目上,......

    06-04????来源:广西新闻网

    分享
  • 为什么阿里工程师纷纷在内网晒代码?

      前阵子,在阿里一个小黑屋里,5名对代码有着极致追求的工程师参与阿里代码领域最高荣誉“多隆奖”的最终角逐。......

    06-08????来源:四川新闻网

    分享
  • 超级大汇总!200多个最好的机器学习、

      我把这篇文章分为了四个部分:机器学习,自然语言处理,python和数学。在每个部分中我都列举了一些主题,但是因......

    09-25????来源:洛阳新闻网

    分享
返回列表
Ctrl+D?将本页面保存为书签,全面了解最新资讯,方便快捷。