Python深度学习SRCNN SRGAN ESRGAN 图像超分辨率研究 (Python代写,Deep Learning代写,Computer Vision代写,Tensorflow代写)

采用DIV2K数据集,包含800张训练集,100张验证集、100张测试集

联系我们
微信: biyeprodaixie 欢迎联系咨询

本次CS代写的主要涉及如下领域: Python代写,Deep Learning代写,Computer Vision代写,Tensorflow代写

训练细节:

数据集:

  1. 采用DIV2K数据集,包含800张训练集,100张验证集、100张测试集,把每张图片分割成256*256分辨率的HR图片,作为Groud truth
  2. scaling factor:x4,低分辨率的LR图片分辨率为64*64
  3. 采用matlab bicubic kernel(双三次插值算法)对LR图片进行上采样,得到256*256的插值图片(作为SRCNN的输入,以及作为传统的超分辨率算法结果,与神经网络算法结果对比)

训练参数:

  1. SRCNN采用mse(均方误差)作为损失函数,训练了一个基于PSNR指标的超分模型;
  2. SRGAN的损失函数分为两个部分:判别器的损失函数是对一张输入图片判断是真或假的交叉熵函数,生成器的损失函数是前面交叉熵函数的取反,以及真图片与生成图片的vgg19_feature的均方误差,两者的权重分别是0.001和0.006
  3. ESRGAN在SRGAN的基础上做了以下改进:
    1. 取消了Batch Normalization层,因为BN层的使用会导致得到的结果锐度不够,丢失高频信息;
    2. 生成网络的基本单元从基本的残差单元变为 Residual-in-Residual Dense Block
    3. GAN网络改进为RaGAN,即判别器不再只判断图片是真或假,而是判断一张图片相比于真实图片或者生成图片的更像真或更像假的一个相对值。
  4. 初始学习率为1e-4,每经过100次迭代,学习率乘以衰减速率0.98
  5. 训练采用的优化器均为Adam
  6. 训练三个模型均为RGB三个通道直接训练
import math
import tensorflow as tf
from tensorflow import keras
import json

import model
from utils import *
from config import *

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "4"


class TrainProcess:
    def __init__(self):
        self._build_data_set()
        self._build_model()
        self._build_callbacks()
        self._mses = dict()
        self._psnrs = dict()
        self._ssims = dict()
        self._log_idx = 0

    def _build_data_set(self):
        self._data_set = MyDataSet(
            TRAIN_HR_PATH,
            TRAIN_BIC_PATH,
            VALID_HR_PATH,
            VALID_BIC_PATH,
            IMAGE_SUFFIX,
            shuffle=True
        )

    def _build_model(self):
        print("building SRCNN Model.......................................")
        self._model, self._output_size_offset = model.get_model()
        self._model.compile(
            optimizer=keras.optimizers.Adam(learning_rate=LEARNING_RATE, beta_1=0.8),
            loss=keras.losses.mean_squared_error,
            run_eagerly=True
        )

    def _build_callbacks(self):
        tensor_board = keras.callbacks.TensorBoard(
            log_dir=LOG_DIR,
            histogram_freq=1,
            write_graph=True,
            write_images=True
        )
        reduce_lr = keras.callbacks.ReduceLROnPlateau(
            monitor="val_loss",
            factor=0.96,
            patience=3,
            verbose=1,
            mode="auto",
            epsilon=1e-4,
            cooldown=0,
            min_lr=0
        )
        sample_callback = keras.callbacks.LambdaCallback(
            on_epoch_end=lambda epoch, logs: self.save_samples(epoch, logs),
            on_batch_end = lambda batch, logs: self.log_psnr_and_ssim(batch, logs)
        )
        self._callbacks = [tensor_board, reduce_lr, sample_callback]


    def save_samples(self, epoch, logs):
        model_save_filename = os.path.join(MODEL_SAVE_DIR, "srcnn_model_%03d.h5" % epoch)
        self._model.save(model_save_filename)
        x, y = self._data_set.get_batch(batch_size=BATCH_SIZE, size_offset=self._output_size_offset, is_train=False)
        input_ = x[0]
        output_ = self._model.predict(x)[0]
        gt_ = y[0]
        cv2.imwrite(
            os.path.join(SAMPLE_SAVE_DIR, "sample_batch%d.png" % epoch),
            concat_three_images(
                normalized_array_to_image(input_),
                normalized_array_to_image(output_),
                normalized_array_to_image(gt_))
        )

    def data_generator(self, batch_size, size_offset, is_train):
        while True:
            x, y = self._data_set.get_batch(batch_size=batch_size, size_offset=size_offset, is_train=is_train)
            yield {"inputs": x}, {"outputs": y}

    def train(self):
        total_size = self._data_set.get_training_set_size()
        # steps_per_epoch = math.ceil(total_size / BATCH_SIZE)
        steps_per_epoch = 100
        result = self._model.fit_generator(
            generator=self.data_generator(BATCH_SIZE, self._output_size_offset, True),
        )
        print(result)

        with open(os.path.join(RES_DIR, "psnr.json"), "w") as f:
            js = json.dumps(self._psnrs)
            f.write(js)
        with open(os.path.join(RES_DIR, "ssim.json"), "w") as f:
            js = json.dumps(self._ssims)
            f.write(js)
        with open(os.path.join(RES_DIR, "mse.json"), "w") as f:
            js = json.dumps(self._mses)
            f.write(js)


def main():
    if not os.path.exists(SAMPLE_SAVE_DIR):
        os.mkdir(SAMPLE_SAVE_DIR)
    if not os.path.exists(MODEL_SAVE_DIR):
        os.mkdir(MODEL_SAVE_DIR)
    if not os.path.exists(LOG_DIR):
        os.mkdir(LOG_DIR)
    if not os.path.exists(RES_DIR):
        os.mkdir(RES_DIR)

    my_process = TrainProcess()
    my_process.train()


if __name__ == "__main__":
    main()