博客
关于我
PyTorch中的nn.Conv1d与nn.Conv2d
阅读量:196 次
发布时间:2019-02-28

本文共 2583 字,大约阅读时间需要 8 分钟。

一维卷积nn.Conv1d

一般来说,一维卷积nn.Conv1d用于文本数据,只对宽度进行卷积,对高度不卷积。通常,输入大小为word_embedding_dim * max_length,其中,word_embedding_dim为词向量的维度,max_length为句子的最大长度。卷积核窗口在句子长度的方向上滑动,进行卷积操作。

定义
class torch.nn.Conv1d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)

主要参数说明:

  • in_channels:在文本应用中,即为词向量的维度
  • out_channels:卷积产生的通道数,相当于是将词向量的维度从in_channels变成了out_channels
  • kernel_size:卷积核的尺寸;卷积核的第二个维度由in_channels决定,所以实际上卷积核的大小为kernel_size * in_channels
  • padding:对输入的每一条边,补充0的层数

代码示例

输入:批大小为32,句子的最大长度为35,词向量维度为256
目标:句子分类,共2类

conv1 = nn.Conv1d(in_channels=256, out_channels=100, kernel_size=2)input = torch.randn(32, 35, 256)input = input.permute(0, 2, 1)   # (32, 35, 256) => (32, 256, 35)output = conv1(input)  # (32, 100, 34)

要使用permute是因为nn.Conv1d是对输入的最后一个维度卷积,所以要把句子长度所在的那个维度变换到最后。

上面的代码只使用了一个卷积核,如果要使用多个卷积核应该使用nn.ModuleList和for循环

import torchimport torch.nn as nnwindow_sizes = [2,3,4]convs = nn.ModuleList([            nn.Sequential(nn.Conv1d(in_channels=8,                                     out_channels=4,                                     kernel_size=h),                          nn.ReLU())                          for h in window_sizes                    ])embed = torch.randn(2, 16, 8)embed = embed.transpose(1,2)output = [conv(embed) for conv in convs]#print(output)for x in output:    print(x.size())'''输出torch.Size([2, 4, 15])torch.Size([2, 4, 14])torch.Size([2, 4, 13])'''

二维卷积nn.Conv2d

一般来说,二维卷积nn.Conv2d用于图像数据,对宽度和高度都进行卷积。

定义
class torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)

代码示例

假设现有大小为32 x 32的图片样本,输入样本的channels为1,该图片可能属于10个类中的某一类。CNN框架定义如下:

class CNN(nn.Module):    def __init__(self):        nn.Model.__init__(self)         self.conv1 = nn.Conv2d(1, 6, 5)  # 输入通道数为1,输出通道数为6        self.conv2 = nn.Conv2d(6, 16, 5)  # 输入通道数为6,输出通道数为16        self.fc1 = nn.Linear(5 * 5 * 16, 120)        self.fc2 = nn.Linear(120, 84)        self.fc3 = nn.Linear(84, 10)    def forward(self,x):        '''			总共有2个卷积层,每一层的结构都是卷积->relu->max_pool		'''		# 第一层        x = self.conv1(x)   # 32*32*1 => 28*28*6        x = F.relu(x)        x = F.max_pool2d(x, 2)  # 28*28*6 => 14*14*6        # 第二层        x = self.conv2(x)     # 14*14*6 => 10*10*16        x = F.relu(x)        x = F.max_pool2d(x, 2)   # 10*10*16 => 5*5*16        # view函数将张量x变形成一维向量形式,总特征数不变,为全连接层做准备        x = x.view(x.size()[0], -1)   # 5*5*16 => 400*1        x = F.relu(self.fc1(x))       # 400*1 => 120 * 1        x = F.relu(self.fc2(x))       # 120*1 => 84*1        x = self.fc3(x)               # 84*1 => 10*1        return x

在这里插入图片描述

参考:

转载地址:http://qwrn.baihongyu.com/

