Beispiel #1
0
# VAE is disabled for this run. Including it in the model would cause style transfer
# to/from just the ADE20K corpus AFAIK, as previous experiments with training
# this model from scratch demonstrated that the size of the Bob Ross corpus is not
# sufficient for the style transfer elements to really kick in.

import sys; sys.path.append('../lib/SPADE-master/')
from options.train_options import TrainOptions
from models.pix2pix_model import Pix2PixModel
from collections import OrderedDict
import data
from util.iter_counter import IterationCounter
from util.visualizer import Visualizer
from trainers.pix2pix_trainer import Pix2PixTrainer
import os

opt = TrainOptions()
opt.D_steps_per_G = 1
opt.aspect_ratio = 1.0
opt.batchSize = 8
opt.beta1 = 0.0                           
opt.beta2 = 0.9
opt.cache_filelist_read = False                         
opt.cache_filelist_write = False                         
opt.checkpoints_dir = '/spell/checkpoints/'
opt.contain_dontcare_label = True                          
opt.continue_train = False
opt.crop_size = 256                           
opt.dataroot = '/spell/adek20k'  # data mount point
opt.dataset_mode = "ade20k"
opt.debug = False                         
opt.display_freq = 100                           
Beispiel #2
0
        python train.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan
    Train a pix2pix model:
        python train.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --direction BtoA

See options/base_options.py and options/train_options.py for more training options.
See training and test tips at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/tips.md
See frequently asked questions at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/qa.md
"""
import time
from options.train_options import TrainOptions
from data import create_dataset
from models import create_model
from util.visualizer import Visualizer

if __name__ == '__main__':
    opt = TrainOptions().parse()  # get training options
    dataset = create_dataset(
        opt)  # create a dataset given opt.dataset_mode and other options
    dataset_size = len(dataset)  # get the number of images in the dataset.
    print('The number of training images = %d' % dataset_size)

    model = create_model(
        opt)  # create a model given opt.model and other options
    model.setup(
        opt)  # regular setup: load and print networks; create schedulers
    visualizer = Visualizer(
        opt)  # create a visualizer that display/save images and plots
    total_iters = 0  # the total number of training iterations

    for epoch in range(
            opt.epoch_count, opt.niter + opt.niter_decay + 1
Beispiel #3
0
### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
import time
from collections import OrderedDict
from options.train_options import TrainOptions
from data.data_loader import CreateFaceConDataLoader
from models.models import create_model
import util.util as util
from util.visualizer import Visualizer
import os
import numpy as np
import torch
from torch.autograd import Variable
import tensorboardX
import random

opt = TrainOptions().parse()
opt_test = TrainOptions().parse()
opt_test.phase = 'val'
opt_test.nThreads = 1
opt_test.batchSize = 1
opt_test.serial_batches = False
opt_test.no_flip = True
data_loader_test = CreateFaceConDataLoader(opt_test)
dataset_test_ = data_loader_test.load_data()
dataset_test = dataset_test_.dataset
'''
for i, data in enumerate(dataset_test_):
    print(i)
    dataset_test.append(data)
    if (i > 10000):
        break
Beispiel #4
0
import os
import numpy as np
import torch
from torch.autograd import Variable
from collections import OrderedDict
from subprocess import call
import fractions
def lcm(a,b): return abs(a * b)/fractions.gcd(a,b) if a and b else 0

from options.train_options import TrainOptions
from data.data_loader import CreateDataLoader
from models.models import create_model
import util.util as util
from util.visualizer import Visualizer

opt = TrainOptions().parse()
iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')
if opt.continue_train:
    try:
        start_epoch, epoch_iter = np.loadtxt(iter_path , delimiter=',', dtype=int)
    except:
        start_epoch, epoch_iter = 1, 0
    print('Resuming from epoch %d at iteration %d' % (start_epoch, epoch_iter))        
else:    
    start_epoch, epoch_iter = 1, 0

opt.print_freq = lcm(opt.print_freq, opt.batchSize)    
if opt.debug:
    opt.display_freq = 1
    opt.print_freq = 1
    opt.niter = 1
Beispiel #5
0
import time
from options.train_options import TrainOptions
from data import CreateDataLoader
from models import create_model
from util.visualizer import Visualizer
from util.laplotter import LossAccPlotter
import pdb

if __name__ == '__main__':
    opt = TrainOptions().parse()
    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    print('#training images = %d' % dataset_size)

    # validate
    opt.phase = 'val'
    validation_loader = CreateDataLoader(opt)
    validation_dataset = validation_loader.load_data()
    validation_dataset_size = len(validation_loader)
    print('validate images = %d' % validation_dataset_size)

    model = create_model(opt)  # has been initialized
    visualizer = Visualizer(opt)
    total_steps = 0

    plotter = LossAccPlotter(save_to_filepath='./checkpoints/nyud_fcrn/')

    for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()
        iter_data_time = time.time()
Beispiel #6
0
def main():
    opt = TrainOptions().parse()
    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    print('#training images = %d' % dataset_size)

    # read pix2pix/PAN moodel
    if opt.model == 'pix2pix':
        assert (opt.dataset_mode == 'aligned')
        from models.pix2pix_model import Pix2PixModel
        model = Pix2PixModel()
        model.initialize(opt)
    elif opt.model == 'pan':
        from models.pan_model import PanModel
        model = PanModel()
        model.initialize(opt)

    total_steps = 0

    batch_size = opt.batchSize
    print_freq = opt.print_freq
    epoch_count = opt.epoch_count
    niter = opt.niter
    niter_decay = opt.niter_decay
    display_freq = opt.display_freq
    save_latest_freq = opt.save_latest_freq
    save_epoch_freq = opt.save_epoch_freq

    for epoch in range(epoch_count, niter + niter_decay + 1):
        epoch_start_time = time.time()
        epoch_iter = 0

        for i, data in enumerate(dataset):
            # data --> (1, 3, 256, 256)
            iter_start_time = time.time()
            total_steps += batch_size
            epoch_iter += batch_size
            model.set_input(data)
            model.optimize_parameters()

            if total_steps % print_freq == 0:
                errors = model.get_current_errors()
                t = (time.time() - iter_start_time) / batch_size

                message = '(epoch: %d, iters: %d, time: %.3f) ' % (
                    epoch, epoch_iter, t)
                for k, v in errors.items():
                    message += '%s: %.3f ' % (k, v)
                print(message)

            # save latest weights
            if total_steps % save_latest_freq == 0:
                print('saving the latest model (epoch %d, total_steps %d)' %
                      (epoch, total_steps))
                model.save('latest')

        # save weights periodicaly
        if epoch % save_epoch_freq == 0:
            print('saving the model at the end of epoch %d, iters %d' %
                  (epoch, total_steps))
            model.save('latest')
            model.save(epoch)

        print('End of epoch %d / %d \t Time Taken: %d sec' %
              (epoch, niter + niter_decay, time.time() - epoch_start_time))
        model.update_learning_rate()
import time
from collections import OrderedDict
from options.train_options import TrainOptions
from data.data_loader import CreateDataLoader
from models.mapping_model import Pix2PixHDModel_Mapping
import util.util as util
from util.visualizer import Visualizer
import os
import numpy as np
import torch
import torchvision.utils as vutils
from torch.autograd import Variable
import datetime
import random

opt = TrainOptions().parse()
visualizer = Visualizer(opt)
iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')
if opt.continue_train:
    try:
        start_epoch, epoch_iter = np.loadtxt(iter_path,
                                             delimiter=',',
                                             dtype=int)
    except:
        start_epoch, epoch_iter = 1, 0
    visualizer.print_save('Resuming from epoch %d at iteration %d' %
                          (start_epoch - 1, epoch_iter))
else:
    start_epoch, epoch_iter = 1, 0

if opt.which_epoch != "latest":
from models import create_model

import torch
import torchvision
import torchvision.transforms as transforms

from util import util
import numpy as np
import progressbar as pb
import shutil

import datetime as dt
import matplotlib.pyplot as plt

if __name__ == '__main__':
    opt = TrainOptions().parse()
    opt.load_model = True
    opt.num_threads = 1  # test code only supports num_threads = 1
    opt.batch_size = 1  # test code only supports batch_size = 1
    opt.display_id = -1  # no visdom display
    opt.phase = 'test'
    opt.dataroot = './dataset/ilsvrc2012/%s/' % opt.phase
    opt.loadSize = 256
    opt.how_many = 5
    opt.aspect_ratio = 1.0
    opt.sample_Ps = [
        6,
    ]
    opt.load_model = True

    # number of random points to assign
Beispiel #9
0
import time
from options.train_options import TrainOptions
from dataset import image_data
from models import create_model

import shutil
import os
import torch
from torch.utils.data import random_split, DataLoader
import numpy as np
from tqdm import tqdm

if __name__ == '__main__':
    trainoptions = TrainOptions()
    opt = trainoptions.parse()

    # setup
    os.environ['CUDA_VISIBLE_DEVICES'] = str(opt.gpu_ids)
    projroot = os.path.join(os.getcwd(), opt.study)
    modelroot = os.path.join(projroot, 'save_models')
    dataroot = os.path.join(projroot, 'data')

    modelsubfolder = '{}_{}'.format(opt.study, opt.model)
    modelfolder = os.path.join(modelroot, modelsubfolder)
    opt.name = opt.name + '_' + modelsubfolder

    torch.manual_seed(opt.seed)
    torch.cuda.manual_seed(opt.seed)

    opt.input_nc, opt.output_nc = 1, 1
Beispiel #10
0
# By Samaneh Azadi
################################################################################

from torch import transpose
import torch.utils.data as data
from torch import index_select, LongTensor
from PIL import Image
import os
import os.path
import numpy as np
from scipy import misc
import random
from options.train_options import TrainOptions
import torch

opt = TrainOptions().parse()  # set CUDA_VISIBLE_DEVICES before import torch

IMG_EXTENSIONS = ['.png']


def is_image_file(filename):
    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)


def make_dataset(dir):
    images = []
    assert os.path.isdir(dir), '%s is not a valid directory' % dir

    for root, _, fnames in sorted(os.walk(dir)):
        for fname in fnames:
            if is_image_file(fname):
Beispiel #11
0
            
        visuals = OrderedDict([('synthesized_image', trainer.get_latest_generated()),
                                   ('real_image', data_i['image'])])
        visualizer.display_current_results(visuals, epoch, iter_counter.total_steps_so_far)

        if rank == 0:
            print('saving the latest model (epoch %d, total_steps %d)' %
                  (epoch, iter_counter.total_steps_so_far))
            trainer.save('latest')
            iter_counter.record_current_iter()

        trainer.update_learning_rate(epoch)
        iter_counter.record_epoch_end()

        if (epoch % opt.save_epoch_freq == 0 or epoch == iter_counter.total_epochs) and (rank == 0):
            print('saving the model at the end of epoch %d, iters %d' %
                  (epoch, iter_counter.total_steps_so_far))
            trainer.save(epoch)
    
    print('Training was successfully finished.')
    
if __name__ == '__main__':
    global TrainOptions
    TrainOptions = TrainOptions()
    opt = TrainOptions.parse(save=True)
    opt.world_size = opt.num_gpu
    opt.mpdist = True

    mp.set_start_method('spawn', force=True)
    mp.spawn(main_worker, nprocs=opt.world_size, args=(opt.world_size, opt))
Beispiel #12
0
python train.py --name_prefix demo --dataname RIMEScharH32W16 --capitalize --display_port 8192

See options/base_options.py and options/train_options.py for more training options.
"""
import time
from options.train_options import TrainOptions
from data import create_dataset
from models import create_model
from util.visualizer import Visualizer
from util.util import seed_rng
from util.util import prepare_z_y, get_curr_data
import torch
import os

