TensorFlow Image Style Transfer


发布于

|

分类

,

学框架总要有个练手项目的。这次学习 TensorFlow,练手项目是 Style Transfer。

首先放一下项目地址: Github,除了逻辑清晰一点之外,没有别的任何优点。

然后再说原理和实现。

原理

一张 Content Image(提供 “内容” 的图片)经过 Transfer Net 卷卷卷之后,连同 Style Image(提供风格的图片)一同送进 VGG16,抽取特定层的输出,计算 loss,然后反向传播回来。

Transfer Net 是一个全卷积网络,先卷积,然后是几个非线性变换,最后是反卷积。

VGG16 只是前向计算,不参与梯度反向传播。

计算三种 Loss: Content Loss、Style Loss 和 Total Variation Loss。

实现

数据读取

代码: reader.py

这里使用了 TensorFlow 的 Dataset API。下一个版本中这个接口会发生一些小变化,所以之后再改。

首先用各种方法,获取文件夹里面的所有图片,之后将它们送入 Dataset。取出的时候,设置一个映射函数,将文件名映射到真实图片,然后做各种预处理。

搭建 Transfer Net

代码: model.py

先 Padding 一下以减少边缘效应,之后使用 Slim 进行各种卷积和非线性变换,然后再各种卷积回来,最后把 Padding 去掉。

这里使用一个比较特殊的 resize_conv2d 进行反卷积。好像各种卷积层不能进行 batch norm 操作。最后一层卷积使用 tanh 作为激活函数,可以直接把输出范围限定在 $[0,1]$ 。

获取 Loss 计算网络

代码: model.py

直接获取某个常用网络,然后加载预训练权重。需要注意的是 fc 层权重不需要加载。

这里一直担心会不会现在使用的图和前面的图不一样,好像这种担心是多余的。

计算 Loss

代码: loss.py

按照论文的说明计算 Loss 即可。

工具

代码: config.py

设置项太多,借鉴 何大佬 的思路,搞成 yaml,比较优雅。

组装

代码: fast_train.py

将上面几个部分串联起来,然后搞一个 Optimizer 优化 Loss。

在这里我使用了 tf.estimator.Estimator 这个 API:写 input_fnmodel_fn,然后构建 Estimator,最后 train

有人说为啥不直接用 slim 的 train?因为数据集制作那部分看不懂…… 失败了好久,最终决定使用 tf.estimator.Estimator

运行

代码: fast_evaluate.py

只需要加载 Transfer Net 这部分网络即可。获取 Content Image,送入网络,得到结果。

比较坑的是,之前的读图片的函数好像不会自动推断图片大小,然后网络一直报错。最后换成 Skimage 读取,Placeholder 直接设置图片大小,就 OK 了。虽然没有那么优雅。

最后想生成个单独的 Transfer Net 的 ckpt 文件,一直提示说不在同一个图中。作罢。

坑们

不遇坑是不可能的,特别是对于我这种新手。

总结一句话: 简直神坑。

比如,fast_evaluate.py里面,将数据送入 TransferNet,第一个卷积层一直报错。江哥帮忙 Debug 到 11 点,最终发现是网络需要 float32 的数据,但是输入图片给的是 uint8 的数据。

还有,如果不学 tfdbg 的话,各种地方都只能脑补 debug。静态图结构还好说,运行起来就完全黑箱了。

最后 checkpoint 我到底不知道里面有什么东西,以及怎么生成单独 Transfer Net 的 ckpt 文件。搞了好久都没有成功,只好放弃了。

文档不全,很多东西需要脑补。另外更新太快,Google 出来的前天的东西,点进去都是 404(models 尤甚,最近大改版)。亦或者,搜出来的是 r1.2 的 API 和 master 的 api,但我现在是 r1.3——其他版本的API都不能在此版本中使用。

后记

  1. 已经从1.3.0迁移到了1.4.0
  2. 已经用上了slim.learning.train

参考资料和感谢


评论

发表回复

您的电子邮箱地址不会被公开。 必填项已用*标注