2026-03-12 21:19:45 +08:00

89 lines
5.1 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# CycleGAN and pix2pix in PyTorch
该项目来源于网上开源代码, 详细情况可参考https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix, 你也许能在[training/test tips](docs/tips.md) and [frequently asked questions](docs/qa.md)中找到有用的信息。我们主要采用该项目中的CycleGAN模型进行SAR图像生成。
## 先决条件
- Linux or macOS
- Python 3+
- CPU or NVIDIA GPU + CUDA CuDNN
## 安装
1.进入[Anaconda](https://www.anaconda.com/download)官网下载Anaconda, 或者从清华大学[开源镜像网站](https://mirrors.tuna.tsinghua.edu.cn/)里安装, 安装过程中注意进行环境变量的配置, 将Anaconda加入环境变量中
2.下载该项目后, 在终端中切换到项目根目录
`cd pytorch-CycleGAN-and-pix2pix`
3.安装 [PyTorch](http://pytorch.org) 以及其他依赖 (例如, torchvision, [visdom](https://github.com/facebookresearch/visdom) and [dominate](https://github.com/Knio/dominate)).
- 对于Conda用户, 请执行以下操作
- 在终端中创建虚拟环境, 环境名为pytorch-img2img `conda env create -f environment.yml`
- 激活该虚拟环境 `conda activate pytorch-img2img`
若要使用显卡训练其一在终端中输入nvidia-smi检查显卡的cuda版本号安装对应版本的pytorch
其二注意检查安装的pytorch是否是带有cuda的版本比如 pytorch pytorch/linux-64::pytorch-1.12.1-py3.9_cuda11.3_cudnn8.3.2_0
如果执行上述操作后运行代码显示缺少XXX模块,或者XXX模块安装失败, 可在该虚拟环境中使用 `pip install XXX` 单独安装对应的模块。注意要先激活虚拟环境,在虚拟环境中进行安装
对于其他用户,可参考原始网站的安装说明
## 自定义数据集
部分光学和SAR数据集可[点此下载](https://pan.baidu.com/s/1Y6-BilxNH-k5ZQhdP0Wurw?pwd=jvmp), 也可在网上根据任务寻找其他数据集下载
预训练模型[下载地址](https://pan.baidu.com/s/1KwTuWVrqLLSuza_xCoJZLg?pwd=icqh), 主要是飞机和车辆SAR目标生成
在datasets/opt2sar目录下创建四个文件夹, trainA, trainB, TestA, TestB, 存放用来训练的光学图片, 用来训练的sar图片, 用来测试的光学图片和sar图片。由于是采用CycleGAN模型进行训练, 因此文件夹A和B中的图片不要求一一对应。数据集结构如下
```python
opt2sar
├── trainA
├── 0001.jpg
├── 0002.jpg
├── .....
├── trainB
├── 0001.jpg
├── 0002.jpg
├── .....
├── testA
├── 0001.jpg
├── 0002.jpg
├── .....
├── testB
├── 0001.jpg
├── 0002.jpg
├── .....
```
## CycleGAN 训练/测试
### train the model:
处理好数据集后, 先激活虚拟环境,在终端中输入以下命令行进行训练,注意路径首先要切换到项目根目录下
```bash
python train.py --dataroot ./datasets/opt2sar --name opt2sar_cyclegan --model cycle_gan --batch_size 4 --input_nc 1 --output_nc 1 --gpu_ids 0 --display_id -1
```
- dataroot 表示数据集的存放路径
- name 表示模型训练过程中存放权重的文件夹名称, 权重保存到checkpoints下文件夹下, 每训练5轮保存一次权重, 例如这行命令保存的权重就在 `./checkpoints/opt2sar_cyclegan`
- model 表示使用的模型是cycle_gan还是pix2pix
- batch_size 表示批处理大小,可根据显卡的显存大小自行调整
- input_nc output_nc 表示模型输入图片和输出图片的通道数, 设为1表示单通道, 设为3表示RGB多通道, 注意输入输出通道数要一致
- gpu_ids 表示显卡设备的ID
- display_id 该参数设为-1, 表示关闭web服务, 该服务需要联网会影响训练速度
### Test the model:
训练完成后, 执行以下命令行进行测试, 通过`--name opt2sar_cyclegan`参数设定要加载的权重所在的文件夹, 模型默认加载该文件夹下最新的权重, 注意input_nc和 output_nc的参数需要和训练命令行中保持一致。此外, 模型默认一次性测试50张图片, 如需更改,请在`./options/test_options.py`中修改 `--num_test`参数
```bash
python test.py --dataroot ./datasets/opt2sar --name opt2sar_cyclegan --model cycle_gan --input_nc 1 --output_nc 1
```
- 测试结果将会保存到 `./results/opt2sar_cyclegan`文件夹下.
### pre-trained the model:
如果需要加载训练好的权重进行微调或者在训练过程中因意外中断训练,想要继续训练模型,请执行以下命令行
```bash
python train.py --dataroot ./datasets/opt2sar --name opt2sar_cyclegan --model cycle_gan --batch_size 4 --input_nc 1 --output_nc 1 --gpu_ids 0 --continue_train --epoch_count 150 --display_id -1
```
- continue_train 表示加载 `./checkpoints/opt2sar_cyclegan`下的权重进行训练
- epoch_count 设定从第几个epoch开始训练, 代码默认训练200个epoch, 设为150表示从第150轮开始训练
如果想了解更多关于训练和测试的参数设置, 可浏览options文件下base_options.py, train_options.py, test_options.py这三个文件