交流群:462197261站长百科站长论坛热门标签收藏本站北冥有鱼 互联网前沿资源第一站 助力全行业互联网+
点击这里给我发消息
  • 当前位置:
  • keras 自定义loss层+接受输入实例

    北冥有鱼 教程大全 2020-06-28 ,,,

    loss函数如何接受输入值

    keras封装的比较厉害,官网给的例子写的云里雾里,

    在stackoverflow找到了答案

    You can wrap the loss function as a inner function and pass your input tensor to it (as commonly done when passing additional arguments to the loss function).

    def custom_loss_wrapper(input_tensor):
     def custom_loss(y_true, y_pred):
      return K.binary_crossentropy(y_true, y_pred) + K.mean(input_tensor)
     return custom_loss
    input_tensor = Input(shape=(10,))
    hidden = Dense(100, activation='relu')(input_tensor)
    out = Dense(1, activation='sigmoid')(hidden)
    model = Model(input_tensor, out)
    model.compile(loss=custom_loss_wrapper(input_tensor), optimizer='adam')

    You can verify that input_tensor and the loss value will change as different X is passed to the model.

    X = np.random.rand(1000, 10)
    y = np.random.randint(2, size=1000)
    model.test_on_batch(X, y) # => 1.1974642
    X *= 1000
    model.test_on_batch(X, y) # => 511.15466
    

    fit_generator

    fit_generator ultimately calls train_on_batch which allows for x to be a dictionary.

    Also, it could be a list, in which casex is expected to map 1:1 to the inputs defined in Model(input=[in1, …], …)

    ### generator
    yield [inputX_1,inputX_2],y
    ### model
    model = Model(inputs=[inputX_1,inputX_2],outputs=...)

    补充知识:keras中自定义 loss损失函数和修改不同样本的loss权重(样本权重、类别权重)

    首先辨析一下概念:

    1. loss是整体网络进行优化的目标, 是需要参与到优化运算,更新权值W的过程的

    2. metric只是作为评价网络表现的一种“指标”, 比如accuracy,是为了直观地了解算法的效果,充当view的作用,并不参与到优化过程

    一、keras自定义损失函数

    在keras中实现自定义loss, 可以有两种方式,一种自定义 loss function, 例如:

    # 方式一
    def vae_loss(x, x_decoded_mean):
     xent_loss = objectives.binary_crossentropy(x, x_decoded_mean)
     kl_loss = - 0.5 * K.mean(1 + z_log_sigma - K.square(z_mean) - K.exp(z_log_sigma), axis=-1)
     return xent_loss + kl_loss
     
    vae.compile(optimizer='rmsprop', loss=vae_loss)

    或者通过自定义一个keras的层(layer)来达到目的, 作为model的最后一层,最后令model.compile中的loss=None:

    # 方式二
    # Custom loss layer
    class CustomVariationalLayer(Layer):
     
     def __init__(self, **kwargs):
      self.is_placeholder = True
      super(CustomVariationalLayer, self).__init__(**kwargs)
     def vae_loss(self, x, x_decoded_mean_squash):
     
      x = K.flatten(x)
      x_decoded_mean_squash = K.flatten(x_decoded_mean_squash)
      xent_loss = img_rows * img_cols * metrics.binary_crossentropy(x, x_decoded_mean_squash)
      kl_loss = - 0.5 * K.mean(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
      return K.mean(xent_loss + kl_loss)
     
     def call(self, inputs):
     
      x = inputs[0]
      x_decoded_mean_squash = inputs[1]
      loss = self.vae_loss(x, x_decoded_mean_squash)
      self.add_loss(loss, inputs=inputs)
      # We don't use this output.
      return x
     
    y = CustomVariationalLayer()([x, x_decoded_mean_squash])
    vae = Model(x, y)
    vae.compile(optimizer='rmsprop', loss=None)

    在keras中自定义metric非常简单,需要用y_pred和y_true作为自定义metric函数的输入参数 点击查看metric的设置

    注意事项:

    1. keras中定义loss,返回的是batch_size长度的tensor, 而不是像tensorflow中那样是一个scalar

    2. 为了能够将自定义的loss保存到model, 以及可以之后能够顺利load model, 需要把自定义的loss拷贝到keras.losses.py 源代码文件下,否则运行时找不到相关信息,keras会报错

    有时需要不同的sample的loss施加不同的权重,这时需要用到sample_weight,例如

    discriminator.train_on_batch(imgs, [valid, labels], class_weight=class_weights)

    二、keras中的样本权重

    # Import
    import numpy as np
    from sklearn.utils import class_weight
     
    # Example model
    model = Sequential()
    model.add(Dense(32, activation='relu', input_dim=100))
    model.add(Dense(1, activation='sigmoid'))
     
    # Use binary crossentropy loss
    model.compile(optimizer='rmsprop',
        loss='binary_crossentropy',
        metrics=['accuracy'])
     
    # Calculate the weights for each class so that we can balance the data
    weights = class_weight.compute_class_weight('balanced',
               np.unique(y_train),
               y_train)
     
    # Add the class weights to the training           
    model.fit(x_train, y_train, epochs=10, batch_size=32, class_weight=weights)

    Note that the output of the class_weight.compute_class_weight() is an numpy array like this: [2.57569845 0.68250928].

    以上这篇keras 自定义loss层+接受输入实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持北冥有鱼。


    广而告之:
    热门推荐:
    Enter转换为Tab的小例子(兼容IE,Firefox)

    复制代码 代码如下:document.onkeydown=function(e){  var e=window.event||e;  var element=e.srcElement||e.target;  if(e.keyCode==13&&element.type!="submit"&&element.type!="button"&&element.type!="textarea"&&element.t···

    C3P0连接池+MySQL的配置及wait

     一、配置环境 spring4.2.4+mybatis3.2.8+c3p0-0.9.1.2+Mysql5.6.24 二、c3p0的配置详解及spring+c3p0配置 1.配置详解 官方文档 : http://www.mchange.com/projects/c3p0/index.html <c3p0-config> < default-config> <!--当连接池中的连接耗尽的时候c3p0···

    纯css下拉菜单 无需js

    再来个今天某人说过的例子:纯css下拉菜单: 效果图 这个的实现很简单,主要是:hover和过渡属性transition的使用。 代码: <!DOCTYPE html> <html lang="en"> <head> <meta charset="UTF-8"> <title>css</title> <style> *{···

    无论是个人主页制作还是企业主页制作,都应该满足搜索引擎优化原

    对于主页设计,我们经常面临着个人、企业和网站主页设计的各种情况,但实际上,无论是个人主页制作还是企业主页设计,从SEO的角度来看,在很大程度上,我们都需要满足一定的基本原则。    1、主页架构  一般的主页结构主要包括:网站导航(head···

    详解用webpack把我们的业务模块分开打包的方法

    webpack我自己还在摸索学习中,今天给大家分享个用webpack把我们的业务模块分开打包的方法,顺便留个笔记   如何用webpack打包这3个js? 只需修改webpack的配置文件webpack.config.js: // entry是入口文件,可以多个,代表要编译那些js entry:['./src/main.js',···

    添加和删除HTML节点的简单示例

    添加和删除HTML节点的简单示例 添加和删除HTML节点的简单示例 <input type="button" onclick="appendnode()" value="添加节点"><input type="button" onclick="removenode()" value=&qu···

    Vuex之理解Getters的用法实例

    1.什么是getters 在介绍state中我们了解到,在Store仓库里,state就是用来存放数据,若是对数据进行处理输出,比如数据要过滤,一般我们可以写到computed中。但是如果很多组件都使用这个过滤后的数据,比如饼状图组件和曲线图组件,我们是否可以把这个数据抽提出来共享?这就···

    js解析与序列化json数据(三)json的解析探讨

    这一节我们主要讨论json的解析。 JSON.parse()方法也可以接收另一个参数,该参数是一个函数,将早每个键值对上调用。为了区别JSON.stringify()接收的替换(过滤)函数(replacer),这个函数被称作还原函数(reviver),但实际上这两个函数的签名是相同的——它们都接收连···

    编写更好的JavaScript条件式和匹配条件的技巧(小结)

    介绍 如果你像我一样乐于见到整洁的代码,那么你会尽可能地减少代码中的条件语句。通常情况下,面向对象编程让我们得以避免条件式,并代之以继承和多态。我认为我们应当尽可能地遵循这些原则。 正如我在另一篇文章 JavaScript 整洁代码的最佳实践里提到的,你写的代码不单单是···

    2012年10月手机品牌网络广告投放费用排行榜Top10

    iResearch艾瑞咨询根据网络广告监测系统iAdTracker的最新数据研究发现,2012年10月,手机品牌网络广告总投放费用达2698万元。其中,三星电子投放费用环比增长143.7%达921万元,位居第一;宏达投放费用环比增长66.9%达856万元,位居第二;联想投放费用环比下降48.5%至160万元,···