交流群:462197261站长百科站长论坛热门标签收藏本站北冥有鱼 互联网前沿资源第一站 助力全行业互联网+
点击这里给我发消息
  • 当前位置:
  • 解决Pytorch自定义层出现多Variable共享内存错误问题

    错误信息:

    RuntimeError: in-place operations can be only used on variables that don't share storage with any other variables, but detected that there are 4 objects sharing it

    自动求导是很方便, 但是想想, 如果两个Variable共享内存, 再对这个共享的内存的数据进行修改, 就会引起错误!

    一般是由于 inplace操作或是indexing或是转置. 这些都是共享内存的.

     @staticmethod
     def backward(ctx, grad_output):
      ind_lst = ctx.ind_lst
      flag = ctx.flag
    
      c = grad_output.size(1)
      grad_former_all = grad_output[:, 0:c//3, :, :]
      grad_latter_all = grad_output[:, c//3: c*2//3, :, :]
      grad_swapped_all = grad_output[:, c*2//3:c, :, :]
    
      spatial_size = ctx.h * ctx.w
    
      W_mat_all = Variable(ctx.Tensor(ctx.bz, spatial_size, spatial_size).zero_())
      for idx in range(ctx.bz):
       W_mat = W_mat_all.select(0,idx)
       for cnt in range(spatial_size):
        indS = ind_lst[idx][cnt] 
    
        if flag[cnt] == 1:
         # 这里W_mat是W_mat_all通过select出来的, 他们共享内存.
         W_mat[cnt, indS] = 1
    
       W_mat_t = W_mat.t()
    
       grad_swapped_weighted = torch.mm(W_mat_t, grad_swapped_all[idx].view(c//3, -1).t())
       grad_swapped_weighted = grad_swapped_weighted.t().contiguous().view(1, c//3, ctx.h, ctx.w)
       grad_latter_all[idx] = torch.add(grad_latter_all[idx], grad_swapped_weighted.mul(ctx.triple_w))
    

    由于 这里W_mat是W_mat_all通过select出来的, 他们共享内存. 所以当对这个共享的内存进行修改W_mat[cnt, indS] = 1, 就会出错. 此时我们可以通过clone()将W_mat和W_mat_all独立出来. 这样的话, 梯度也会通过 clone()操作将W_mat的梯度正确反传到W_mat_all中.

     @staticmethod
     def backward(ctx, grad_output):
      ind_lst = ctx.ind_lst
      flag = ctx.flag
    
      c = grad_output.size(1)
      grad_former_all = grad_output[:, 0:c//3, :, :]
      grad_latter_all = grad_output[:, c//3: c*2//3, :, :]
      grad_swapped_all = grad_output[:, c*2//3:c, :, :]
    
      spatial_size = ctx.h * ctx.w
    
      W_mat_all = Variable(ctx.Tensor(ctx.bz, spatial_size, spatial_size).zero_())
      for idx in range(ctx.bz):
       # 这里使用clone了
       W_mat = W_mat_all.select(0,idx).clone()
       for cnt in range(spatial_size):
        indS = ind_lst[idx][cnt]
    
        if flag[cnt] == 1:
         W_mat[cnt, indS] = 1
    
       W_mat_t = W_mat.t()
    
       grad_swapped_weighted = torch.mm(W_mat_t, grad_swapped_all[idx].view(c//3, -1).t())
       grad_swapped_weighted = grad_swapped_weighted.t().contiguous().view(1, c//3, ctx.h, ctx.w)
    
       # 这句话删了不会出错, 加上就吹出错
       grad_latter_all[idx] = torch.add(grad_latter_all[idx], grad_swapped_weighted.mul(ctx.triple_w))
    

    但是现在却出现 4个objects共享内存. 如果将最后一句话删掉, 那么则不会出错.

    如果没有最后一句话, 我们看到

    grad_swapped_weighted = torch.mm(W_mat_t, grad_swapped_all[idx].view(c//3, -1).t())

    grad_swapped_weighted = grad_swapped_weighted.t().contiguous().view(1, c//3, ctx.h, ctx.w)

    grad_swapped_weighted 一个新的Variable, 因此并没有和其他Variable共享内存, 所以不会出错. 但是最后一句话,

    grad_latter_all[idx] = torch.add(grad_latter_all[idx], grad_swapped_weighted.mul(ctx.triple_w))

    你可能会说, 不对啊, 修改grad_latter_all[idx]又没有创建新的Variable, 怎么会出错. 这是因为grad_latter_all和grad_output是共享内存的. 因为 grad_latter_all = grad_output[:, c//3: c*2//3, :, :], 所以这里的解决方案是:

     @staticmethod
     def backward(ctx, grad_output):
      ind_lst = ctx.ind_lst
      flag = ctx.flag
    
      c = grad_output.size(1)
      grad_former_all = grad_output[:, 0:c//3, :, :]
      # 这两个后面修改值了, 所以也要加clone, 防止它们与grad_output共享内存
      grad_latter_all = grad_output[:, c//3: c*2//3, :, :].clone()
      grad_swapped_all = grad_output[:, c*2//3:c, :, :].clone()
    
      spatial_size = ctx.h * ctx.w
    
      W_mat_all = Variable(ctx.Tensor(ctx.bz, spatial_size, spatial_size).zero_())
      for idx in range(ctx.bz):
       W_mat = W_mat_all.select(0,idx).clone()
       for cnt in range(spatial_size):
        indS = ind_lst[idx][cnt]
    
        if flag[cnt] == 1:
         W_mat[cnt, indS] = 1
    
       W_mat_t = W_mat.t()
    
       grad_swapped_weighted = torch.mm(W_mat_t, grad_swapped_all[idx].view(c//3, -1).t())
    
       grad_swapped_weighted = grad_swapped_weighted.t().contiguous().view(1, c//3, ctx.h, ctx.w)
       grad_latter_all[idx] = torch.add(grad_latter_all[idx], grad_swapped_weighted.mul(ctx.triple_w))
    
      grad_input = torch.cat([grad_former_all, grad_latter_all], 1)
    
      return grad_input, None, None, None, None, None, None, None, None, None, None
    

    补充知识:Pytorch 中 expand, expand_as是共享内存的,只是原始数据的一个视图 view

    如下所示:

    mask = mask_miss.expand_as(sxing).clone() # type: torch.Tensor
    mask[:, :, -2, :, :] = 1 # except for person mask channel

    为了避免对expand后对某个channel操作会影响原始tensor的全部元素,需要使用clone()

    如果没有clone(),对mask_miss的某个通道赋值后,所有通道上的tensor都会变成1!

    # Notice! expand does not allocate more memory but just make the tensor look as if you expanded it.
    # You should call .clone() on the resulting tensor if you plan on modifying it
    # https://discuss.pytorch.org/t/very-strange-behavior-change-one-element-of-a-tensor-will-influence-all-elements/41190

    以上这篇解决Pytorch自定义层出现多Variable共享内存错误问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持北冥有鱼。


    广而告之:
    热门推荐:
    Access教程 如何在表添加现有字段

    Access供给了一种便利的增加字段的办法,因为各个表有重复的字段,所以咱们能够直接的将另外表中的字段增加到心的表中即可。 这样建立表的字段就很方便快捷了,详细的办法如下。 1、打开你想要添加字段的数据表,然后点击菜单栏上的 数据表。 2、在数据表 选项中,我么···

    js打字机效果代码

    Type Writea{text-decoration:none} [Ctrl+A 全选 注:如需引入外部Js需刷新才能执行] 您可能感兴趣的文章: Js 打字效果 逐一出现的文字 JS实现的自动打字效果示例 JS实现的打字机效果完整实例 javascript 打字效果的文字特效 JS模拟键盘打字效果的方法 JS实现简单···

    DedeCMS编辑器改成eWebEditor编辑器详解

    织梦DedeCMS编辑器改成eWebEditor编辑器详解 。 第一步: 首先下载eWebEditor最新免费版, 名字叫做eWebEditor V4.6精简版, 但是其实功能没什么缺陷,就跟收费版V4.6是一样的。 地址是:http://www.ewebeditor.net/download.asp。 第二步:   解压下载的eWebEdi···

    js实时获取并显示当前时间的方法

    本文实例讲述了js实时获取并显示当前时间的方法。分享给大家供大家参考。具体实现方法如下: js部分如下: <script type="text/javascript"> window.onload = function() { var show = document.getElementById("show"); setInterval(function() { var time = n···

    subsonic3.0插件更新字符串过长引发的异常修复方法

    最近公司客服提交了个BUG,说是更新产品详细信息时,有的可以有的更新不了,前段时间一直没空所以暂时放下,刚才又出现这个问题,所以马上处理了一下。 打开项目解决方案,进入DEBUG模式,拿到操作的数据提交后进行追踪,发现提交时产生了:System.Data.SqlClient.SqlExc···

    BootStrap 智能表单实战系列(五) 表单依赖插件处理

    什么是 Bootstrap? Bootstrap 是一个用于快速开发 Web 应用程序和网站的前端框架。Bootstrap 是基于 HTML、CSS、JAVASCRIPT 的。 历史 Bootstrap 是由 Twitter 的 Mark Otto 和 Jacob Thornton 开发的。Bootstrap 是 2011 年八月在 GitHub 上发布的开源产品。 Bootstrap 包的···

    Node.js事件驱动

    Node.js事件驱动实现概览 虽然在ECMAScript的标准里并没有(也没有必要)明确规定“事件”,但是在浏览器中,事件作为一个极为重要的机制,给予JavaScript响应用户操作与DOM变化的能力;在Node.js中,异步事件驱动模型则是其高并发能力的基础。 学习JavaScript也需要了解它的运行···

    javascript控制swfObject应用介绍

    复制代码 代码如下: <!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd"> <html xmlns="http://www.w3.org/1999/xhtml"> <head> <meta http-equiv="Content-Type" conten···

    帝国cms调用大栏目下的小栏目投稿数与审核数百分比显示

    需要 [月费会员] 级别以上与扣除 60 点积分才能查看。 您还未登陆,登录点击这里进行登陆操作;注册请点击这里。

    利用javascript实现禁用网页上所有文本框,下拉菜单,多行文本域

    原理就是循环获取网页上的控件,然后设置disabled 属性为true. 代码如下:复制代码 代码如下:<script type="text/javascript">    var nodeList = document.getElementsByTagName("input");    for (var i = 0; i < nodeList.len···