class DownConvLayer(tf.keras.layers.Layer):
def __init__(self, dim):
super(DownConvLayer, self).__init__()
self.conv = tf.keras.layers.Conv2D(dim, 3, activation=tf.keras.layers.ReLU(), use_bias=False, padding='same')
self.pool = tf.keras.layers.MaxPool2D(2)
def call(self, x, training=False, **kwargs):
class UpConvLayer(tf.keras.layers.Layer):
def __init__(self, dim):
super(UpConvLayer, self).__init__()
self.conv = tf.keras.layers.Conv2D(dim, 3, activation=tf.keras.layers.ReLU(), use_bias=False, padding='same')
self.pool = tf.keras.layers.UpSampling2D(2)
def call(self, x, training=False, **kwargs):
# 示例代码都是通过非常简单的卷积操作实现编码器和解码器
class Encoder(tf.keras.layers.Layer):
def __init__(self, dim, layer_num=3):
super(Encoder, self).__init__()
self.convs = [DownConvLayer(dim) for _ in range(layer_num)]
def call(self, x, training=False, **kwargs):
class Decoder(tf.keras.layers.Layer):
def __init__(self, dim, layer_num=3):
super(Decoder, self).__init__()
self.convs = [UpConvLayer(dim) for _ in range(layer_num)]
self.final_conv = tf.keras.layers.Conv2D(1, 3, strides=1)
def call(self, x, training=False, **kwargs):
reconstruct = self.final_conv(x)
class AutoEncoderModel(tf.keras.Model):
super(AutoEncoderModel, self).__init__()
self.encoder = Encoder(64, layer_num=3)
self.decoder = Decoder(64, layer_num=3)
def call(self, inputs, training=None, mask=None):
latent = self.encoder(image, training)
reconstruct_img = self.decoder(latent, training)
def train_step(self, data):
with tf.GradientTape() as tape:
reconstruct_img = self((img,), True)
trainable_vars = self.trainable_variables
# 利用l2 loss 来判断重建图片和原始图像的一致性
l2_loss = (reconstruct_img - img) ** 2
l2_loss = tf.reduce_mean(tf.reduce_sum(
gradients = tape.gradient(l2_loss, trainable_vars)
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
return {"l2_loss": l2_loss}