TensorFlow Keras 的 Layer
是所有 Keras 层的核心基类,提供通用功能如参数管理、输入输出处理等。所有自定义层都应继承此类。
核心特性 ✅
- 输入形状定义:通过
input_shape
参数指定输入维度input_shape=(None, 64) # 支持变长序列
- 配置保存:使用
get_config()
实现层的序列化Keras_Layer - 训练/推理模式:自动区分
training=True
和training=False
场景
常用方法 📦
方法名 | 功能描述 |
---|---|
__init__ |
初始化层参数 |
build |
构建层的权重(可重写) |
call |
定义前向传播逻辑(必须实现) |
compute_output_shape |
计算输出形状(可重写) |
子类化示例 🚀
class MyCustomLayer(tf.keras.layers.Layer):
def __init__(self, units=32, **kwargs):
super().__init__(**kwargs)
self.units = units
def build(self, input_shape):
self.kernel = self.add_weight(
name="kernel", shape=[int(input_shape[-1]), self.units],
initializer="uniform", trainable=True
)
def call(self, inputs):
return tf.matmul(inputs, self.kernel)