交流群:462197261站长百科站长论坛热门标签收藏本站北冥有鱼 互联网前沿资源第一站 助力全行业互联网+
点击这里给我发消息
  • 当前位置:
  • pytorch VGG11识别cifar10数据集(训练+预测单张输入图片操作)

    北冥有鱼 教程大全 2020-06-26 ,,

    首先这是VGG的结构图,VGG11则是红色框里的结构,共分五个block,如红框中的VGG11第一个block就是一个conv3-64卷积层:

    一,写VGG代码时,首先定义一个 vgg_block(n,in,out)方法,用来构建VGG中每个block中的卷积核和池化层:

    n是这个block中卷积层的数目,in是输入的通道数,out是输出的通道数

    有了block以后,我们还需要一个方法把形成的block叠在一起,我们定义这个方法叫vgg_stack:

    def vgg_stack(num_convs, channels): # vgg_net = vgg_stack((1, 1, 2, 2, 2), ((3, 64), (64, 128), (128, 256), (256, 512), (512, 512)))
    
    
     net = []
     for n, c in zip(num_convs, channels):
      in_c = c[0]
      out_c = c[1]
      net.append(vgg_block(n, in_c, out_c))
     return nn.Sequential(*net)

    右边的注释

    vgg_net = vgg_stack((1, 1, 2, 2, 2), ((3, 64), (64, 128), (128, 256), (256, 512), (512, 512)))

    里,(1, 1, 2, 2, 2)表示五个block里,各自的卷积层数目,((3, 64), (64, 128), (128, 256), (256, 512), (512, 512))表示每个block中的卷积层的类型,如(3,64)表示这个卷积层输入通道数是3,输出通道数是64。vgg_stack方法返回的就是完整的vgg11模型了。

    接着定义一个vgg类,包含vgg_stack方法:

    #vgg类
    class vgg(nn.Module):
     def __init__(self):
      super(vgg, self).__init__()
      self.feature = vgg_net
      self.fc = nn.Sequential(
       nn.Linear(512, 100),
       nn.ReLU(True),
       nn.Linear(100, 10)
      )
     
     def forward(self, x):
      x = self.feature(x)
      x = x.view(x.shape[0], -1)
      x = self.fc(x)
      return x

    最后:

    net = vgg() #就能获取到vgg网络

    那么构建vgg网络完整的pytorch代码是:

    def vgg_block(num_convs, in_channels, out_channels):
     net = [nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), nn.ReLU(True)]
     
     for i in range(num_convs - 1): # 定义后面的许多层
      net.append(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1))
      net.append(nn.ReLU(True))
     
     net.append(nn.MaxPool2d(2, 2)) # 定义池化层
     return nn.Sequential(*net)
     
    # 下面我们定义一个函数对这个 vgg block 进行堆叠
    def vgg_stack(num_convs, channels): # vgg_net = vgg_stack((1, 1, 2, 2, 2), ((3, 64), (64, 128), (128, 256), (256, 512), (512, 512)))
     net = []
     for n, c in zip(num_convs, channels):
      in_c = c[0]
      out_c = c[1]
      net.append(vgg_block(n, in_c, out_c))
     return nn.Sequential(*net)
     
    #确定vgg的类型,是vgg11 还是vgg16还是vgg19
    vgg_net = vgg_stack((1, 1, 2, 2, 2), ((3, 64), (64, 128), (128, 256), (256, 512), (512, 512)))
    #vgg类
    class vgg(nn.Module):
     def __init__(self):
      super(vgg, self).__init__()
      self.feature = vgg_net
      self.fc = nn.Sequential(
       nn.Linear(512, 100),
       nn.ReLU(True),
       nn.Linear(100, 10)
      )
     def forward(self, x):
      x = self.feature(x)
      x = x.view(x.shape[0], -1)
      x = self.fc(x)
      return x
     
    #获取vgg网络
    net = vgg() 
    

    基于VGG11的cifar10训练代码:

    import sys
    import numpy as np
    import torch
    from torch import nn
    from torch.autograd import Variable
    from torchvision.datasets import CIFAR10
    import torchvision.transforms as transforms
     
    def vgg_block(num_convs, in_channels, out_channels):
     net = [nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), nn.ReLU(True)]
     
     for i in range(num_convs - 1): # 定义后面的许多层
      net.append(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1))
      net.append(nn.ReLU(True))
     
     net.append(nn.MaxPool2d(2, 2)) # 定义池化层
     return nn.Sequential(*net)
     
    # 下面我们定义一个函数对这个 vgg block 进行堆叠
    def vgg_stack(num_convs, channels): # vgg_net = vgg_stack((1, 1, 2, 2, 2), ((3, 64), (64, 128), (128, 256), (256, 512), (512, 512)))
     net = []
     for n, c in zip(num_convs, channels):
      in_c = c[0]
      out_c = c[1]
      net.append(vgg_block(n, in_c, out_c))
     return nn.Sequential(*net)
     
    #vgg类
    class vgg(nn.Module):
     def __init__(self):
      super(vgg, self).__init__()
      self.feature = vgg_net
      self.fc = nn.Sequential(
       nn.Linear(512, 100),
       nn.ReLU(True),
       nn.Linear(100, 10)
      )
     def forward(self, x):
      x = self.feature(x)
      x = x.view(x.shape[0], -1)
      x = self.fc(x)
      return x
     
    # 然后我们可以训练我们的模型看看在 cifar10 上的效果
    def data_tf(x):
     x = np.array(x, dtype='float32') / 255
     x = (x - 0.5) / 0.5
     x = x.transpose((2, 0, 1)) ## 将 channel 放到第一维,只是 pytorch 要求的输入方式
     x = torch.from_numpy(x)
     return x
     
    transform = transforms.Compose([transforms.ToTensor(),
             transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
             ])
    def get_acc(output, label):
     total = output.shape[0]
     _, pred_label = output.max(1)
     num_correct = (pred_label == label).sum().item()
     return num_correct / total
     
    def train(net, train_data, valid_data, num_epochs, optimizer, criterion):
     if torch.cuda.is_available():
      net = net.cuda()
     for epoch in range(num_epochs):
      train_loss = 0
      train_acc = 0
      net = net.train()
      for im, label in train_data:
       if torch.cuda.is_available():
        im = Variable(im.cuda())
        label = Variable(label.cuda())
       else:
        im = Variable(im)
        label = Variable(label)
       # forward
       output = net(im)
       loss = criterion(output, label)
       # forward
       optimizer.zero_grad()
       loss.backward()
       optimizer.step()
     
       train_loss += loss.item()
       train_acc += get_acc(output, label)
     
      if valid_data is not None:
       valid_loss = 0
       valid_acc = 0
       net = net.eval()
       for im, label in valid_data:
        if torch.cuda.is_available():
         with torch.no_grad():
          im = Variable(im.cuda())
          label = Variable(label.cuda())
        else:
         with torch.no_grad():
          im = Variable(im)
          label = Variable(label)
        output = net(im)
        loss = criterion(output, label)
        valid_loss += loss.item()
        valid_acc += get_acc(output, label)
       epoch_str = (
         "Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, "
         % (epoch, train_loss / len(train_data),
          train_acc / len(train_data), valid_loss / len(valid_data),
          valid_acc / len(valid_data)))
      else:
       epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " %
           (epoch, train_loss / len(train_data),
           train_acc / len(train_data)))
     
      # prev_time = cur_time
      print(epoch_str)
     
    if __name__ == '__main__':
     # 作为实例,我们定义一个稍微简单一点的 vgg11 结构,其中有 8 个卷积层
     vgg_net = vgg_stack((1, 1, 2, 2, 2), ((3, 64), (64, 128), (128, 256), (256, 512), (512, 512)))
     print(vgg_net)
     
     train_set = CIFAR10('./data', train=True, transform=transform, download=True)
     train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
     test_set = CIFAR10('./data', train=False, transform=transform, download=True)
     test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)
     
     net = vgg()
     optimizer = torch.optim.SGD(net.parameters(), lr=1e-1)
     criterion = nn.CrossEntropyLoss() #损失函数为交叉熵
     
     train(net, train_data, test_data, 50, optimizer, criterion)
     torch.save(net, 'vgg_model.pth')

    结束后,会出现一个模型文件vgg_model.pth

    二,然后网上找张图片,把图片缩成32x32,放到预测代码中,即可有预测结果出现,预测代码如下:

    import torch
    import cv2
    import torch.nn.functional as F
    from vgg2 import vgg ##重要,虽然显示灰色(即在次代码中没用到),但若没有引入这个模型代码,加载模型时会找不到模型
    from torch.autograd import Variable
    from torchvision import datasets, transforms
    import numpy as np
     
    classes = ('plane', 'car', 'bird', 'cat',
       'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
    if __name__ == '__main__':
     device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
     model = torch.load('vgg_model.pth') # 加载模型
     model = model.to(device)
     model.eval() # 把模型转为test模式
     
     img = cv2.imread("horse.jpg") # 读取要预测的图片
     trans = transforms.Compose(
      [
       transforms.ToTensor(),
       transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
      ])
     
     img = trans(img)
     img = img.to(device)
     img = img.unsqueeze(0) # 图片扩展多一维,因为输入到保存的模型中是4维的[batch_size,通道,长,宽],而普通图片只有三维,[通道,长,宽]
     # 扩展后,为[1,1,28,28]
     output = model(img)
     prob = F.softmax(output,dim=1) #prob是10个分类的概率
     print(prob)
     value, predicted = torch.max(output.data, 1)
     print(predicted.item())
     print(value)
     pred_class = classes[predicted.item()]
     print(pred_class)
     
     # prob = F.softmax(output, dim=1)
     # prob = Variable(prob)
     # prob = prob.cpu().numpy() # 用GPU的数据训练的模型保存的参数都是gpu形式的,要显示则先要转回cpu,再转回numpy模式
     # print(prob) # prob是10个分类的概率
     # pred = np.argmax(prob) # 选出概率最大的一个
     # # print(pred)
     # # print(pred.item())
     # pred_class = classes[pred]
     # print(pred_class)

    缩成32x32的图片:

    运行结果:

    以上这篇pytorch VGG11识别cifar10数据集(训练+预测单张输入图片操作)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持北冥有鱼。


    广而告之:
    热门推荐:
    怎么能够做出攻克用户心理的网页交互设计

      由于百度的多次调整,让站长们越来越注重用户体验,因此网站的制作与设计也越来越趋于人性化,甚至有不少站长把研究用户行为当作网站中最为重要的工作。网站建设当中网页的设计显得尤为重要,尤其是对于网页互动功能的设计!   心理学越来越成为重要组成部分,或者说这···

    嵌套repeater示例分享

    前台代码部分: 复制代码 代码如下:<asp:repeater runat="server" id="repeater1" OnItemDataBound="function2"> <itemtemplate>   <asp:repeater runat="server" id="repeater2">   <itemtemplate>   </itemtemplate>   </asp···

    JSON字符串转换JSONObject和JSONArray的方法

    一.下载json 具体到http://www.json.org/上找Java-json下载,并把其放到项目源代码中,这样就可以引用其类对象了 二.具体转化过程 //JSONObject String jsonMessage = "{\"语文\":\"88\",\"数学\":\"78\",\"计算机\":\"99\"}"; String value1 = null; try { //将字符串转换成j···

    JavaScript实现短信倒计时60s

    废话不多说了,直接给大家贴代码了,具体代码如下所示; $(function(){ //获取验证码 var getCode = document.getElementById('getCode'); var wait = 60; function time(btn){ if (wait===0) { btn.removeAttribute("disabled"); btn.in···

    PHP HTTP 认证实例详解

    HP来实现HTTP的强制认证是十分简单的,只需简单的几行代码就可以实现,下面我们来看一个例子,然后结合这里例子我向大家详细介绍一下PHP实现HTTP认证。 <?php if(!isset($_SERVER['PHP_AUTH_USER'])) { header('WWW-Authenticate: Basic realm="系统名称"'); header(···

    jQuery表单对象属性过滤选择器实例详解

    本文实例讲述了jQuery表单对象属性过滤选择器。分享给大家供大家参考,具体如下: <html> <head> <meta http-equiv="Content-Type" content="text/html; charset=utf-8" /> <title>2-11</title> <script src="jquery-1.7.2.min.js" type="te···

    HTML5的Video标签有部分MP4无法播放的问题解析(多图)

    现在网页视频的应用极为广泛,在实际项目中发现有些MP4文件可以在H5网页中正常播放,但有些却不行,这是为什么呢? 1、首先我自己从网上下载的一个MP4文件,IE10和谷歌都能正常播放。然后自己用格式化工厂转换了一个RMVB文件为MP4,然后就不能播放。 如下图,我选择MP4格式转换···

    蓝帽SEO是属于优化技术吗?关于蓝帽优化的特点与seo结构标准

    在搜索引擎优化中,提高目标网站在搜索引擎中的自然排名。SEO作为一种策略比技术更重要。所以,跟随织梦58SEO编辑了解蓝帽SEO是否是SEO中的优化方法?  一、什么是蓝帽SEO  蓝帽SEO与白帽SEO和黑帽SEO在传统理解和鲜为人知的绿帽SEO定义上有很大不同。蓝···

    js使浏览器窗口最大化实现代码(适用于IE)

    这里使用的方法是IE的私有特性,只能在IE中有效。主要是window.moveTo和 window.resizeTo方法。 效果和点击最大化按钮差不多,有一点区别。点击最大化按钮后,浏览器的内容填充满显示器,浏览器窗口的边框被挤出显示器。而该js的最大化效果是浏览器的边框在显示器内显示···

    destoon切换城市后实现logo旁边显示地区名称的方法

    本文讲述了destoon切换城市后实现logo旁边显示地区名称的方法,针对不同地区建设分站的情况非常适用。 一般来说,当我们进入网站后默认的是总站,当我们开启城市分站的时候,点击选择分站后,在logo的旁边可以看到你选的城市分站名称,当选择全国的时候在logo旁边什么也不显示···