if __name__ == '__main__':
    opt = TrainOptions().parse()  # get training options
    # Seed RNG
    seed_rng(opt.seed)
    torch.backends.cudnn.benchmark = True
    dataset = create_dataset(
        opt)  # create a dataset given opt.dataset_mode and other options
    dataset_size = len(dataset)  # get the number of images in the dataset.
    print('The number of training images = %d' % dataset_size)

    model = create_model(
        opt)  # create a model given opt.model and other options
    model.setup(
        opt)  # regular setup: load and print networks; create schedulers
    if opt.single_writer:
        opt.G_init = 'N02'
        opt.D_init = 'N02'
Beispiel #13
0
			print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps))
			model.save('latest')
			model.save(epoch)

		print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))

		if epoch > opt.niter:
			model.update_learning_rate()


if __name__ == '__main__':
	freeze_support()

	# python train.py --dataroot /.path_to_your_data --learn_residual --resize_or_crop crop --fineSize CROP_SIZE (we used 256)

	opt = TrainOptions().parse()
	opt.dataroot = 'D:\Photos\TrainingData\BlurredSharp\combined'
	opt.learn_residual = True
	opt.resize_or_crop = "crop"
	opt.fineSize = 256
	opt.gan_type = "gan"
	# opt.which_model_netG = "unet_256"

	# default = 5000
	opt.save_latest_freq = 100

	# default = 100
	opt.print_freq = 20

	data_loader = CreateDataLoader(opt)
	model = create_model(opt)
### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 
### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
from options.train_options import TrainOptions
from data.data_loader import CreateDataLoader
from models.models import create_model
import os
import util.util as util
from torch.autograd import Variable
import torch.nn as nn

opt = TrainOptions().parse()
opt.nThreads = 1
opt.batchSize = 1 
opt.serial_batches = True 
opt.no_flip = True
opt.instance_feat = True