你可能感兴趣的文章
NIFI1.21.0_Mysql到Mysql增量CDC同步中_日期类型_以及null数据同步处理补充---大数据之Nifi工作笔记0057
查看>>
NIFI1.21.0_Mysql到Mysql增量CDC同步中_补充_插入时如果目标表中已存在该数据则自动改为更新数据_Postgresql_Hbase也适用---大数据之Nifi工作笔记0058
查看>>
NIFI1.21.0_Mysql到Mysql增量CDC同步中_补充_更新时如果目标表中不存在记录就改为插入数据_Postgresql_Hbase也适用---大数据之Nifi工作笔记0059
查看>>
NIFI1.21.0_NIFI和hadoop蹦了_200G集群磁盘又满了_Jps看不到进程了_Unable to write in /tmp. Aborting----大数据之Nifi工作笔记0052
查看>>
NIFI1.21.0_Postgresql和Mysql同时指定库_指定多表_全量同步到Mysql数据库以及Hbase数据库中---大数据之Nifi工作笔记0060
查看>>
NIFI1.21.0最新版本安装_连接phoenix_单机版_Https登录_什么都没改换了最新版本的NIFI可以连接了_气人_实现插入数据到Hbase_实际操作---大数据之Nifi工作笔记0050
查看>>
NIFI1.21.0最新版本安装_配置使用HTTP登录_默认是用HTTPS登录的_Https登录需要输入用户名密码_HTTP不需要---大数据之Nifi工作笔记0051
查看>>
NIFI1.21.0通过Postgresql11的CDC逻辑复制槽实现_指定表多表增量同步_增删改数据分发及删除数据实时同步_通过分页解决变更记录过大问题_02----大数据之Nifi工作笔记0054
查看>>
NIFI1.21.0通过Postgresql11的CDC逻辑复制槽实现_指定表多表增量同步_增加修改实时同步_使用JsonPath及自定义Python脚本_03---大数据之Nifi工作笔记0055
查看>>
NIFI1.21.0通过Postgresql11的CDC逻辑复制槽实现_指定表多表增量同步_插入修改删除增量数据实时同步_通过分页解决变更记录过大问题_01----大数据之Nifi工作笔记0053
查看>>
NIFI1.21.0通过Postgresql11的CDC逻辑复制槽实现_指定表或全表增量同步_实现指定整库同步_或指定数据表同步配置_04---大数据之Nifi工作笔记0056
查看>>
NIFI1.23.2_最新版_性能优化通用_技巧积累_使用NIFI表达式过滤表_随时更新---大数据之Nifi工作笔记0063
查看>>
NIFI从MySql中增量同步数据_通过Mysql的binlog功能_实时同步mysql数据_根据binlog实现update数据实时同步_实际操作05---大数据之Nifi工作笔记0044
查看>>
NIFI从MySql中增量同步数据_通过Mysql的binlog功能_实时同步mysql数据_根据binlog实现数据实时delete同步_实际操作04---大数据之Nifi工作笔记0043
查看>>
NIFI从MySql中增量同步数据_通过Mysql的binlog功能_实时同步mysql数据_配置binlog_使用处理器抓取binlog数据_实际操作01---大数据之Nifi工作笔记0040
查看>>
NIFI从MySql中增量同步数据_通过Mysql的binlog功能_实时同步mysql数据_配置数据路由_实现数据插入数据到目标数据库_实际操作03---大数据之Nifi工作笔记0042
查看>>
NIFI从MySql中增量同步数据_通过Mysql的binlog功能_实时同步mysql数据_配置数据路由_生成插入Sql语句_实际操作02---大数据之Nifi工作笔记0041
查看>>
NIFI从MySql中离线读取数据再导入到MySql中_03_来吧用NIFI实现_数据分页获取功能---大数据之Nifi工作笔记0038
查看>>
NIFI从MySql中离线读取数据再导入到MySql中_不带分页处理_01_QueryDatabaseTable获取数据_原0036---大数据之Nifi工作笔记0064
查看>>
NIFI从MySql中离线读取数据再导入到MySql中_无分页功能_02_转换数据_分割数据_提取JSON数据_替换拼接SQL_添加分页---大数据之Nifi工作笔记0037
查看>>