首页 最新 热门 推荐

  • 首页
  • 最新
  • 热门
  • 推荐

【增强学习】Recurrent Visual Attention源码解读

  • 25-03-03 21:22
  • 3681
  • 14108
blog.csdn.net

Mnih, Volodymyr, Nicolas Heess, and Alex Graves. “Recurrent models of visual attention.” Advances in Neural Information Processing Systems. 2014.

戳这里下载训练代码,戳这里下载测试代码。

这篇文章处理的任务非常简单:MNIST手写数字分类。但使用了聚焦机制(Visual Attention),不是一次看一张大图进行估计,而是分多次观察小部分图像,根据每次查看结果移动观察位置,最后估计结果。

Yoshua Bengio的高徒,先后供职于LISA和Element Research的Nicolas Leonard用Torch实现了这篇文章的算法。Torch官方cheetsheet的demo中,就包含这篇源码,作者自己的讲解也刊登在Torch的博客中,足见其重要性。

通过这篇源码,我们可以

  • 理解聚焦机制中较简单的hard attention
  • 了解增强学习的基本流程
  • 复习Torch和扩展包dp的相关语法

本文解读训练源码,分三大部分:参数设置,网络构造,训练设置。以下逐次介绍其中重要的语句。

参数设置

除了Torch之外,还需要包含Nicholas Leonard自己编写的两个包。dp:能够简化DL流程,训练过程更“面向对象”;rnn:实现Recurrent网络。

require 'dp'
require 'rnn'
  • 1
  • 2

首先使用Torch的CmdLine类设定一系列参数,存储在opt中。这是Torch的标准写法。

cmd = torch.CmdLine()
cmd:option('--learningRate', 0.01, 'learning rate at t=0')    -- 参数名,参数值,说明
local opt = cmd:parse(arg or {
   })    --把cmd中的参数传入opt
  • 1
  • 2
  • 3
  • 4

把数据载入到数据集ds中,数据是dp包中已经下载好的:

ds = dp[opt.dataset]()
  • 1

网络构造

这篇源码中模型的写法遵循:由底到顶,先细节后整体。和CNN不同,Recurrent网络带有反馈,呈现较为复杂的多级嵌套结构。请着重关注每个模块的输入、输出和作用部分。

Glimpse网络

输入:图像 I I I和观察位置 l l l
输出:观察结果 x x x

蓝色输入,橙色输出,菱形表示串接:
这里写图片描述

首先用locationSensor(左半)提取位置信息 l l l中的特征:

locationSensor:add(nn.SelectTable(2))    --选择两个输入中的第二个,位置l
locationSensor:add(nn.Linear(2, opt.locatorHiddenSize))    --Torch中的Linear指全连层
locationSensor:add(nn[opt.transfer]())    --opt.transfer定义一种非线性运算,本文中是ReLU
  • 1
  • 2
  • 3

之后用glimpseSensor(右半)提取图像 I I I位置 l l l的特征。
其中SpacialGlimpse是dp中定义的层,提取尺寸为PatchSize的Depth层图像,相邻层比例为Scale。

glimpseSensor:add(nn.SpatialGlimpse(opt.glimpsePatchSize, opt.glimpseDepth, opt.glimpseScale):float())    --SpatialGlimpse提取小块金字塔
glimpseSensor:add(nn.Collapse(3))    --压缩第三维
glimpseSensor:add(nn.Linear(ds:imageSize('c')*(opt.glimpsePatchSize^2)*opt
  • 1
  • 2
文章知识点与官方知识档案匹配,可进一步学习相关知识
Python入门技能树首页概览416654 人正在系统学习中
注:本文转载自blog.csdn.net的shenxiaolu1984的文章"https://blog.csdn.net/shenxiaolu1984/article/details/51582185"。版权归原作者所有,此博客不拥有其著作权,亦不承担相应法律责任。如有侵权,请联系我们删除。
复制链接
复制链接
相关推荐
发表评论
登录后才能发表评论和回复 注册

/ 登录

评论记录:

未查询到任何数据!
回复评论:

分类栏目

后端 (14832) 前端 (14280) 移动开发 (3760) 编程语言 (3851) Java (3904) Python (3298) 人工智能 (10119) AIGC (2810) 大数据 (3499) 数据库 (3945) 数据结构与算法 (3757) 音视频 (2669) 云原生 (3145) 云平台 (2965) 前沿技术 (2993) 开源 (2160) 小程序 (2860) 运维 (2533) 服务器 (2698) 操作系统 (2325) 硬件开发 (2491) 嵌入式 (2955) 微软技术 (2769) 软件工程 (2056) 测试 (2865) 网络空间安全 (2948) 网络与通信 (2797) 用户体验设计 (2592) 学习和成长 (2593) 搜索 (2744) 开发工具 (7108) 游戏 (2829) HarmonyOS (2935) 区块链 (2782) 数学 (3112) 3C硬件 (2759) 资讯 (2909) Android (4709) iOS (1850) 代码人生 (3043) 阅读 (2841)

热门文章

101
推荐
关于我们 隐私政策 免责声明 联系我们
Copyright © 2020-2025 蚁人论坛 (iYenn.com) All Rights Reserved.
Scroll to Top