name = 'features'
save_path = os.path.join(opt.checkpoints_dir, opt.name)

############ Initialize #########
data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()
dataset_size = len(data_loader)
model = create_model(opt)
util.mkdirs(os.path.join(opt.dataroot, opt.phase + '_feat'))

######## Save precomputed feature maps for 1024p training #######
for i, data in enumerate(dataset):
	print('%d / %d images' % (i+1, dataset_size)) 
	feat_map = model.module.netE.forward(Variable(data['image'].cuda(), volatile=True), data['inst'].cuda())
Beispiel #15
0
import os
import sys
from collections import OrderedDict
from options.train_options import TrainOptions
import data
from data.base_dataset import repair_data
from util.iter_counter import IterationCounter
from util.visualizer import Visualizer
from trainers.pix2pix_trainer import Pix2PixTrainer
from options.test_options import TestOptions
import tqdm
from util import html
from util.util import tensor2im, tensor2label

# parse options
opt = TrainOptions().parse()

# print options to help debugging
print(' '.join(sys.argv))

# load the dataset
dataloader = data.create_dataloader(opt)

# create trainer for our model
trainer = Pix2PixTrainer(opt)

# create tool for counting iterations
iter_counter = IterationCounter(opt, len(dataloader))

# create tool for visualization
visualizer = Visualizer(opt)
Beispiel #16
0
def main():
    opt = TrainOptions().parse() 
    train_history = TrainHistory()
    checkpoint = Checkpoint(opt)
    os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_id

    """Architecture"""
    net = MakeLinearModel(1024,16)
    net = torch.nn.DataParallel(net).cuda()

    """Uploading Mean and SD"""
    path_to_data = '.../multi-view-pose-estimation/dataset/'

    #mean and sd of 2d poses in training dataset
    Mean_2D = np.loadtxt(path_to_data + 'Mean_2D.txt')
    Mean_2D = Mean_2D.astype('float32')
    Mean_2D = torch.from_numpy (Mean_2D)

    Mean_Delta = np.loadtxt(path_to_data + 'Mean_Delta.txt')
    Mean_Delta = Mean_Delta.astype('float32')
    Mean_Delta = torch.from_numpy (Mean_Delta)

    SD_2D = np.loadtxt(path_to_data + 'SD_2D.txt')
    SD_2D = SD_2D.astype('float32')
    SD_2D = torch.from_numpy (SD_2D)

    SD_Delta = np.loadtxt(path_to_data + 'SD_Delta.txt')
    SD_Delta = SD_Delta.astype('float32')
    SD_Delta = torch.from_numpy (SD_Delta)

    """Loading Data"""
    train_list = 'train_list.txt' 
    train_loader = torch.utils.data.DataLoader(
        data.PtsList(train_list, is_train=True),
        batch_size=opt.bs, shuffle=True,
        num_workers=opt.nThreads, pin_memory=True)

    val_list = 'valid_list.txt'
    val_loader = torch.utils.data.DataLoader(
        data.PtsList(val_list,is_train=False ),
        batch_size=opt.bs, shuffle=False,
        num_workers=opt.nThreads, pin_memory=True)


    """optimizer"""
    optimizer = torch.optim.Adam( net.parameters(), lr=opt.lr, betas=(0.9,0.999), weight_decay=0)

    """training and validation"""
    for epoch in range(0, opt.nEpochs):
        adjust_learning_rate(optimizer, epoch, opt.lr)

        # train for one epoch
        train_loss = train(train_loader, net, Mean_2D, Mean_Delta, SD_2D, SD_Delta, optimizer, epoch, opt)

        # evaluate on validation set
        val_loss,val_pckh = validate(val_loader, net, Mean_2D, Mean_Delta, SD_2D, SD_Delta, epoch, opt)

        # update training history
        e = OrderedDict( [('epoch', epoch)] )
        lr = OrderedDict( [('lr', opt.lr)] )
        loss = OrderedDict( [('train_loss', train_loss),('val_loss', val_loss)] )
        pckh = OrderedDict( [('val_pckh', val_pckh)] )
        train_history.update(e, lr, loss, pckh)
        checkpoint.save_checkpoint(net, train_history, 'best-single.pth.tar')
Beispiel #17
0
### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
from options.train_options import TrainOptions
from data.data_loader import CreateDataLoader
from models.models import create_model
import numpy as np
import os

opt = TrainOptions().parse()
opt.nThreads = 1
opt.batchSize = 1
opt.serial_batches = True
opt.no_flip = True
opt.instance_feat = True

name = 'features'
save_path = os.path.join(opt.checkpoints_dir, opt.name)

############ Initialize #########
data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()
dataset_size = len(data_loader)
model = create_model(opt)

########### Encode features ###########
reencode = True
if reencode:
    features = {}
    for label in range(opt.label_nc):
        features[label] = np.zeros((0, opt.feat_num + 1))
    for i, data in enumerate(dataset):
"""
import time

from datasets.get_cityscapes import get_loaders_cityscapes
from options.train_options import TrainOptions
from data import create_dataset
from models import create_model
from util.evaluation import evaluate
from util.transforms import joint_transform_train
from util.util import calculate_class_weights
from util.visualizer import Visualizer
from pathlib import Path
import torch

if __name__ == '__main__':
    opt = TrainOptions().parse()  # get training options
    # dataset = create_dataset(opt)  # create a dataset given opt.dataset_mode and other options
    if opt.run_colab:
        root_path = Path("/content/data/")
    else:
        root_path = Path(
            "/Users/laurenssamson/Documents/Projects/data/Cityscapes/leftImg8bit_trainvaltest"
        )
    training_loader, val_loader, train_eval_loader = get_loaders_cityscapes(
        root_path, opt)
    opt.class_weights = calculate_class_weights(val_loader, opt.output_nc)
    model = create_model(
        opt)  # create a model given opt.model and other options
    model.setup(
        opt)  # regular setup: load and print networks; create schedulers
    visualizer = Visualizer(
Beispiel #19
0
def train_main(raw_args=None):
    # print(torch.backends.cudnn.benchmark)
    opt = TrainOptions().parse(raw_args)  # get training options
    if opt.debug_mode:
        import multiprocessing
        multiprocessing.set_start_method('spawn', True)
        opt.num_threads = 0

    dataset = create_dataset(
        opt)  # create a dataset given opt.dataset_mode and other options
    dataset_size = len(dataset)  # get the number of images in the dataset.
    print('The number of training images = %d' % dataset_size)

    existing_epochs = glob.glob(opt.checkpoints_dir + '/' + opt.name +
                                '/*[0-9]_net_G_A.pth')
    if opt.restart_training and len(existing_epochs) > 0:
        opt.epoch = int(
            os.path.splitext(os.path.basename(
                existing_epochs[-1]))[0].split('_')[0])
        opt.epoch_count = opt.epoch + 1

    plot_losses_from_log_files(opt,
                               dataset_size,
                               domain=['A', 'B'],
                               specified=['G', 'D', 'cycle'])

    model = create_model(
        opt)  # create a model given opt.model and other options
    model.setup(
        opt)  # regular setup: load and print networks; create schedulers
    visualizer = Visualizer(
        opt)  # create a visualizer that display/save images and plots
    total_iters = 0  # the total number of training iterations

    for epoch in range(
            opt.epoch_count, opt.niter + opt.niter_decay + 1
    ):  # outer loop for different epochs; we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>
        epoch_start_time = time.time()  # timer for entire epoch
        iter_data_time = time.time()  # timer for data loading per iteration
        epoch_iter = 0  # the number of training iterations in current epoch, reset to 0 every epoch

        for i, data in enumerate(dataset):  # inner loop within one epoch
            iter_start_time = time.time(
            )  # timer for computation per iteration
            if total_iters % opt.print_freq == 0:
                t_data = iter_start_time - iter_data_time
            visualizer.reset()
            total_iters += opt.batch_size
            epoch_iter += opt.batch_size
            model.set_input(
                data)  # unpack data from dataset and apply preprocessing
            model.optimize_parameters(
            )  # calculate loss functions, get gradients, update network weights

            if total_iters % opt.display_freq == 0:  # display images on visdom and save images to a HTML file
                save_result = total_iters % opt.update_html_freq == 0
                model.compute_visuals()
                visualizer.display_current_results(model.get_current_visuals(),
                                                   epoch, save_result)

            if total_iters % opt.print_freq == 0:  # print training losses and save logging information to the disk
                losses = model.get_current_losses()
                t_comp = (time.time() - iter_start_time) / opt.batch_size
                visualizer.print_current_losses(epoch, epoch_iter, losses,
                                                t_comp, t_data)
                if opt.display_id > 0:
                    visualizer.plot_current_losses(
                        epoch,
                        float(epoch_iter) / dataset_size, losses)

            if total_iters % opt.save_latest_freq == 0:  # cache our latest model every <save_latest_freq> iterations
                print('saving the latest model (epoch %d, total_iters %d)' %
                      (epoch, total_iters))
                save_suffix = 'iter_%d' % total_iters if opt.save_by_iter else 'latest'
                model.save_networks(save_suffix)

            iter_data_time = time.time()
        if epoch % opt.save_epoch_freq == 0:  # cache our model every <save_epoch_freq> epochs
            print('saving the model at the end of epoch %d, iters %d' %
                  (epoch, total_iters))
            model.save_networks('latest')
            model.save_networks(epoch)

        print('End of epoch %d / %d \t Time Taken: %d sec' %
              (epoch, opt.niter + opt.niter_decay,
               time.time() - epoch_start_time))
        model.update_learning_rate(
        )  # update learning rates at the end of every epoch.
# Generates a ton of images and ranks them in order of
# discriminator loss


def plotTensor(im):
    im = util.tensor2im(im)
    plotim(im)


def plotim(im):
    plt.imshow(im)
    plt.show()


if __name__ == '__main__':
    opt = TrainOptions().parse()
    opt.no_flip = True
    opt.resize_or_crop = 'none'
    opt.dataset_mode = 'auto'

    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)

    dirname = '12_nol1'

    model = create_model(opt)
    model.setup(opt)
    total_steps = 0

    chkpt_D = torch.load('checkpoints/sv_nlayers5_ranker/2_net_D.pth')
Beispiel #21
0
def main():
    plt.ion()  # 开启interactive mode,便于连续plot
    opt = TrainOptions().parse()
    # 用于计算的设备 CPU or GPU
    device = torch.device("cuda" if USE_CUDA else "cpu")

    # 定义判别器与生成器的网络
    #net_d = NLayerDiscriminator(opt.output_nc, opt.ndf, n_layers=3)#batchnorm
    #net_d = Discriminator(opt.output_nc)
    net_d_ct =networks.define_D(opt.output_nc, opt.ndf, opt.netD,
                                            opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, opt.gpu_ids)
    net_d_dr =networks.define_D(opt.input_nc, opt.ndf, 'ProjNet',
                                             opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, opt.gpu_ids)
    net_g_dr=networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG, opt.norm,
                                        not opt.no_dropout, opt.init_type, opt.init_gain, opt.gpu_ids)

    #net_g = CTGenerator(opt.input_nc, opt.output_nc, opt.ngf, n_blocks=6)
    net_g_ct = networks.define_G(1, 65, opt.ngf, 'CTnet', opt.norm,
                      not opt.no_dropout, opt.init_type, opt.init_gain, opt.gpu_ids)
    # init_weights(net_d_dr)
    # init_weights(net_d_ct)
    # init_weights(net_g_dr)
    # init_weights(net_g_ct)
    net_d_ct.to(device)
    net_d_dr.to(device)
    net_g_dr.to(device)
    net_g_ct.to(device)

    one = torch.FloatTensor([1])
    mone = one * -1
    one = one.to(device)
    mone= mone.to(device)
    #summary(net_g_dr, (2,65, 65,65))
    if load_net:
        # save_filename = 'net_d%s.pth' % epoch_start
        # save_path = os.path.join('./check/', save_filename)
        # load_network(net_d, save_path)
        save_filename = 'net_g%s.pth' % epoch_start
        save_path = os.path.join('./check/', save_filename)
        load_network(net_g_ct, save_path)
    # 损失函数
    #criterion = nn.BCELoss().to(device)
    criterion = nn.MSELoss().to(device)
    criterion1 = nn.L1Loss().to(device)

    # 优化器
    optimizer_d = torch.optim.Adam(itertools.chain(net_d_ct.parameters(),net_d_dr.parameters()), lr=0.0001,betas=[0.5,0.9])
    optimizer_g = torch.optim.Adam(itertools.chain(net_g_ct.parameters(),net_g_dr.parameters()), lr=0.0001,betas=[0.5,0.9])

    #optimizer_d = torch.optim.AdamW(net_d.parameters(), lr=0.0001)
    #optimizer_g = torch.optim.AdamW(net_g.parameters(), lr=0.0001)
    #one = torch.FloatTensor([1]).cuda()
    #mone = one * -1
    dataset = create_dataset(opt)  # create a dataset given opt.dataset_mode and other options
    dataset_size = len(dataset)  # get the number of images in the dataset.
    def gensample():
        for image in enumerate(dataset):
            yield image
    gen=gensample()
    ii=0

    for epoch in range(MAX_EPOCH):
        # 为真实数据加上噪声
        for it,data in enumerate(dataset):


            #载入数据
            dr_real = autograd.Variable(data['A'].cuda())
            dr_real=dr_real.squeeze(0)
            ct_real = autograd.Variable(data['B'].cuda())
            ct_real=ct_real.squeeze(0)
            # 训练
            #内循环
            freeze_params(net_g_ct)
            freeze_params(net_g_dr)
            unfreeze_params(net_d_ct)
            unfreeze_params(net_d_dr)

            ct_fake = autograd.Variable(net_g_ct(dr_real).data)
            dr_fake = autograd.Variable(net_g_dr(ct_real).data)

            optimizer_d.zero_grad()
            loss_dsc_realct = net_d_ct(ct_real).mean()
            #loss_dsc_realct.backward()
            loss_dsc_fakect = net_d_ct(ct_fake.detach()).mean()
            #loss_dsc_fakect.backward()
            gradient_penalty_ct = calc_gradient_penalty(net_d_ct, ct_real, ct_fake)
            #gradient_penalty_ct.backward()
            loss_d_ct=loss_dsc_fakect - loss_dsc_realct+gradient_penalty_ct
            loss_d_ct.backward()
            Wd_ct=loss_dsc_realct-loss_dsc_fakect

            loss_dsc_realdr = net_d_dr(dr_real).mean()
            #loss_dsc_realdr.backward()
            loss_dsc_fakedr = net_d_dr(dr_fake.detach()).mean()
            #loss_dsc_fakedr.backward()
            gradient_penalty_dr = calc_gradient_penalty(net_d_dr, dr_real, dr_fake)
            #gradient_penalty_dr.backward()
            loss_d_dr = loss_dsc_fakedr - loss_dsc_realdr + gradient_penalty_dr
            loss_d_dr.backward()
            Wd_dr = loss_dsc_realdr - loss_dsc_fakedr
            optimizer_d.step()
            if it%CRITIC_ITERS==0:
            #if True:
                unfreeze_params(net_g_ct)
                freeze_params(net_d_ct)
                unfreeze_params(net_g_dr)
                freeze_params(net_d_dr)


                ct_fake_g=net_g_ct(dr_real)
                dr_fake_g=net_g_dr(ct_real)

                #外循环ct_dr
                # optimizer_g.zero_grad()
                loss_out_dr=criterion1(net_g_ct(dr_fake_g),ct_real)
                # loss_out_dr.backward()
                # optimizer_g.step()
                #net_g_ct.load_state_dict(dict_g_ct)

                #外循环dr_ct
                # optimizer_g.zero_grad()
                loss_out_ct=criterion1(net_g_dr(ct_fake_g),dr_real)
                # loss_out_ct.backward()
                # optimizer_g.step()

                #内循环gan
                loss_g_ct = - net_d_ct(ct_fake_g).mean()
                #loss_g_ct.backward()
                loss_g_dr = - net_d_dr(dr_fake_g).mean()
                #loss_g_dr.backward()
                loss_gan = loss_out_dr + loss_out_ct
                #loss_gan=loss_out_dr+loss_out_ct+loss_g_ct+loss_g_dr
                #loss_gan = criterion(net_g_ct(dr_real), ct_real) + criterion(net_g_dr(ct_real), dr_real)
                optimizer_g.zero_grad()
                loss_gan.backward()
                optimizer_g.step()

            if it%1==0:
                fk_im=toimage(torch.irfft(torch.roll(torch.roll(ct_fake,-32,2),-32,3).permute(1, 2, 3, 0), 2, onesided=False)[32, :, :].unsqueeze(0))

                #fk_im=toimage(ct_fake[0,32,:,:].unsqueeze(0))
                 #img_test.append(fk_im)
                save_filenamet = 'fakect%s.bmp' % int(epoch/dataset_size)
                img_path = os.path.join('./check/img/', save_filenamet)
                save_image(fk_im, img_path)

                rel_im=toimage(torch.irfft(torch.roll(torch.roll(ct_real,-32,2),-32,3).permute(1, 2, 3, 0), 2, onesided=False)[32, :, :].unsqueeze(0))
                #rel_im = toimage(ct_real[0,32, :, :].unsqueeze(0))
                # img_test.append(rel_im)
                save_image(rel_im, os.path.join('./check/img/', 'Realct%s.bmp' % int(epoch)))

                fake_im = toimage(torch.irfft(torch.roll(torch.roll(dr_fake, -128, 2), -128, 3).permute(1, 2, 3, 0), 2, onesided=False))
                #fake_im =toimage(dr_fake.squeeze(0))
                save_image(fake_im, os.path.join('./check/img/', 'fakedr%s.bmp' % int(epoch)))
                ceshi(net_g_ct)
                message = '(epoch: %d, iters: %d, D_ct: %.3f;[real:%.3f;fake:%.3f], G_ct: %.3f, D_dr: %.3f, G_dr: %.3f) ' % (int(epoch), ii,loss_d_ct,loss_dsc_realct,loss_dsc_fakect,loss_g_ct,loss_d_dr,loss_g_dr)
                print(message)

        save_filename = 'net_g%s.pth' % epoch
        save_path = os.path.join('./check/', save_filename)
        torch.save(net_g_ct.cpu().state_dict(), save_path)
        net_g_ct.cuda(0)
from options.train_options import TrainOptions
from data import create_dataset
from models import create_model
from util.visualizer import Visualizer
from util import util
from math import log10
from util.visualizer import save_sr_result
import time
import os
import pytorch_ssim

if __name__ == '__main__':
    # 加载设置
    opt = TrainOptions().parse()
    opt.upscale_factor = 4

    # 设置显示验证结果存储的设置
    web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'val')
    image_dir = os.path.join(web_dir, 'images')
    util.mkdirs([web_dir, image_dir])

    # 加载训练数据集
    dataset_train = create_dataset(opt)
    dataset_train_size = len(dataset_train)
    print('The number of training images = %d' % dataset_train_size)

    # 加载验证数据集
    opt1 = TrainOptions().parse()
    opt1.upscale_factor = 4
    opt1.phase = "val"
    opt1.batch_size = 1
Beispiel #23
0
from data.data_loader import CreateDataLoader
import util.util as util
from util.visualizer import Visualizer
import os
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
import scipy.io as sio

# Set random seed
torch.manual_seed(0)
np.random.seed(0)

# Extract the options
opt = TrainOptions().parse()

# For testing  the neural networks, manually edit/add options below
opt.gan_mode = 'none'  # 'wgangp', 'lsgan', 'vanilla', 'none'

# Set the input dataset
opt.dataset_mode = 'CIFAR10'  # Current dataset:  CIFAR10, CelebA

if opt.dataset_mode in ['CIFAR10', 'CIFAR100']:
    opt.n_layers_D = 3
    opt.label_smooth = 1  # Label smoothing factor (for lsgan and vanilla gan only)
    opt.n_downsample = 2  # Downsample times
    opt.n_blocks = 2  # Numebr of residual blocks
    opt.first_kernel = 5  # The filter size of the first convolutional layer in encoder
    opt.batchsize = 128
    opt.n_epochs = 200  # # of epochs without lr decay
Beispiel #24
0
def main():
    cfg = TrainOptions().parse()  # get training options
    cfg.NUM_GPUS = torch.cuda.device_count()
    cfg.batch_size = int(cfg.batch_size / max(1, cfg.NUM_GPUS))
    cfg.phase = 'train'
    launch_job(cfg=cfg, init_method=cfg.init_method, func=train)
Beispiel #25
0
    def __len__(self):
        return len(self.path_to_images)

    def __getitem__(self, idx):
        idx = idx.tolist() if torch.is_tensor(idx) else idx

        img_name = self.path_to_images[idx]
        image = PIL.open(img_name)
        
        if self.transform:
            sample = self.transform(sample)
        return sample


if __name__ == '__main__':
    opt = TrainOptions().parse()

    opt.dataroot = './dataset/ilsvrc2012/%s/' % opt.phase
    colorization_dataset = Colorization_Dataset(opt.dataroot,
                                               transform=transforms.Compose([
                                                   transforms.RandomChoice([transforms.Resize(opt.loadSize, interpolation=1),
                                                                            transforms.Resize(opt.loadSize, interpolation=2),
                                                                            transforms.Resize(opt.loadSize, interpolation=3),
                                                                            transforms.Resize((opt.loadSize, opt.loadSize), interpolation=1),
                                                                            transforms.Resize((opt.loadSize, opt.loadSize), interpolation=2),
                                                                            transforms.Resize((opt.loadSize, opt.loadSize), interpolation=3)]),
                                                   transforms.RandomChoice([transforms.RandomResizedCrop(opt.fineSize, interpolation=1),
                                                                            transforms.RandomResizedCrop(opt.fineSize, interpolation=2),
                                                                            transforms.RandomResizedCrop(opt.fineSize, interpolation=3)]),
                                                   transforms.RandomHorizontalFlip(),
                                                   transforms.ToTensor()]))
Beispiel #26
0
# ========================================================
# Compositional GAN
# Train different components of the paired/unpaired models
# By Samaneh Azadi
# ========================================================

import time
from options.train_options import TrainOptions
from data.data_loader import CreateDataLoader
from models.models import create_model
from util.visualizer import Visualizer

opt = TrainOptions().parse()
opt.phase = 'train'
data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()
dataset_size = len(data_loader)

print('#training images = %d' % dataset_size)
model = create_model(opt)
visualizer = Visualizer(opt)

print('Train the STN models')
# Only for the unpaired case
if opt.dataset_mode == 'comp_decomp_unaligned' and opt.niterSTN:
    opt.isPretrain = False
    visualizer = Visualizer(opt)
    total_steps = 0
    for epoch in range(opt.epoch_count, opt.niterSTN + 1):
        epoch_start_time = time.time()
        epoch_iter = 0
Beispiel #27
0
See options/base_options.py and options/train_options.py for more training options.
See training and test tips at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/tips.md
See frequently asked questions at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/qa.md
"""
import time
from options.train_options import TrainOptions
from options.val_options import ValOptions
from data import create_dataset
from models import create_model
from util.visualizer_sssd import Visualizer, save_images, cal_scores
from copy import deepcopy
import os
from util import html

if __name__ == '__main__':
    opt = TrainOptions().parse()  # get training options
    # opt.sample_nums = 100
    # opt.dataset_mode = "alignedm2md"
    dataset = create_dataset(
        opt)  # create a dataset given opt.dataset_mode and other options
    dataset_size = len(dataset)  # get the number of images in the dataset.
    print('The number of training images = %d' % dataset_size)

    model = create_model(
        opt)  # create a model given opt.model and other options
    model.setup(
        opt)  # regular setup: load and print networks; create schedulers
    visualizer = Visualizer(
        opt)  # create a visualizer that display/save images and plots
    total_iters = 0  # the total number of training iterations
Beispiel #28
0
import time
from options.train_options import TrainOptions
from data.data_loader import CreateDataLoader
from models.models import create_model


opt = TrainOptions().parse()
data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()
dataset_size = len(data_loader)
print('#training images = %d' % dataset_size)

model = create_model(opt)
total_steps = 0

for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
    epoch_start_time = time.time()
    epoch_iter = 0

    for i, data in enumerate(dataset):
        iter_start_time = time.time()

        total_steps += opt.batchSize
        epoch_iter += opt.batchSize
        model.set_input(data)       
        model.optimize_parameters()

        
        if total_steps % opt.print_freq == 0:            
            errors = model.get_current_errors()
            t = (time.time() - iter_start_time) / opt.batchSize
Beispiel #29
0
def main():

    opt = TrainOptions()
    args = opt.initialize()

    _t = {'iter time': Timer()}

    model_name = args.source + '_to_' + args.target
    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)
        os.makedirs(os.path.join(args.snapshot_dir, 'logs'))
    opt.print_options(args)

    sourceloader, targetloader = CreateSrcDataLoader(
        args), CreateTrgDataLoader(args)
    targetloader_iter, sourceloader_iter = iter(targetloader), iter(
        sourceloader)

    model, optimizer = CreateModel(args)
    model_D, optimizer_D = CreateDiscriminator(args)

    start_iter = 0
    if args.restore_from is not None:
        start_iter = int(args.restore_from.rsplit('/', 1)[1].rsplit('_')[1])

    train_writer = tensorboardX.SummaryWriter(
        os.path.join(args.snapshot_dir, "logs", model_name))

    bce_loss = torch.nn.BCEWithLogitsLoss()

    cudnn.enabled = True
    cudnn.benchmark = True
    model.train()
    model.cuda()
    model_D.train()
    model_D.cuda()
    loss = [
        'loss_seg_src', 'loss_seg_trg', 'loss_D_trg_fake', 'loss_D_src_real',
        'loss_D_trg_real'
    ]
    _t['iter time'].tic()
    for i in range(start_iter, args.num_steps):

        model.adjust_learning_rate(args, optimizer, i)
        model_D.adjust_learning_rate(args, optimizer_D, i)

        optimizer.zero_grad()
        optimizer_D.zero_grad()
        for param in model_D.parameters():
            param.requires_grad = False

        src_img, src_lbl, _, _ = sourceloader_iter.next()
        src_img, src_lbl = Variable(src_img).cuda(), Variable(
            src_lbl.long()).cuda()
        src_seg_score = model(src_img, lbl=src_lbl)
        loss_seg_src = model.loss
        loss_seg_src.backward()

        if args.data_label_folder_target is not None:
            trg_img, trg_lbl, _, _ = targetloader_iter.next()
            trg_img, trg_lbl = Variable(trg_img).cuda(), Variable(
                trg_lbl.long()).cuda()
            trg_seg_score = model(trg_img, lbl=trg_lbl)
            loss_seg_trg = model.loss
        else:
            trg_img, _, name = targetloader_iter.next()
            trg_img = Variable(trg_img).cuda()
            trg_seg_score = model(trg_img)
            loss_seg_trg = 0

        outD_trg = model_D(F.softmax(trg_seg_score), 0)
        loss_D_trg_fake = model_D.loss

        loss_trg = args.lambda_adv_target * loss_D_trg_fake + loss_seg_trg
        loss_trg.backward()

        for param in model_D.parameters():
            param.requires_grad = True

        src_seg_score, trg_seg_score = src_seg_score.detach(
        ), trg_seg_score.detach()

        outD_src = model_D(F.softmax(src_seg_score), 0)
        loss_D_src_real = model_D.loss / 2
        loss_D_src_real.backward()

        outD_trg = model_D(F.softmax(trg_seg_score), 1)
        loss_D_trg_real = model_D.loss / 2
        loss_D_trg_real.backward()

        optimizer.step()
        optimizer_D.step()

        for m in loss:
            train_writer.add_scalar(m, eval(m), i + 1)

        if (i + 1) % args.save_pred_every == 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                os.path.join(args.snapshot_dir,
                             '%s_' % (args.source) + str(i + 1) + '.pth'))

        if (i + 1) % args.print_freq == 0:
            _t['iter time'].toc(average=False)
            print('[it %d][src seg loss %.4f][lr %.4f][%.2fs]' % \
                    (i + 1, loss_seg_src.data, optimizer.param_groups[0]['lr']*10000, _t['iter time'].diff))
            if i + 1 > args.num_steps_stop:
                print('finish training')
                break
            _t['iter time'].tic()
Beispiel #30
0
import time
from options.train_options import TrainOptions
from data.data_loader import CreateDataLoader
from models.models import create_model
import ntpath
import os
from util import util
import shutil
import numpy as np
#from util.visualizer import Visualizer

opt = TrainOptions().parse()


# Load dataset_train
data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()
dataset_size = len(data_loader)
print('#training images = %d' % dataset_size)

# Load dataset_test
opt.batchSize = 1
opt.phase = 'test'
data_loader_test = CreateDataLoader(opt)
dataset_test = data_loader_test.load_data()
dataset_size = len(data_loader_test)
print('#testing images = %d' % dataset_size)

# Create or clear dir for saving generated samples
if os.path.exists(opt.testing_path):
    shutil.rmtree(opt.testing_path)
Beispiel #31
0
def main():
    opt = TrainOptions().parse()

    # Determine validation step options that might differ from training
    if opt.data == 'KTH':
        val_pick_mode = 'Slide'
        val_gpu_ids = [opt.gpu_ids[0]]
        val_batch_size = 1
    elif opt.data in ['UCF', 'HMDB51', 'S1M']:
        val_pick_mode = 'First'
        val_gpu_ids = opt.gpu_ids
        val_batch_size = opt.batch_size / 2
    else:
        raise ValueError('Dataset [%s] not recognized.' % opt.data)

    expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
    makedir(expr_dir)
    tb_dir = os.path.join(opt.tensorboard_dir, opt.name)
    makedir(tb_dir)

    file_name = os.path.join(expr_dir, 'train_opt.txt')
    with open(file_name, 'wt') as opt_file:
        listopt(opt, opt_file)

    log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
    print('after reading options')
    include_following = (opt.model_type != 'mcnet')
    data_loader = CustomDataLoader(opt.data, opt.c_dim, opt.dataroot,
                                   opt.textroot, opt.video_list, opt.K, opt.T,
                                   opt.backwards, opt.flip, opt.pick_mode,
                                   opt.image_size, include_following, opt.skip,
                                   opt.F, opt.batch_size, opt.serial_batches,
                                   opt.nThreads)
    print(data_loader.name())
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    print('# training videos = %d' % dataset_size)

    env = create_environment(
        opt.model_type, opt.gf_dim, opt.c_dim, opt.gpu_ids, True,
        opt.checkpoints_dir, opt.name, opt.K, opt.T, opt.F, opt.image_size,
        opt.batch_size, opt.which_update, opt.comb_type, opt.shallow, opt.ks,
        opt.num_block, opt.layers, opt.kf_dim, opt.enable_res, opt.rc_loc,
        opt.no_adversarial, opt.alpha, opt.beta, opt.D_G_switch, opt.margin,
        opt.lr, opt.beta1, opt.sn, opt.df_dim, opt.Ip, opt.continue_train,
        opt.comb_loss)

    total_updates = env.start_update
    writer = SummaryWriter(log_dir=os.path.join(opt.tensorboard_dir, opt.name))

    while True:
        for data in dataset:
            iter_start_time = time.time()

            # Enable losses on intermediate and final predictions partway through training
            if total_updates >= opt.inter_sup_update:
                env.enable_inter_loss()
            if total_updates >= opt.final_sup_update:
                env.enable_final_loss()

            # Update model
            total_updates += 1
            env.set_inputs(data)
            env.optimize_parameters()

            if total_updates % opt.print_freq == 0:
                errors = env.get_current_errors()
                t = (time.time() - iter_start_time) / opt.batch_size
                writer.add_scalar('iter_time', t, total_updates)
                for key in errors.keys():
                    writer.add_scalar('loss/%s' % (key), errors[key],
                                      total_updates)
                print_current_errors(log_name, total_updates, errors, t)

            if total_updates % opt.display_freq == 0:
                visuals = env.get_current_visuals()
                grid = visual_grid(visuals, opt.K, opt.T)
                writer.add_image('current_batch', grid, total_updates)

            if total_updates % opt.save_latest_freq == 0:
                print('saving the latest model (update %d)' % total_updates)
                env.save('latest', total_updates)
                env.save(total_updates, total_updates)

            if total_updates % opt.validate_freq == 0:
                psnr_plot, ssim_plot, grid = val(
                    opt.c_dim, opt.data, opt.T * 2, opt.dataroot, opt.textroot,
                    'val_data_list.txt', opt.K, opt.backwards, opt.flip,
                    val_pick_mode, opt.image_size, val_gpu_ids, opt.model_type,
                    opt.skip, opt.F, val_batch_size, True, opt.nThreads,
                    opt.gf_dim, False, opt.checkpoints_dir, opt.name,
                    opt.no_adversarial, opt.alpha, opt.beta, opt.D_G_switch,
                    opt.margin, opt.lr, opt.beta1, opt.sn, opt.df_dim, opt.Ip,
                    opt.comb_type, opt.comb_loss, opt.shallow, opt.ks,
                    opt.num_block, opt.layers, opt.kf_dim, opt.enable_res,
                    opt.rc_loc, opt.continue_train, 'latest')
                writer.add_image('psnr', psnr_plot, total_updates)
                writer.add_image('ssim', ssim_plot, total_updates)
                writer.add_image('samples', grid, total_updates)

            if total_updates >= opt.max_iter:
                env.save('latest', total_updates)
                break

        if total_updates >= opt.max_iter:
            break
Beispiel #32
0
### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 
### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
import time
from collections import OrderedDict
from options.train_options import TrainOptions
from data.data_loader import CreateDataLoader
from models.models import create_model
import util.util as util
from util.visualizer import Visualizer
import os
import numpy as np
import torch
from torch.autograd import Variable

opt = TrainOptions().parse()
iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')
if opt.continue_train:
    try:
        start_epoch, epoch_iter = np.loadtxt(iter_path , delimiter=',', dtype=int)
    except:
        start_epoch, epoch_iter = 1, 0
    print('Resuming from epoch %d at iteration %d' % (start_epoch, epoch_iter))        
else:    
    start_epoch, epoch_iter = 1, 0

if opt.debug:
    opt.display_freq = 1
    opt.print_freq = 1
    opt.niter = 1
    opt.niter_decay = 0
    opt.max_dataset_size = 10
def main():
    opt = TrainOptions().parse()
    if opt.sr_dir == '':
        print('sr directory is null.')
        exit()
    sr_pretrain_dir = os.path.join(
        opt.exp_dir, opt.exp_id, opt.sr_dir + '-' + opt.load_prefix_pose[0:-1])
    if not os.path.isdir(sr_pretrain_dir):
        os.makedirs(sr_pretrain_dir)
    train_history = ASNTrainHistory()
    # print(train_history.lr)
    # exit()
    checkpoint_agent = Checkpoint()
    visualizer = Visualizer(opt)
    visualizer.log_path = sr_pretrain_dir + '/' + 'log.txt'
    train_scale_path = sr_pretrain_dir + '/' + 'train_scales.txt'
    train_rotation_path = sr_pretrain_dir + '/' + 'train_rotations.txt'
    val_scale_path = sr_pretrain_dir + '/' + 'val_scales.txt'
    val_rotation_path = sr_pretrain_dir + '/' + 'val_rotations.txt'

    # with open(visualizer.log_path, 'a+') as log_file:
    #     log_file.write(opt.resume_prefix_pose + '.pth.tar\n')
    # lost_joint_count_path = os.path.join(opt.exp_dir, opt.exp_id, opt.astn_dir, 'joint-count.txt')
    # print("=> log saved to path '{}'".format(visualizer.log_path))
    # if opt.dataset == 'mpii':
    #     num_classes = 16
    os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_id

    print('collecting training scale and rotation distributions ...\n')
    train_scale_distri = read_grnd_distri_from_txt(train_scale_path)
    train_rotation_distri = read_grnd_distri_from_txt(train_rotation_path)
    dataset = MPII('dataset/mpii-hr-lsp-normalizer.json',
                   '/bigdata1/zt53/data',
                   is_train=True,
                   grnd_scale_distri=train_scale_distri,
                   grnd_rotation_distri=train_rotation_distri)
    train_loader = torch.utils.data.DataLoader(dataset,
                                               batch_size=opt.bs,
                                               shuffle=True,
                                               num_workers=opt.nThreads,
                                               pin_memory=True)

    print('collecting validation scale and rotation distributions ...\n')
    val_scale_distri = read_grnd_distri_from_txt(val_scale_path)
    val_rotation_distri = read_grnd_distri_from_txt(val_rotation_path)
    dataset = MPII('dataset/mpii-hr-lsp-normalizer.json',
                   '/bigdata1/zt53/data',
                   is_train=False,
                   grnd_scale_distri=val_scale_distri,
                   grnd_rotation_distri=val_rotation_distri)
    val_loader = torch.utils.data.DataLoader(dataset,
                                             batch_size=opt.bs,
                                             shuffle=False,
                                             num_workers=opt.nThreads,
                                             pin_memory=True)

    agent = model.create_asn(chan_in=256,
                             chan_out=256,
                             scale_num=len(dataset.scale_means),
                             rotation_num=len(dataset.rotation_means),
                             is_aug=True)
    agent = torch.nn.DataParallel(agent).cuda()
    optimizer = torch.optim.RMSprop(agent.parameters(),
                                    lr=opt.lr,
                                    alpha=0.99,
                                    eps=1e-8,
                                    momentum=0,
                                    weight_decay=0)
    # optimizer = torch.optim.Adam(agent.parameters(), lr=opt.agent_lr)
    if opt.load_prefix_sr == '':
        checkpoint_agent.save_prefix = sr_pretrain_dir + '/'
    else:
        checkpoint_agent.save_prefix = sr_pretrain_dir + '/' + opt.load_prefix_sr
        checkpoint_agent.load_prefix = checkpoint_agent.save_prefix[0:-1]
        checkpoint_agent.load_checkpoint(agent, optimizer, train_history)
        # adjust_lr(optimizer, opt.lr)
        # lost_joint_count_path = os.path.join(opt.exp_dir, opt.exp_id, opt.asdn_dir, 'joint-count-finetune.txt')
    print('agent: ', type(optimizer), optimizer.param_groups[0]['lr'])

    if opt.dataset == 'mpii':
        num_classes = 16
    hg = model.create_hg(num_stacks=2,
                         num_modules=1,
                         num_classes=num_classes,
                         chan=256)
    hg = torch.nn.DataParallel(hg).cuda()
    if opt.load_prefix_pose == '':
        print('please input the checkpoint name of the pose model')
        exit()
    checkpoint_hg = Checkpoint()
    # checkpoint_hg.save_prefix = os.path.join(opt.exp_dir, opt.exp_id, opt.resume_prefix_pose)
    checkpoint_hg.load_prefix = os.path.join(opt.exp_dir, opt.exp_id,
                                             opt.load_prefix_pose)[0:-1]
    checkpoint_hg.load_checkpoint(hg)

    logger = Logger(sr_pretrain_dir + '/' + 'training-summary.txt',
                    title='training-summary')
    logger.set_names(['Epoch', 'LR', 'Train Loss', 'Val Loss'])
    """training and validation"""
    start_epoch = 0
    if opt.load_prefix_sr != '':
        start_epoch = train_history.epoch[-1]['epoch'] + 1
    for epoch in range(start_epoch, opt.nEpochs):
        # train for one epoch
        train_loss = train(train_loader, hg, agent, optimizer, epoch,
                           visualizer, opt)
        val_loss = validate(val_loader, hg, agent, epoch, visualizer, opt)
        # update training history
        e = OrderedDict([('epoch', epoch)])
        lr = OrderedDict([('lr', optimizer.param_groups[0]['lr'])])
        loss = OrderedDict([('train_loss', train_loss),
                            ('val_loss', val_loss)])
        # pckh = OrderedDict( [('val_pckh', val_pckh)] )
        train_history.update(e, lr, loss)
        # print(train_history.lr[-1]['lr'])
        checkpoint_agent.save_checkpoint(agent,
                                         optimizer,
                                         train_history,
                                         is_asn=True)
        visualizer.plot_train_history(train_history, 'sr')
        logger.append(
            [epoch, optimizer.param_groups[0]['lr'], train_loss, val_loss])
    logger.close()