Пример #1
0
def main():
    opt = TestOptions().parse(save=False)
    opt.nThreads = 1  # test code only supports nThreads = 1
    opt.batchSize = 1  # test code only supports batchSize = 1
    opt.serial_batches = True  # no shuffle
    opt.no_flip = True  # no flip

    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    visualizer = Visualizer(opt)
    # create website
    web_dir = os.path.join(opt.results_dir, opt.name,
                           '%s_%s' % (opt.phase, opt.which_epoch))
    webpage = html.HTML(
        web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' %
        (opt.name, opt.phase, opt.which_epoch))

    # test
    if not opt.engine and not opt.onnx:
        model = create_model(opt)
        if opt.data_type == 16:
            model.half()
        elif opt.data_type == 8:
            model.type(torch.uint8)

        if opt.verbose:
            print(model)
    else:
        from run_engine import run_trt_engine, run_onnx

    for i, data in enumerate(dataset):
        if i >= opt.how_many:
            break
        if opt.data_type == 16:
            data['label'] = data['label'].half()
            data['inst'] = data['inst'].half()
        elif opt.data_type == 8:
            data['label'] = data['label'].uint8()
            data['inst'] = data['inst'].uint8()
        if opt.export_onnx:
            print("Exporting to ONNX: ", opt.export_onnx)
            assert opt.export_onnx.endswith(
                "onnx"), "Export model file should end with .onnx"
            torch.onnx.export(model, [data['label'], data['inst']],
                              opt.export_onnx,
                              verbose=True)
            exit(0)
        minibatch = 1
        if opt.engine:
            generated = run_trt_engine(opt.engine, minibatch,
                                       [data['label'], data['inst']])
        elif opt.onnx:
            generated = run_onnx(opt.onnx, opt.data_type, minibatch,
                                 [data['label'], data['inst']])
        else:
            generated = model.inference(data['label'], data['inst'],
                                        data['image'])

        visuals = OrderedDict([
            ('input_label', util.tensor2label(data['label'][0], opt.label_nc)),
            ('synthesized_image', util.tensor2im(generated.data[0]))
        ])
        img_path = data['path']
        print('process image... %s' % img_path)
        visualizer.save_images(webpage, visuals, img_path)

    webpage.save()
Пример #2
0
num_batch = 5

# load options
opt = TestEncoderDecoderOptions_V2().parse()
opt.batch_size = batch_size
train_opt = io.load_json(os.path.join('checkpoints', opt.id, 'train_opt.json'))
preserved_opt = {'gpu_ids', 'batch_size', 'is_train'}
for k, v in train_opt.iteritems():
    if k in opt and (k not in preserved_opt):
        setattr(opt, k, v)

# create model
model = EncoderDecoderFramework_DFN()
model.initialize(opt)
# create data loader
val_loader_iter = iter(CreateDataLoader(opt, split='test'))
# create viusalizer
visualizer = GANVisualizer_V3(opt)

for i in range(num_batch):
    print('[%s] test edge transfer: %d / %d' % (opt.id, i+1, num_batch))
    data = val_loader_iter.next()
    imgs_title = data['img'].expand(batch_size, 3, opt.fine_size, opt.fine_size).cpu()

    for name in ['img', 'seg_map', 'seg_mask', 'pose_map', 'edge_map', 'color_map']:
        # print('%s: %s' % (name, data[name].size()))
        new_1 = torch.cat([data[name]] * batch_size, dim=0).contiguous()
        new_2 = new_1.view(batch_size, batch_size, -1).transpose(0, 1).contiguous().view(new_1.size())
        new_3 = torch.cat([data[name+'_def']] * batch_size, dim=0).contiguous()
        new_3 = new_3.view(batch_size, batch_size, -1).transpose(0, 1).contiguous().view(new_1.size())
        data[name] = new_2
Пример #3
0
def main():
	opt.input_image_root =  './data/%s_%s_%s' % (opt.dataset_name, opt.model_name, opt.which_epoch)
	opt.real_image_root  = '../data/%s/train/face_img' % opt.dataset_name
	opt.label_root       = '../data/%s/train/face_label' % opt.dataset_name


	iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')
	data_loader = CreateDataLoader(opt)
	dataset = data_loader.load_data()
	dataset_size = len(data_loader)
	print('#training images = %d' % dataset_size)

	start_epoch, epoch_iter = 1, 0
	total_steps = (start_epoch - 1) * dataset_size + epoch_iter
	display_delta = total_steps % opt.display_freq
	print_delta = total_steps % opt.print_freq
	save_delta = total_steps % opt.save_latest_freq

	model = create_model(opt)
	# model = model.cuda()
	visualizer = Visualizer(opt)

	for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):
		epoch_start_time = time.time()
		if epoch != start_epoch:
			epoch_iter = epoch_iter % dataset_size
		for i, data in enumerate(dataset, start=epoch_iter):
			iter_start_time = time.time()
			total_steps += opt.batchSize
			epoch_iter += opt.batchSize

			# whether to collect output images
			save_fake = total_steps % opt.display_freq == display_delta

			############## Forward Pass ######################
			losses, generated = model(Variable(data['label']), Variable(data['inst']),
									  Variable(data['image']), Variable(data['feat']), infer=save_fake)

			# sum per device losses
			losses = [torch.mean(x) if not isinstance(x, int) else x for x in losses]
			loss_dict = dict(zip(model.loss_names, losses))

			# calculate final loss scalar
			loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5
			loss_G = loss_dict['G_GAN'] + loss_dict.get('G_GAN_Feat', 0) + loss_dict.get('G_VGG', 0)

			############### Backward Pass ####################
			# update generator weights
			model.optimizer_G.zero_grad()
			loss_G.backward()
			model.optimizer_G.step()

			# update discriminator weights
			model.optimizer_D.zero_grad()
			loss_D.backward()
			model.optimizer_D.step()


			############## Display results and errors ##########
			### print out errors
			if total_steps % opt.print_freq == print_delta:
				errors = {k: v.data if not isinstance(v, int) else v for k, v in loss_dict.items()}
				t = (time.time() - iter_start_time) / opt.batchSize
				visualizer.print_current_errors(epoch, epoch_iter, errors, t)
				visualizer.plot_current_errors(errors, total_steps)

			### display output images
			if save_fake:
				input_image = data['label'].detach()[0][10:,:,:]
				input_label = data['label'].detach()[0][:10,:,:]
				generated_image = generated.detach()[0]
				visuals = OrderedDict([('input_label', util.tensor2label(input_label, opt.label_nc)),
									   ('input_image', util.tensor2im(input_image)),
									   ('synthesized_image', util.tensor2im(generated_image)),
									   ('real_image', util.tensor2im(data['image'][0]))])
				visualizer.display_current_results(visuals, epoch, total_steps)

			### save latest model
			if total_steps % opt.save_latest_freq == save_delta:
				print('saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps))
				model.save('latest')
				np.savetxt(iter_path, (epoch, epoch_iter), delimiter=',', fmt='%d')

			if epoch_iter >= dataset_size:
				break

		# end of epoch
		print('End of epoch %d / %d \t Time Taken: %d sec' %
			  (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))

		### save model for this epoch
		if epoch % opt.save_epoch_freq == 0:
			print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps))
			model.save('latest')
			model.save(epoch)
			np.savetxt(iter_path, (epoch + 1, 0), delimiter=',', fmt='%d')

		### instead of only training the local enhancer, train the entire network after certain iterations
		if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global):
			model.update_fixed_params()

		### linearly decay learning rate after certain iterations
		if epoch > opt.niter:
			model.update_learning_rate()

	torch.cuda.empty_cache()
Пример #4
0
def create_descriptor_vunet():
    '''
    use variational unet feature as descriptor
    '''
    import torch
    from models.vunet_pose_transfer_model import VUnetPoseTransferModel
    from data.data_loader import CreateDataLoader
    from options.pose_transfer_options import TestPoseTransferOptions
    # load image info
    image_info = io.load_json('temp/patch_matching/label/image_info.json')
    # load model
    # opt = TestPoseTransferOptions().parse(ord_str = '--which_model_T vunet --id 7.5 --gpu_ids 0,1,2,3 --batch_size 16', save_to_file = False, display = False, set_gpu = True)
    opt = TestPoseTransferOptions().parse(
        ord_str=
        '--which_model_T vunet --id 7.9 --gpu_ids 0,1,2,3 --batch_size 16',
        save_to_file=False,
        display=False,
        set_gpu=True)
    train_opt = io.load_json(
        os.path.join('checkpoints', opt.id, 'train_opt.json'))
    for k, v in train_opt.iteritems():
        if k in opt and (k not in {'gpu_ids', 'is_train'}):
            setattr(opt, k, v)
    model = VUnetPoseTransferModel()
    model.initialize(opt)
    # create data set
    val_loader = CreateDataLoader(opt, split='test')
    id_list = [[sid_1, sid_2]
               for sid_1, sid_2 in zip(image_info['id_1'], image_info['id_2'])]
    val_loader.dataset.id_list = id_list

    # extract descriptor
    print('extracing descriptors ...')
    desc = {
        'img_1': [],
        'img_2': [],
        'feat1_1': [],
        'feat1_2': [],
        'feat2_1': [],
        'feat2_2': [],
    }
    for i, data in enumerate(tqdm.tqdm(val_loader)):
        model.set_input(data)
        appr_1 = model.get_appearance(model.opt.appearance_type, index='1')
        appr_2 = model.get_appearance(model.opt.appearance_type, index='2')
        pose_1 = model.get_pose(model.opt.pose_type, index='1')
        pose_2 = model.get_pose(model.opt.pose_type, index='2')
        # desc_1 (ref)
        output, ps, qs, ds = model.netT(appr_1,
                                        pose_1,
                                        pose_1,
                                        mode='transfer',
                                        output_feature=True)
        output = model.parse_output(output, model.opt.output_type)
        img = torch.nn.functional.tanh(output['image'])
        desc['img_1'].append(img.detach().cpu())
        desc['feat1_1'].append(ds[-1].detach().cpu())
        desc['feat2_1'].append(ds[-2].detach().cpu())
        # desc_2 (tar)
        output, ps, qs, ds = model.netT(appr_1,
                                        pose_1,
                                        pose_2,
                                        mode='transfer',
                                        output_feature=True)
        output = model.parse_output(output, model.opt.output_type)
        img = torch.nn.functional.tanh(output['image'])
        desc['img_2'].append(img.detach().cpu())
        desc['feat1_2'].append(ds[-1].detach().cpu())
        desc['feat2_2'].append(ds[-2].detach().cpu())

    for k, v in desc.iteritems():
        desc[k] = torch.cat(v).numpy().transpose(0, 2, 3, 1)

    # save descriptor
    data_dict_img = {
        'desc_1': desc['img_1'],
        'desc_2': desc['img_2'],
        'name': 'gen_vunet_img'
    }
    data_dict_feat1 = {
        'desc_1': desc['feat1_1'],
        'desc_2': desc['feat1_2'],
        'name': 'gen_vunet_feat1'
    }
    data_dict_feat2 = {
        'desc_1': desc['feat2_1'],
        'desc_2': desc['feat2_2'],
        'name': 'gen_vunet_feat2'
    }

    scipy.io.matlab.savemat(
        'temp/patch_matching/descriptor/desc_gen_vunet-7.9_img.mat',
        data_dict_img)
    scipy.io.matlab.savemat(
        'temp/patch_matching/descriptor/desc_gen_vunet-7.9_feat1.mat',
        data_dict_feat1)
    scipy.io.matlab.savemat(
        'temp/patch_matching/descriptor/desc_gen_vunet-7.9_feat2.mat',
        data_dict_feat2)

    # visualize desc_img
    vis_dir = 'temp/patch_matching/descriptor/vis_gen_vunet_img/'
    io.mkdir_if_missing(vis_dir)
    imgs_1 = (desc['img_1'] * 127.5 + 127.5).astype(np.uint8)
    imgs_2 = (desc['img_2'] * 127.5 + 127.5).astype(np.uint8)
    for i in range(imgs_1.shape[0]):
        imageio.imwrite(vis_dir + '%d_1.jpg' % i, imgs_1[i])
        imageio.imwrite(vis_dir + '%d_2.jpg' % i, imgs_2[i])
Пример #5
0
import sys

from options.train_options import TrainOptions
opt = TrainOptions().parse()  # set CUDA_VISIBLE_DEVICES before import torch
from data.data_loader import CreateDataLoader
from models.models import create_model

dataset_root = "/phoenix/S6/zl548/"
test_list_dir_l = '/phoenix/S6/zl548/MegaDpeth_code/test_list/landscape/'
input_height = 240
input_width = 320
is_flipped = False
shuffle = False

test_data_loader_l = CreateDataLoader(dataset_root, test_list_dir_l,
                                      input_height, input_width, is_flipped,
                                      shuffle)
test_dataset_l = test_data_loader_l.load_data()
test_dataset_size_l = len(test_data_loader_l)
print('========================= test images = %d' % test_dataset_size_l)
test_list_dir_p = '/phoenix/S6/zl548/MegaDpeth_code/test_list/portrait/'
input_height = 320
input_width = 240
test_data_loader_p = CreateDataLoader(dataset_root, test_list_dir_p,
                                      input_height, input_width, is_flipped,
                                      shuffle)
test_dataset_p = test_data_loader_p.load_data()
test_dataset_size_p = len(test_data_loader_p)
print('========================= test images = %d' % test_dataset_size_p)

model = create_model(opt)
Пример #6
0
import time
from options.train_options import TrainOptions
from data.data_loader import CreateDataLoader
from models.models import create_model
from util.visualizer import Visualizer
import copy
import numpy as np
import os

opt = TrainOptions().parse()
data_loader = CreateDataLoader(opt)

opt_copy = copy.deepcopy(opt)
opt_copy.dataset_mode = 'aligned'
opt_copy.dataroot = opt.dataroot_aligned
opt_copy.resize_or_crop = 'resize_and_crop'
paired_data_loader = CreateDataLoader(opt_copy)
aligned_dataset = paired_data_loader.load_data()

dataset = data_loader.load_data()
dataset_size = len(data_loader) + len(paired_data_loader)
print('#training images = %d' % dataset_size)

#model = create_model(opt)
visualizer = Visualizer(opt)
total_steps = 0

for epoch in range(1):  #opt.epoch_count, opt.niter + opt.niter_decay + 1):
    epoch_start_time = time.time()
    epoch_iter = 0
Пример #7
0
from models.models import create_model
from util.visualizer import Visualizer
from pdb import set_trace as st
from util import html
from data.process_videos import avi2pngs

opt = TrainOptions().parse()
opt.nThreads = 1   # test code only supports nThreads = 1


opt.continue_train = True
opt.dataroot = opt.trainroot
opt.file_list = opt.train_file_list
avi2pngs(opt)
opt.dataroot = opt.dataroot + '/split'
training_data_loader = CreateDataLoader(opt)
train_dataset = training_data_loader.load_data()
train_dataset_size = len(training_data_loader)
print('#training samples = %d' % train_dataset_size)


opt.batchSize = 1  # test code only supports batchSize = 1
opt.serial_batches = True  # no shuffle
opt.no_flip = True  # no flip
opt.dataroot = opt.testroot
opt.file_list = opt.test_file_list
avi2pngs(opt)
opt.dataroot = opt.dataroot + '/split'
testing_data_loader = CreateDataLoader(opt)
test_dataset = testing_data_loader.load_data()
test_dataset_size = len(testing_data_loader)
Пример #8
0
def main():
    opt = TrainOptions().parse()
    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    print('#training images = %d' % dataset_size)

    # read pix2pix/PAN moodel
    if opt.model == 'pix2pix':
        assert (opt.dataset_mode == 'aligned')
        from models.pix2pix_model import Pix2PixModel
        model = Pix2PixModel()
        model.initialize(opt)
    elif opt.model == 'pan':
        from models.pan_model import PanModel
        model = PanModel()
        model.initialize(opt)

    total_steps = 0

    batch_size = opt.batchSize
    print_freq = opt.print_freq
    epoch_count = opt.epoch_count
    niter = opt.niter
    niter_decay = opt.niter_decay
    display_freq = opt.display_freq
    save_latest_freq = opt.save_latest_freq
    save_epoch_freq = opt.save_epoch_freq

    for epoch in range(epoch_count, niter + niter_decay + 1):
        epoch_start_time = time.time()
        epoch_iter = 0

        for i, data in enumerate(dataset):
            # data --> (1, 3, 256, 256)
            iter_start_time = time.time()
            total_steps += batch_size
            epoch_iter += batch_size
            model.set_input(data)
            model.optimize_parameters()

            if total_steps % print_freq == 0:
                errors = model.get_current_errors()
                t = (time.time() - iter_start_time) / batch_size

                message = '(epoch: %d, iters: %d, time: %.3f) ' % (
                    epoch, epoch_iter, t)
                for k, v in errors.items():
                    message += '%s: %.3f ' % (k, v)
                print(message)

            # save latest weights
            if total_steps % save_latest_freq == 0:
                print('saving the latest model (epoch %d, total_steps %d)' %
                      (epoch, total_steps))
                model.save('latest')

        # save weights periodicaly
        if epoch % save_epoch_freq == 0:
            print('saving the model at the end of epoch %d, iters %d' %
                  (epoch, total_steps))
            model.save('latest')
            model.save(epoch)

        print('End of epoch %d / %d \t Time Taken: %d sec' %
              (epoch, niter + niter_decay, time.time() - epoch_start_time))
        model.update_learning_rate()
Пример #9
0
import tqdm
import time
from collections import OrderedDict

# parse and save options
parser = TrainPoseTransferOptions()
opt = parser.parse()
parser.save()
# create model
model = PoseTransferModel()
model.initialize(opt)
# save terminal line
io.save_str_list([' '.join(sys.argv)],
                 os.path.join(model.save_dir, 'order_line.txt'))
# create data loader
train_loader = CreateDataLoader(opt, split='train')
val_loader = CreateDataLoader(
    opt, split='test' if not opt.small_val_set else 'test_small')
# create visualizer
visualizer = Visualizer(opt)

logdir = os.path.join('logs', opt.id)
if not os.path.exists(logdir):
    os.makedirs(logdir)
writer = tensorboardX.SummaryWriter(logdir)

# set "saving best"
best_info = {'meas': 'SSIM', 'type': 'max', 'best_value': 0, 'best_epoch': -1}

# set continue training
if not opt.resume_train:
Пример #10
0
import sys
import time

import numpy as np
import torch.backends.cudnn as cudnn
from tqdm import tqdm

from data.data_loader import CreateDataLoader
from models.models import create_model
from options.train_options import TrainOptions

cudnn.benchmark = True

opt = TrainOptions().parse()
ROOT = '/data1'
testDataLoader = CreateDataLoader(opt, {'mode':'Test', 'labelFn':'in5008.txt', 'rootPath':ROOT, 'subDir':'eval'})
#testDataLoader = CreateDataLoader(opt, {'mode':'Train', 'labelFn':'in5008.txt', 'rootPath':ROOT, 'subDir':'train'})
 

testDataset = testDataLoader.load_data()
print('#testing images = %d' % len(testDataLoader))

model = create_model(opt)

totalWer = list()
totalCer = list()
for i, data in tqdm(enumerate(testDataset), total=len(testDataset)):
    model.set_input(data)
    results = model.test()
    totalWer.append(results['wer'])
    totalCer.append(results['cer'])
def train_pose2vid(target_dir, run_name, temporal_smoothing=False):
    import src.config.train_opt as opt

    opt = update_opt(opt, target_dir, run_name, temporal_smoothing)

    iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.json')
    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    print('#training images = %d' % dataset_size)

    if opt.load_pretrain != '':
        with open(iter_path, 'r') as f:
            iter_json = json.load(f)
    else:
        iter_json = {'start_epoch': 1, 'epoch_iter': 0}

    start_epoch = iter_json['start_epoch']
    epoch_iter = iter_json['epoch_iter']
    total_steps = (start_epoch - 1) * dataset_size + epoch_iter
    display_delta = total_steps % opt.display_freq
    print_delta = total_steps % opt.print_freq
    save_delta = total_steps % opt.save_latest_freq

    model = create_model(opt)
    model = model.to(device)
    visualizer = Visualizer(opt)

    for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()
        if epoch != start_epoch:
            epoch_iter = epoch_iter % dataset_size
        for i, data in enumerate(dataset, start=epoch_iter):
            iter_start_time = time.time()
            total_steps += opt.batchSize
            epoch_iter += opt.batchSize

            # whether to collect output images
            save_fake = total_steps % opt.display_freq == display_delta

            ############## Forward Pass ######################
            if temporal_smoothing:
                losses, generated = model(Variable(data['label']),
                                          Variable(data['inst']),
                                          Variable(data['image']),
                                          Variable(data['feat']),
                                          Variable(data['previous_label']),
                                          Variable(data['previous_image']),
                                          infer=save_fake)
            else:
                losses, generated = model(Variable(data['label']),
                                          Variable(data['inst']),
                                          Variable(data['image']),
                                          Variable(data['feat']),
                                          infer=save_fake)

            # sum per device losses
            losses = [
                torch.mean(x) if not isinstance(x, int) else x for x in losses
            ]
            loss_dict = dict(zip(model.loss_names, losses))

            # calculate final loss scalar
            loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5
            loss_G = loss_dict['G_GAN'] + loss_dict.get(
                'G_GAN_Feat', 0) + loss_dict.get('G_VGG', 0)

            ############### Backward Pass ####################
            # update generator weights
            model.optimizer_G.zero_grad()
            loss_G.backward()
            model.optimizer_G.step()

            # update discriminator weights
            model.optimizer_D.zero_grad()
            loss_D.backward()
            model.optimizer_D.step()

            ############## Display results and errors ##########

            print(f"Epoch {epoch} batch {i}:")
            print(f"loss_D: {loss_D}, loss_G: {loss_G}")
            print(
                f"loss_D_fake: {loss_dict['D_fake']}, loss_D_real: {loss_dict['D_real']}"
            )
            print(
                f"loss_G_GAN {loss_dict['G_GAN']}, loss_G_GAN_Feat: {loss_dict.get('G_GAN_Feat', 0)}, loss_G_VGG: {loss_dict.get('G_VGG', 0)}\n"
            )

            ### print out errors
            if total_steps % opt.print_freq == print_delta:
                errors = {
                    k: v.item() if not isinstance(v, int) else v
                    for k, v in loss_dict.items()
                }
                # errors = {k: v.data[0] if not isinstance(v, int) else v for k, v in loss_dict.items()}
                t = (time.time() - iter_start_time) / opt.batchSize
                visualizer.print_current_errors(epoch, epoch_iter, errors, t)
                visualizer.plot_current_errors(errors, total_steps)

            ### display output images
            if save_fake:
                visuals = OrderedDict([
                    ('input_label',
                     util.tensor2label(data['label'][0], opt.label_nc)),
                    ('synthesized_image', util.tensor2im(generated.data[0])),
                    ('real_image', util.tensor2im(data['image'][0]))
                ])
                visualizer.display_current_results(visuals, epoch, total_steps)

            ### save latest model
            if total_steps % opt.save_latest_freq == save_delta:
                print('saving the latest model (epoch %d, total_steps %d)' %
                      (epoch, total_steps))
                model.save('latest')
                iter_json['start_epoch'] = epoch
                iter_json['epoch_iter'] = epoch_iter
                with open(iter_path, 'w') as f:
                    json.dump(iter_json, f)

            if epoch_iter >= dataset_size:
                break

        # end of epoch
        print('End of epoch %d / %d \t Time Taken: %d sec' %
              (epoch, opt.niter + opt.niter_decay,
               time.time() - epoch_start_time))

        ### save model for this epoch
        if epoch % opt.save_epoch_freq == 0:
            print('saving the model at the end of epoch %d, iters %d' %
                  (epoch, total_steps))
            model.save('latest')
            model.save(epoch)
            iter_json['start_epoch'] = epoch + 1
            iter_json['epoch_iter'] = 0
            with open(iter_path, 'w') as f:
                json.dump(iter_json, f)

        ### instead of only training the local enhancer, train the entire network after certain iterations
        if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global):
            model.update_fixed_params()

        ### linearly decay learning rate after certain iterations
        if epoch > opt.niter:
            model.update_learning_rate()

    torch.cuda.empty_cache()
Пример #12
0
def train():
    opt = TrainOptions().parse()

    if opt.distributed:
        init_dist()
        print('batch size per GPU: %d' % opt.batchSize)
    torch.backends.cudnn.benchmark = True

    ### setup dataset
    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    pose = 'pose' in opt.dataset_mode

    ### setup trainer
    trainer = Trainer(opt, data_loader)

    ### setup models
    model, flowNet, [optimizer_G,
                     optimizer_D] = create_model(opt, trainer.start_epoch)
    flow_gt = conf_gt = [None] * 2

    for epoch in range(trainer.start_epoch, opt.niter + opt.niter_decay + 1):
        if opt.distributed:
            dataset.sampler.set_epoch(epoch)
        trainer.start_of_epoch(epoch, model, data_loader)
        n_frames_total, n_frames_load = data_loader.dataset.n_frames_total, opt.n_frames_per_gpu
        for idx, data in enumerate(dataset, start=trainer.epoch_iter):
            data = trainer.start_of_iter(data)

            if not opt.no_flow_gt:
                data_list = [
                    data['tgt_label'], data['ref_label']
                ] if pose else [data['tgt_image'], data['ref_image']]
                flow_gt, conf_gt = flowNet(data_list, epoch)
            data_list = [
                data['tgt_label'], data['tgt_image'], flow_gt, conf_gt
            ]
            data_ref_list = [data['ref_label'], data['ref_image']]
            data_prev = [None, None, None]

            ############## Forward Pass ######################
            for t in range(0, n_frames_total, n_frames_load):
                data_list_t = get_data_t(data_list, n_frames_load,
                                         t) + data_ref_list + data_prev

                d_losses = model(data_list_t, mode='discriminator')
                d_losses = loss_backward(opt, d_losses, optimizer_D, 1)

                g_losses, generated, data_prev = model(
                    data_list_t, save_images=trainer.save, mode='generator')
                g_losses = loss_backward(opt, g_losses, optimizer_G, 0)

            loss_dict = dict(
                zip(model.module.lossCollector.loss_names,
                    g_losses + d_losses))

            if trainer.end_of_iter(loss_dict,
                                   generated + data_list + data_ref_list,
                                   model):
                break
        trainer.end_of_epoch(model)
Пример #13
0
def main():
    with open('./train/train_opt.pkl', mode='rb') as f:
        opt = pickle.load(f)
        opt.checkpoints_dir = './checkpoints/'
        opt.dataroot = './train'
        opt.no_flip = True
        opt.label_nc = 0
        opt.batchSize = 2
        print(opt)

    iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')

    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    print('#training images = %d' % dataset_size)

    start_epoch, epoch_iter = 1, 0
    total_steps = (start_epoch - 1) * dataset_size + epoch_iter
    display_delta = total_steps % opt.display_freq
    print_delta = total_steps % opt.print_freq
    save_delta = total_steps % opt.save_latest_freq
    best_loss = 999999
    epoch_loss = 9999999999
    model = create_model(opt)
    model = model.cuda()
    visualizer = Visualizer(opt)
    #niter = 20,niter_decay = 20
    for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()
        if epoch != start_epoch:
            epoch_iter = epoch_iter % dataset_size
        for i, data in enumerate(dataset, start=epoch_iter):
            iter_start_time = time.time()
            total_steps += opt.batchSize
            epoch_iter += opt.batchSize

            # whether to collect output images
            save_fake = total_steps % opt.display_freq == display_delta

            ############## Forward Pass ######################
            losses, generated = model(Variable(data['label']), Variable(data['inst']),
                                      Variable(data['image']), Variable(data['feat']), infer=save_fake)

            # sum per device losses
            losses = [torch.mean(x) if not isinstance(x, int) else x for x in losses]
            loss_dict = dict(zip(model.loss_names, losses))

            # calculate final loss scalar
            loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5
            loss_G = loss_dict['G_GAN'] + loss_dict.get('G_GAN_Feat', 0) + loss_dict.get('G_VGG', 0)
            loss_DG = loss_D + loss_G

            ############### Backward Pass ####################
            # update generator weights
            model.optimizer_G.zero_grad()
            loss_G.backward()
            model.optimizer_G.step()

            # update discriminator weights
            model.optimizer_D.zero_grad()
            loss_D.backward()
            model.optimizer_D.step()

            # call(["nvidia-smi", "--format=csv", "--query-gpu=memory.used,memory.free"])

            ############## Display results and errors ##########

            ### print out errors
            if total_steps % opt.print_freq == print_delta:
                errors = {k: v.data[0] if not isinstance(v, int) else v for k, v in loss_dict.items()}
                t = (time.time() - iter_start_time) / opt.batchSize
                visualizer.print_current_errors(epoch, epoch_iter, errors, t)
                visualizer.plot_current_errors(errors, total_steps)

            ### display output images
            if save_fake:
                visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)),
                                       ('synthesized_image', util.tensor2im(generated.data[0])),
                                       ('real_image', util.tensor2im(data['image'][0]))])
                visualizer.display_current_results(visuals, epoch, total_steps)

            ### save latest model
            if total_steps % opt.save_latest_freq == save_delta and loss_DG<best_loss:
                best_loss = loss_DG
                print('saving the latest model (epoch %d, total_steps %d ,total loss %g)' % (epoch, total_steps,loss_DG.item()))
                model.save('latest')
                np.savetxt(iter_path, (epoch, epoch_iter), delimiter=',', fmt='%d')

            if epoch_iter >= dataset_size:
                break

        # end of epoch
        iter_end_time = time.time()
        print('End of epoch %d / %d \t Time Taken: %d sec' %
              (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))

        ### save model for this epoch
        if epoch % opt.save_epoch_freq == 0:

            print('saving the model at the end of epoch %d, iters %d ' % (epoch, total_steps))
            model.save('latest')
            model.save(epoch)
            np.savetxt(iter_path, (epoch + 1, 0), delimiter=',', fmt='%d')

        ### instead of only training the local enhancer, train the entire network after certain iterations
        if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global):
            model.update_fixed_params()

        ### linearly decay learning rate after certain iterations
        if epoch > opt.niter:
            model.update_learning_rate()

    torch.cuda.empty_cache()
Пример #14
0
import time
from options.train_options import TrainOptions
from data.data_loader import CreateDataLoader
from models.models import create_model
from util.visualizer import Visualizer
import pdb
pdb.set_trace()

opt = TrainOptions().parse()

# data for training
data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()
dataset_size = len(data_loader)
print('#training images = %d' % dataset_size)

# data for test
from copy import deepcopy
test_opt = deepcopy(opt)
test_opt.phase = 'test'
test_opt.nThreads = 1  # test code only supports nThreads = 1
test_opt.batchSize = 1  # test code only supports batchSize = 1
test_opt.serial_batches = True  # no shuffle
test_opt.no_flip = True  # no flip
test_data_loader = CreateDataLoader(test_opt)
test_dataset = test_data_loader.load_data()
previous_score = 0.0

#########
model = create_model(opt)
visualizer = Visualizer(opt)
Пример #15
0
from tensorboardX import SummaryWriter
from util.preprocess import preprocess_for_train
from util.extract_key_bases import extract_key_bases

if __name__ == "__main__":
    opt = TrainOptions().parse()

    # 对数据进行预处理
    print('Pre-processing datasets for train ...')
    start = time.time()
    preprocess_for_train(opt.dataset_root)
    end = time.time()
    print('Pre-process for train completed !')
    print('Pre-process time : %d min' % int((end - start) / 60))

    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    print('#training images = %d' % dataset_size)

    #create validation set data loader if validation_on option is set
    if opt.validation_on:
        #temperally set to val to load val data
        opt.mode = 'val'
        data_loader_val = CreateDataLoader(opt)
        dataset_val = data_loader_val.load_data()
        dataset_size_val = len(data_loader_val)
        print('#validation images = %d' % dataset_size_val)
        opt.mode = 'train'  #set it back

    writer = SummaryWriter(comment=opt.name)
Пример #16
0
import time
from options.train_options import TrainOptions
from data.data_loader import CreateDataLoader
from models.models import create_model
from util.visualizer import Visualizer
import cv2
import copy
import pdb

opt = TrainOptions().parse()
opt2 = copy.copy(opt)
# change some opts
opt2.dataroot = opt.dataroot_more
#opt.which_direction = 'BtoA' # ensure the common domain A is photo
#opt2.which_direction = 'BtoA'
data_loader = CreateDataLoader(opt)
data_loader2 = CreateDataLoader(opt2)

dataset = data_loader.load_data()
dataset2 = data_loader2.load_data()
dataset_size = min(len(data_loader), len(data_loader2))
#dataset_size = len(data_loader)
print('#training images = %d' % dataset_size)
model = create_model(opt)
visualizer = Visualizer(opt)
total_steps = 0

for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
    epoch_start_time = time.time()
    epoch_iter = 0
Пример #17
0
def train():
    opt = TrainOptions().parse()
    if opt.debug:
        opt.display_freq = 1
        opt.print_freq = 1
        opt.nThreads = 1

    ### initialize dataset
    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    print('#training videos = %d' % dataset_size)

    ### initialize models
    models = create_model(opt)
    modelG, modelD, flowNet, optimizer_G, optimizer_D, optimizer_D_T = create_optimizer(
        opt, models)

    ### set parameters
    n_gpus, tG, tD, tDB, s_scales, t_scales, input_nc, output_nc, \
        start_epoch, epoch_iter, print_freq, total_steps, iter_path = init_params(opt, modelG, modelD, dataset_size)
    visualizer = Visualizer(opt)

    ### real training starts here
    for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()
        for idx, data in enumerate(dataset, start=epoch_iter):
            if total_steps % print_freq == 0:
                iter_start_time = time.time()
            total_steps += opt.batchSize
            epoch_iter += opt.batchSize

            # whether to collect output images
            save_fake = total_steps % opt.display_freq == 0
            n_frames_total, n_frames_load, t_len = data_loader.dataset.init_data_params(
                data, n_gpus, tG)
            fake_B_last, frames_all = data_loader.dataset.init_data(t_scales)

            for i in range(0, n_frames_total, n_frames_load):
                input_A, input_B, inst_A = data_loader.dataset.prepare_data(
                    data, i, input_nc, output_nc)

                ###################################### Forward Pass ##########################
                ####### generator
                fake_B, fake_B_raw, flow, weight, real_A, real_Bp, fake_B_last = modelG(
                    input_A, input_B, inst_A, fake_B_last)

                ####### discriminator
                ### individual frame discriminator
                real_B_prev, real_B = real_Bp[:, :
                                              -1], real_Bp[:,
                                                           1:]  # the collection of previous and current real frames
                flow_ref, conf_ref = flowNet(
                    real_B, real_B_prev)  # reference flows and confidences
                #flow_ref, conf_ref = util.remove_dummy_from_tensor([flow_ref, conf_ref])
                fake_B_prev = modelG.module.compute_fake_B_prev(
                    real_B_prev, fake_B_last, fake_B)

                losses = modelD(
                    0,
                    reshape([
                        real_B, fake_B, fake_B_raw, real_A, real_B_prev,
                        fake_B_prev, flow, weight, flow_ref, conf_ref
                    ]))
                losses = [
                    torch.mean(x) if x is not None else 0 for x in losses
                ]
                loss_dict = dict(zip(modelD.module.loss_names, losses))

                ### temporal discriminator
                # get skipped frames for each temporal scale
                frames_all, frames_skipped = modelD.module.get_all_skipped_frames(frames_all, \
                        real_B, fake_B, flow_ref, conf_ref, t_scales, tD, n_frames_load, i, flowNet)

                # run discriminator for each temporal scale
                loss_dict_T = []
                for s in range(t_scales):
                    if frames_skipped[0][s] is not None:
                        losses = modelD(s + 1, [
                            frame_skipped[s]
                            for frame_skipped in frames_skipped
                        ])
                        losses = [
                            torch.mean(x) if not isinstance(x, int) else x
                            for x in losses
                        ]
                        loss_dict_T.append(
                            dict(zip(modelD.module.loss_names_T, losses)))

                # collect losses
                loss_G, loss_D, loss_D_T, t_scales_act = modelD.module.get_losses(
                    loss_dict, loss_dict_T, t_scales)

                ###################################### Backward Pass #################################
                # update generator weights
                loss_backward(opt, loss_G, optimizer_G)

                # update individual discriminator weights
                loss_backward(opt, loss_D, optimizer_D)

                # update temporal discriminator weights
                for s in range(t_scales_act):
                    loss_backward(opt, loss_D_T[s], optimizer_D_T[s])

                if i == 0:
                    fake_B_first = fake_B[
                        0, 0]  # the first generated image in this sequence

            if opt.debug:
                call([
                    "nvidia-smi", "--format=csv",
                    "--query-gpu=memory.used,memory.free"
                ])

            ############## Display results and errors ##########
            ### print out errors
            if total_steps % print_freq == 0:
                t = (time.time() - iter_start_time) / print_freq
                errors = {
                    k: v.data.item() if not isinstance(v, int) else v
                    for k, v in loss_dict.items()
                }
                for s in range(len(loss_dict_T)):
                    errors.update({
                        k + str(s):
                        v.data.item() if not isinstance(v, int) else v
                        for k, v in loss_dict_T[s].items()
                    })
                visualizer.print_current_errors(epoch, epoch_iter, errors, t)
                visualizer.plot_current_errors(errors, total_steps)

            ### display output images
            if save_fake:
                visuals = util.save_all_tensors(opt, real_A, fake_B,
                                                fake_B_first, fake_B_raw,
                                                real_B, flow_ref, conf_ref,
                                                flow, weight, modelD)
                visualizer.display_current_results(visuals, epoch, total_steps)

            ### save latest model
            save_models(opt, epoch, epoch_iter, total_steps, visualizer,
                        iter_path, modelG, modelD)
            if epoch_iter > dataset_size - opt.batchSize:
                epoch_iter = 0
                break

        # end of epoch
        iter_end_time = time.time()
        visualizer.vis_print('End of epoch %d / %d \t Time Taken: %d sec' %
                             (epoch, opt.niter + opt.niter_decay,
                              time.time() - epoch_start_time))

        ### save model for this epoch and update model params
        save_models(opt,
                    epoch,
                    epoch_iter,
                    total_steps,
                    visualizer,
                    iter_path,
                    modelG,
                    modelD,
                    end_of_epoch=True)
        update_models(opt, epoch, modelG, modelD, data_loader)
Пример #18
0
def train():
    opt = TrainOptions().parse()
    iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')
    if opt.continue_train:
        try:
            start_epoch, epoch_iter = np.loadtxt(iter_path,
                                                 delimiter=',',
                                                 dtype=int)
        except:
            start_epoch, epoch_iter = 1, 0
        # compute resume lr
        if start_epoch > opt.niter:
            lrd_unit = opt.lr / opt.niter_decay
            resume_lr = opt.lr - (start_epoch - opt.niter) * lrd_unit
            opt.lr = resume_lr
        print('Resuming from epoch %d at iteration %d' %
              (start_epoch, epoch_iter))
    else:
        start_epoch, epoch_iter = 1, 0

    opt.print_freq = lcm(opt.print_freq, opt.batchSize)
    if opt.debug:
        opt.display_freq = 2
        opt.print_freq = 2
        opt.niter = 3
        opt.niter_decay = 0
        opt.max_dataset_size = 1
        opt.valSize = 1

    ## Loading data
    # train data
    data_loader = CreateDataLoader(opt, isVal=False)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    print('# training images = %d' % dataset_size)
    # validation data
    data_loader = CreateDataLoader(opt, isVal=True)
    valset = data_loader.load_data()
    print('# validation images = %d' % len(data_loader))

    ## Loading model
    model = create_model(opt)
    visualizer = Visualizer(opt)
    if opt.fp16:
        from apex import amp
        model, [optimizer_G, optimizer_D
                ] = amp.initialize(model,
                                   [model.optimizer_G, model.optimizer_D],
                                   opt_level='O1')
        model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids)
    else:
        optimizer_G, optimizer_D = model.module.optimizer_G, model.module.optimizer_D

    total_steps = (start_epoch - 1) * dataset_size + epoch_iter

    display_delta = total_steps % opt.display_freq
    print_delta = total_steps % opt.print_freq
    save_delta = total_steps % opt.save_latest_freq

    for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()
        if epoch != start_epoch:
            epoch_iter = epoch_iter % dataset_size
        for i, data in enumerate(dataset, start=epoch_iter):
            if total_steps % opt.print_freq == print_delta:
                iter_start_time = time.time()

            total_steps += opt.batchSize
            epoch_iter += opt.batchSize

            # whether to collect output images
            save_fake = total_steps % opt.display_freq == display_delta

            ############## Forward Pass ######################
            model = model.train()
            losses, generated, metrics = model(data['A'],
                                               data['B'],
                                               data['geometry'],
                                               infer=False)

            # sum per device losses and metrics
            losses = [
                torch.mean(x) if not isinstance(x, int) else x for x in losses
            ]
            metric_dict = {k: torch.mean(v) for k, v in metrics.items()}
            loss_dict = dict(zip(model.module.loss_names, losses))

            # calculate final loss scalar
            loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5
            loss_G = loss_dict['G_GAN'] + opt.gan_feat_weight * loss_dict.get(
                'G_GAN_Feat', 0) + opt.vgg_weight * loss_dict.get('G_VGG', 0)

            ############### Backward Pass ####################
            # update generator weights
            optimizer_G.zero_grad()
            if opt.fp16:
                with amp.scale_loss(loss_G, optimizer_G) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss_G.backward()
            optimizer_G.step()

            # update discriminator weights
            optimizer_D.zero_grad()
            if opt.fp16:
                with amp.scale_loss(loss_D, optimizer_D) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss_D.backward()
            optimizer_D.step()

            ############## Display results and errors ##########
            ### print out errors
            if total_steps % opt.print_freq == print_delta:
                errors = {
                    k: v.data.item() if not isinstance(v, int) else v
                    for k, v in loss_dict.items()
                }
                metrics_ = {
                    k: v.data.item() if not isinstance(v, int) else v
                    for k, v in metric_dict.items()
                }
                t = (time.time() - iter_start_time) / opt.print_freq
                visualizer.print_current_errors(epoch, epoch_iter, errors, t)
                visualizer.plot_current_errors(errors, total_steps)
                visualizer.print_current_metrics(epoch, epoch_iter, metrics_,
                                                 t)
                visualizer.plot_current_metrics(metrics_, total_steps)
                #call(["nvidia-smi", "--format=csv", "--query-gpu=memory.used,memory.free"])

            ### display output images
            if save_fake:
                if opt.task_type == 'specular':
                    visuals = OrderedDict([
                        ('albedo', util.tensor2im(data['A'][0])),
                        ('generated',
                         util.tensor2im_exr(generated.data[0], type=1)),
                        ('GT', util.tensor2im_exr(data['B'][0], type=1))
                    ])
                elif opt.task_type == 'low':
                    visuals = OrderedDict([
                        ('albedo', util.tensor2im(data['A'][0])),
                        ('generated',
                         util.tensor2im_exr(generated.data[0], type=2)),
                        ('GT', util.tensor2im_exr(data['B'][0], type=2))
                    ])
                elif opt.task_type == 'high':
                    visuals = OrderedDict([
                        ('albedo', util.tensor2im(data['A'][0])),
                        ('generated',
                         util.tensor2im_exr(generated.data[0], type=3)),
                        ('GT', util.tensor2im_exr(data['B'][0], type=3))
                    ])
                visualizer.display_current_results(visuals, epoch, total_steps)

            ### save latest model
            if total_steps % opt.save_latest_freq == save_delta:
                print('saving the latest model (epoch %d, total_steps %d)' %
                      (epoch, total_steps))
                model.module.save('latest')
                np.savetxt(iter_path, (epoch, epoch_iter),
                           delimiter=',',
                           fmt='%d')

            if epoch_iter >= dataset_size:
                break

        # end of epoch
        print('End of epoch %d / %d \t Time Taken: %d sec' %
              (epoch, opt.niter + opt.niter_decay,
               time.time() - epoch_start_time))

        ###########################################################################################
        # validation at the end of each epoch
        val_start_time = time.time()
        metrics_val = []
        for _, val_data in enumerate(valset):
            model = model.eval()
            # model.half()
            generated, metrics = model(val_data['A'],
                                       val_data['B'],
                                       val_data['geometry'],
                                       infer=True)
            metric_dict = {k: torch.mean(v) for k, v in metrics.items()}
            metrics_ = {
                k: v.data.item() if not isinstance(v, int) else v
                for k, v in metric_dict.items()
            }
            metrics_val.append(metrics_)
        # Print out losses
        metrics_val = visualizer.mean4dict(metrics_val)
        t = (time.time() - val_start_time) / opt.print_freq
        visualizer.print_current_metrics(epoch,
                                         epoch_iter,
                                         metrics_val,
                                         t,
                                         isVal=True)
        visualizer.plot_current_metrics(metrics_val, total_steps, isVal=True)
        # visualization
        if opt.task_type == 'specular':
            visuals = OrderedDict([
                ('albedo', util.tensor2im(val_data['A'][0])),
                ('generated', util.tensor2im_exr(generated.data[0], type=1)),
                ('GT', util.tensor2im_exr(val_data['B'][0], type=1))
            ])
        if opt.task_type == 'low':
            visuals = OrderedDict([
                ('albedo', util.tensor2im(val_data['A'][0])),
                ('generated', util.tensor2im_exr(generated.data[0], type=2)),
                ('GT', util.tensor2im_exr(val_data['B'][0], type=2))
            ])
        if opt.task_type == 'high':
            visuals = OrderedDict([
                ('albedo', util.tensor2im(val_data['A'][0])),
                ('generated', util.tensor2im_exr(generated.data[0], type=3)),
                ('GT', util.tensor2im_exr(val_data['B'][0], type=3))
            ])
        visualizer.display_current_results(visuals, epoch, epoch, isVal=True)
        ###########################################################################################

        ### save model for this epoch
        if epoch % opt.save_epoch_freq == 0:
            print('saving the model at the end of epoch %d, iters %d' %
                  (epoch, total_steps))
            model.module.save('latest')
            model.module.save(epoch)
            np.savetxt(iter_path, (epoch + 1, 0), delimiter=',', fmt='%d')

        ### instead of only training the local enhancer, train the entire network after certain iterations
        if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global):
            model.module.update_fixed_params()

        ### linearly decay learning rate after certain iterations
        if epoch > opt.niter:
            model.module.update_learning_rate()
Пример #19
0
def train(opt):
    iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')

    if opt.continue_train:
        if opt.which_epoch == 'latest':
            try:
                start_epoch, epoch_iter = np.loadtxt(iter_path,
                                                     delimiter=',',
                                                     dtype=int)
            except:
                start_epoch, epoch_iter = 1, 0
        else:
            start_epoch, epoch_iter = int(opt.which_epoch), 0

        print('Resuming from epoch %d at iteration %d' %
              (start_epoch, epoch_iter))
        for update_point in opt.decay_epochs:
            if start_epoch < update_point:
                break

            opt.lr *= opt.decay_gamma
    else:
        start_epoch, epoch_iter = 0, 0

    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    print('#training images = %d' % dataset_size)

    model = create_model(opt)
    visualizer = Visualizer(opt)

    total_steps = (start_epoch) * dataset_size + epoch_iter

    display_delta = total_steps % opt.display_freq
    print_delta = total_steps % opt.print_freq
    save_delta = total_steps % opt.save_latest_freq
    bSize = opt.batchSize

    #in case there's no display sample one image from each class to test after every epoch
    if opt.display_id == 0:
        dataset.dataset.set_sample_mode(True)
        dataset.num_workers = 1
        for i, data in enumerate(dataset):
            if i * opt.batchSize >= opt.numClasses:
                break
            if i == 0:
                sample_data = data
            else:
                for key, value in data.items():
                    if torch.is_tensor(data[key]):
                        sample_data[key] = torch.cat(
                            (sample_data[key], data[key]), 0)
                    else:
                        sample_data[key] = sample_data[key] + data[key]
        dataset.num_workers = opt.nThreads
        dataset.dataset.set_sample_mode(False)

    for epoch in range(start_epoch, opt.epochs):
        epoch_start_time = time.time()
        if epoch != start_epoch:
            epoch_iter = 0
        for i, data in enumerate(dataset, start=epoch_iter):
            iter_start_time = time.time()
            total_steps += opt.batchSize
            epoch_iter += opt.batchSize

            # whether to collect output images
            save_fake = (total_steps % opt.display_freq
                         == display_delta) and (opt.display_id > 0)

            ############## Network Pass ########################
            model.set_inputs(data)
            disc_losses = model.update_D()
            gen_losses, gen_in, gen_out, rec_out, cyc_out = model.update_G(
                infer=save_fake)
            loss_dict = dict(gen_losses, **disc_losses)
            ##################################################

            ############## Display results and errors ##########
            ### print out errors
            if total_steps % opt.print_freq == print_delta:
                errors = {
                    k: v.item()
                    if not (isinstance(v, float) or isinstance(v, int)) else v
                    for k, v in loss_dict.items()
                }
                t = (time.time() - iter_start_time) / opt.batchSize
                visualizer.print_current_errors(epoch + 1, epoch_iter, errors,
                                                t)
                if opt.display_id > 0:
                    visualizer.plot_current_errors(
                        epoch,
                        float(epoch_iter) / dataset_size, opt, errors)

            ### display output images
            if save_fake and opt.display_id > 0:
                class_a_suffix = ' class {}'.format(data['A_class'][0])
                class_b_suffix = ' class {}'.format(data['B_class'][0])
                classes = None

                visuals = OrderedDict()
                visuals_A = OrderedDict([('real image' + class_a_suffix,
                                          util.tensor2im(gen_in.data[0]))])
                visuals_B = OrderedDict([('real image' + class_b_suffix,
                                          util.tensor2im(gen_in.data[bSize]))])

                A_out_vis = OrderedDict([('synthesized image' + class_b_suffix,
                                          util.tensor2im(gen_out.data[0]))])
                B_out_vis = OrderedDict([('synthesized image' + class_a_suffix,
                                          util.tensor2im(gen_out.data[bSize]))
                                         ])
                if opt.lambda_rec > 0:
                    A_out_vis.update([('reconstructed image' + class_a_suffix,
                                       util.tensor2im(rec_out.data[0]))])
                    B_out_vis.update([('reconstructed image' + class_b_suffix,
                                       util.tensor2im(rec_out.data[bSize]))])
                if opt.lambda_cyc > 0:
                    A_out_vis.update([('cycled image' + class_a_suffix,
                                       util.tensor2im(cyc_out.data[0]))])
                    B_out_vis.update([('cycled image' + class_b_suffix,
                                       util.tensor2im(cyc_out.data[bSize]))])

                visuals_A.update(A_out_vis)
                visuals_B.update(B_out_vis)
                visuals.update(visuals_A)
                visuals.update(visuals_B)

                ncols = len(visuals_A)
                visualizer.display_current_results(visuals, epoch, classes,
                                                   ncols)

            ### save latest model
            if total_steps % opt.save_latest_freq == save_delta:
                print('saving the latest model (epoch %d, total_steps %d)' %
                      (epoch + 1, total_steps))
                model.save('latest')
                np.savetxt(iter_path, (epoch, epoch_iter),
                           delimiter=',',
                           fmt='%d')
                if opt.display_id == 0:
                    model.eval()
                    visuals = model.inference(sample_data)
                    visualizer.save_matrix_image(visuals, 'latest')
                    model.train()

        # end of epoch
        iter_end_time = time.time()
        print('End of epoch %d / %d \t Time Taken: %d sec' %
              (epoch + 1, opt.epochs, time.time() - epoch_start_time))

        ### save model for this epoch
        if (epoch + 1) % opt.save_epoch_freq == 0:
            print('saving the model at the end of epoch %d, iters %d' %
                  (epoch + 1, total_steps))
            model.save('latest')
            model.save(epoch + 1)
            np.savetxt(iter_path, (epoch + 1, 0), delimiter=',', fmt='%d')
            if opt.display_id == 0:
                model.eval()
                visuals = model.inference(sample_data)
                visualizer.save_matrix_image(visuals, epoch + 1)
                model.train()

        ### multiply learning rate by opt.decay_gamma after certain iterations
        if (epoch + 1) in opt.decay_epochs:
            model.update_learning_rate()
Пример #20
0
def train():
    opt = TrainOptions().parse()
    if opt.distributed:
        init_dist()
        opt.batchSize = opt.batchSize // len(opt.gpu_ids)

    ### setup dataset
    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()

    ### setup trainer
    trainer = Trainer(opt, data_loader)

    ### setup models
    model, flowNet = create_model(opt, trainer.start_epoch)
    flow_gt = conf_gt = [None] * 3

    ref_idx_fix = torch.zeros([opt.batchSize])
    for epoch in tqdm(
            range(trainer.start_epoch, opt.niter + opt.niter_decay + 1)):
        trainer.start_of_epoch(epoch, model, data_loader)
        n_frames_total, n_frames_load = data_loader.dataset.n_frames_total, opt.n_frames_per_gpu
        for idx, data in enumerate(tqdm(dataset), start=trainer.epoch_iter):
            trainer.start_of_iter()

            if not opt.warp_ani:
                data.update({
                    'ani_image': None,
                    'ani_lmark': None,
                    'cropped_images': None,
                    'cropped_lmarks': None
                })

            if not opt.no_flow_gt:
                data_list = [
                    data['tgt_mask_images'], data['cropped_images'],
                    data['warping_ref'], data['ani_image']
                ]
                flow_gt, conf_gt = flowNet(data_list, epoch)
            data_list = [
                data['tgt_label'], data['tgt_image'], data['tgt_template'],
                data['cropped_images'], flow_gt, conf_gt
            ]
            data_ref_list = [data['ref_label'], data['ref_image']]
            data_prev = [None, None, None]
            data_ani = [
                data['warping_ref_lmark'], data['warping_ref'],
                data['ori_warping_refs'], data['ani_lmark'], data['ani_image']
            ]

            ############## Forward Pass ######################
            prevs = {"raw_images":[], "synthesized_images":[], \
                    "prev_warp_images":[], "prev_weights":[], \
                    "ani_warp_images":[], "ani_weights":[], \
                    "ref_warp_images":[], "ref_weights":[], \
                    "ref_flows":[], "prev_flows":[], "ani_flows":[], \
                    "ani_syn":[]}
            for t in range(0, n_frames_total, n_frames_load):

                data_list_t = get_data_t(data_list, n_frames_load, t) + data_ref_list + \
                              get_data_t(data_ani, n_frames_load, t) + data_prev

                g_losses, generated, data_prev, ref_idx = model(
                    data_list_t,
                    save_images=trainer.save,
                    mode='generator',
                    ref_idx_fix=ref_idx_fix)
                g_losses = loss_backward(opt, g_losses,
                                         model.module.optimizer_G)

                d_losses, _ = model(data_list_t,
                                    mode='discriminator',
                                    ref_idx_fix=ref_idx_fix)
                d_losses = loss_backward(opt, d_losses,
                                         model.module.optimizer_D)

                # store previous
                store_prev(generated, prevs)

            loss_dict = dict(
                zip(model.module.lossCollector.loss_names,
                    g_losses + d_losses))

            output_data_list = [prevs] + [
                data['ref_image']
            ] + data_ani + data_list + [data['tgt_mask_images']]

            if trainer.end_of_iter(loss_dict, output_data_list, model):
                break

        trainer.end_of_epoch(model)
Пример #21
0
opt_train = TrainOptions().parse()
opt_val = TrainOptions().parse()

# Random seed
opt_train.use_gpu = len(opt_train.gpu_ids) and torch.cuda.is_available()
if opt_train.manualSeed is None:
    opt_train.manualSeed = random.randint(1, 10000)
print("Random Seed: ", opt_train.manualSeed)
random.seed(opt_train.manualSeed)
np.random.seed(opt_train.manualSeed)
torch.manual_seed(opt_train.manualSeed)
if opt_train.use_gpu:
    torch.cuda.manual_seed_all(opt_train.manualSeed)

data_loader = CreateDataLoader(opt_train)
dataset = data_loader.load_data()
dataset_size = len(data_loader)
print('#training images = %d' % dataset_size)

opt_val.phase = 'val'
opt_val.nThreads = 1  # test code only supports nThreads = 1
opt_val.batchSize = 1  # test code only supports batchSize = 1
opt_val.serial_batches = True  # no shuffle
opt_val.no_flip = True  # no flip
opt_val.no_rotate = True  # no rotate
if opt_val.valSize == 0:
    opt_val.valSize = opt_val.loadSize
opt_val.loadSize = opt_val.valSize
opt_val.fineSize = opt_val.valSize
data_loader_val = CreateDataLoader(opt_val)
Пример #22
0
                    help='manual seed')
parser.add_argument(
    '--no_lsgan',
    action='store_true',
    help='do *not* use least square GAN, if false, use vanilla GAN')
parser.add_argument(
    '--pool_size',
    type=int,
    default=62,
    help='the size of image buffer that stores previously generated images')
opt = parser.parse_args()
print(opt)

face_dataset = CreateDataLoader(
    opt,
    csv_fileA='./dataset/celeba/fine_grained_attribute_testA.txt',
    root_dirA='./dataset/celeba/testA/',
    csv_fileB='./dataset/celeba/fine_grained_attribute_testB.txt',
    root_dirB='./dataset/celeba/testB/')
# face_dataset = CreateDataLoader(opt, csv_fileA='./dataset/lfw/fine_grained_attribute_testA.txt',root_dirA='./dataset/lfw/testA/',
# csv_fileB='./dataset/lfw/fine_grained_attribute_testB.txt',root_dirB='./dataset/lfw/testB/')
dataset_size = len(face_dataset)
print('#test images = %d' % dataset_size)

if opt.manualSeed is None:
    opt.manualSeed = np.random.randint(1, 10000)

np.random.seed(opt.manualSeed)
torch.manual_seed(opt.manualSeed)
torch.cuda.manual_seed_all(opt.manualSeed)

model = create_model(opt)
Пример #23
0
def train():
    opt = TrainOptions().parse()
    if opt.debug:
        opt.display_freq = 1
        opt.print_freq = 1
        opt.nThreads = 1

    # initialize dataset
    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    if opt.dataset_mode == 'pose':
        print('#training frames = %d' % dataset_size)
    else:
        print('#training videos = %d' % dataset_size)

    # initialize models
    modelG, modelD, flowNet = create_model(opt)
    visualizer = Visualizer(opt)

    iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')
    # if continue training, recover previous states
    if opt.continue_train:
        try:
            start_epoch, epoch_iter = np.loadtxt(iter_path,
                                                 delimiter=',',
                                                 dtype=int)
        except:
            start_epoch, epoch_iter = 1, 0
        print('Resuming from epoch %d at iteration %d' %
              (start_epoch, epoch_iter))
        if start_epoch > opt.niter:
            modelG.module.update_learning_rate(start_epoch - 1)
            modelD.module.update_learning_rate(start_epoch - 1)
        if (opt.n_scales_spatial > 1) and (opt.niter_fix_global != 0) and (
                start_epoch > opt.niter_fix_global):
            modelG.module.update_fixed_params()
        if start_epoch > opt.niter_step:
            data_loader.dataset.update_training_batch(
                (start_epoch - 1) // opt.niter_step)
            modelG.module.update_training_batch(
                (start_epoch - 1) // opt.niter_step)
    else:
        start_epoch, epoch_iter = 1, 0

    # set parameters
    # number of gpus used for generator for each batch
    n_gpus = opt.n_gpus_gen // opt.batchSize
    tG, tD = opt.n_frames_G, opt.n_frames_D
    tDB = tD * opt.output_nc
    s_scales = opt.n_scales_spatial
    t_scales = opt.n_scales_temporal
    input_nc = 1 if opt.label_nc != 0 else opt.input_nc
    output_nc = opt.output_nc

    opt.print_freq = lcm(opt.print_freq, opt.batchSize)
    total_steps = (start_epoch - 1) * dataset_size + epoch_iter
    total_steps = total_steps // opt.print_freq * opt.print_freq

    # real training starts here
    for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()
        for idx, data in enumerate(dataset, start=epoch_iter):
            if total_steps % opt.print_freq == 0:
                iter_start_time = time.time()
            total_steps += opt.batchSize
            epoch_iter += opt.batchSize

            # whether to collect output images
            save_fake = total_steps % opt.display_freq == 0

            # n_frames_total = n_frames_load * n_loadings + tG - 1
            _, n_frames_total, height, width = data['B'].size()
            n_frames_total = n_frames_total // opt.output_nc
            # number of total frames loaded into GPU at a time for each batch
            n_frames_load = opt.max_frames_per_gpu * n_gpus
            n_frames_load = min(n_frames_load, n_frames_total - tG + 1)
            # number of loaded frames plus previous frames
            t_len = n_frames_load + tG - 1

            # the last generated frame from previous training batch (which becomes input to the next batch)
            fake_B_last = None
            # all real/generated frames so far
            real_B_all, fake_B_all, flow_ref_all, conf_ref_all = None, None, None, None
            if opt.sparse_D:
                real_B_all, fake_B_all, flow_ref_all, conf_ref_all = [
                    None
                ] * t_scales, [None] * t_scales, [None] * t_scales, [
                    None
                ] * t_scales
            # temporally subsampled frames
            real_B_skipped, fake_B_skipped = [None] * t_scales, [None
                                                                 ] * t_scales
            flow_ref_skipped, conf_ref_skipped = [None] * t_scales, [
                None
            ] * t_scales  # temporally subsampled flows

            for i in range(0, n_frames_total - t_len + 1, n_frames_load):
                # 5D tensor: batchSize, # of frames, # of channels, height, width
                input_A = Variable(
                    data['A'][:, i * input_nc:(i + t_len) * input_nc,
                              ...]).view(-1, t_len, input_nc, height, width)
                input_B = Variable(
                    data['B'][:, i * output_nc:(i + t_len) * output_nc,
                              ...]).view(-1, t_len, output_nc, height, width)
                inst_A = Variable(data['inst'][:, i:i + t_len, ...]).view(
                    -1, t_len, 1, height,
                    width) if len(data['inst'].size()) > 2 else None

                ###################################### Forward Pass ##########################
                # generator
                fake_B, fake_B_raw, flow, weight, real_A, real_Bp, fake_B_last = modelG(
                    input_A, input_B, inst_A, fake_B_last)

                if i == 0:
                    # the first generated image in this sequence
                    fake_B_first = fake_B[0, 0]
                # the collection of previous and current real frames
                real_B_prev, real_B = real_Bp[:, :-1], real_Bp[:, 1:]

                # discriminator
                # individual frame discriminator
                # reference flows and confidences
                flow_ref, conf_ref = flowNet(real_B, real_B_prev)
                fake_B_prev = real_B_prev[:, 0:
                                          1] if fake_B_last is None else fake_B_last[
                                              0][:, -1:]
                if fake_B.size()[1] > 1:
                    fake_B_prev = torch.cat(
                        [fake_B_prev, fake_B[:, :-1].detach()], dim=1)

                losses = modelD(
                    0,
                    reshape([
                        real_B, fake_B, fake_B_raw, real_A, real_B_prev,
                        fake_B_prev, flow, weight, flow_ref, conf_ref
                    ]))
                losses = [
                    torch.mean(x) if x is not None else 0 for x in losses
                ]
                loss_dict = dict(zip(modelD.module.loss_names, losses))

                # temporal discriminator
                loss_dict_T = []
                # get skipped frames for each temporal scale
                if t_scales > 0:
                    if opt.sparse_D:
                        real_B_all, real_B_skipped = get_skipped_frames_sparse(
                            real_B_all, real_B, t_scales, tD, n_frames_load, i)
                        fake_B_all, fake_B_skipped = get_skipped_frames_sparse(
                            fake_B_all, fake_B, t_scales, tD, n_frames_load, i)
                        flow_ref_all, flow_ref_skipped = get_skipped_frames_sparse(
                            flow_ref_all,
                            flow_ref,
                            t_scales,
                            tD,
                            n_frames_load,
                            i,
                            is_flow=True)
                        conf_ref_all, conf_ref_skipped = get_skipped_frames_sparse(
                            conf_ref_all,
                            conf_ref,
                            t_scales,
                            tD,
                            n_frames_load,
                            i,
                            is_flow=True)
                    else:
                        real_B_all, real_B_skipped = get_skipped_frames(
                            real_B_all, real_B, t_scales, tD)
                        fake_B_all, fake_B_skipped = get_skipped_frames(
                            fake_B_all, fake_B, t_scales, tD)
                        flow_ref_all, conf_ref_all, flow_ref_skipped, conf_ref_skipped = get_skipped_flows(
                            flowNet, flow_ref_all, conf_ref_all,
                            real_B_skipped, flow_ref, conf_ref, t_scales, tD)

                # run discriminator for each temporal scale
                for s in range(t_scales):
                    if real_B_skipped[s] is not None:
                        losses = modelD(s + 1, [
                            real_B_skipped[s], fake_B_skipped[s],
                            flow_ref_skipped[s], conf_ref_skipped[s]
                        ])
                        losses = [
                            torch.mean(x) if not isinstance(x, int) else x
                            for x in losses
                        ]
                        loss_dict_T.append(
                            dict(zip(modelD.module.loss_names_T, losses)))

                # collect losses
                loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5
                loss_G = loss_dict['G_GAN'] + \
                    loss_dict['G_GAN_Feat'] + loss_dict['G_VGG']
                loss_G += loss_dict['G_Warp'] + loss_dict['F_Flow'] + \
                    loss_dict['F_Warp'] + loss_dict['W']
                if opt.add_face_disc:
                    loss_G += loss_dict['G_f_GAN'] + loss_dict['G_f_GAN_Feat']
                    loss_D += (loss_dict['D_f_fake'] +
                               loss_dict['D_f_real']) * 0.5

                # collect temporal losses
                loss_D_T = []
                t_scales_act = min(t_scales, len(loss_dict_T))
                for s in range(t_scales_act):
                    loss_G += loss_dict_T[s]['G_T_GAN'] + \
                        loss_dict_T[s]['G_T_GAN_Feat'] + \
                        loss_dict_T[s]['G_T_Warp']
                    loss_D_T.append((loss_dict_T[s]['D_T_fake'] +
                                     loss_dict_T[s]['D_T_real']) * 0.5)

                ###################################### Backward Pass #################################
                optimizer_G = modelG.module.optimizer_G
                optimizer_D = modelD.module.optimizer_D
                # update generator weights
                optimizer_G.zero_grad()
                loss_G.backward()
                optimizer_G.step()

                # update discriminator weights
                # individual frame discriminator
                optimizer_D.zero_grad()
                loss_D.backward()
                optimizer_D.step()
                # temporal discriminator
                for s in range(t_scales_act):
                    optimizer_D_T = getattr(modelD.module,
                                            'optimizer_D_T' + str(s))
                    optimizer_D_T.zero_grad()
                    loss_D_T[s].backward()
                    optimizer_D_T.step()

            if opt.debug:
                call([
                    "nvidia-smi", "--format=csv",
                    "--query-gpu=memory.used,memory.free"
                ])

            ############## Display results and errors ##########
            # print out errors
            if total_steps % opt.print_freq == 0:
                t = (time.time() - iter_start_time) / opt.print_freq
                errors = {
                    k: v.data.item() if not isinstance(v, int) else v
                    for k, v in loss_dict.items()
                }
                for s in range(len(loss_dict_T)):
                    errors.update({
                        k + str(s):
                        v.data.item() if not isinstance(v, int) else v
                        for k, v in loss_dict_T[s].items()
                    })
                visualizer.print_current_errors(epoch, epoch_iter, errors, t)
                visualizer.plot_current_errors(errors, total_steps)

            # display output images
            if save_fake:
                if opt.label_nc != 0:
                    input_image = util.tensor2label(real_A[0, -1],
                                                    opt.label_nc)
                elif opt.dataset_mode == 'pose':
                    input_image = util.tensor2im(real_A[0, -1, :3])
                    if real_A.size()[2] == 6:
                        input_image2 = util.tensor2im(real_A[0, -1, 3:])
                        input_image[input_image2 != 0] = input_image2[
                            input_image2 != 0]
                else:
                    c = 3 if opt.input_nc == 3 else 1
                    input_image = util.tensor2im(real_A[0, -1, :c],
                                                 normalize=False)
                if opt.use_instance:
                    edges = util.tensor2im(real_A[0, -1, -1:, ...],
                                           normalize=False)
                    input_image += edges[:, :, np.newaxis]

                if opt.add_face_disc:
                    ys, ye, xs, xe = modelD.module.get_face_region(real_A[0,
                                                                          -1:])
                    if ys is not None:
                        input_image[ys, xs:xe, :] = input_image[
                            ye, xs:xe, :] = input_image[
                                ys:ye, xs, :] = input_image[ys:ye, xe, :] = 255

                visual_list = [
                    ('input_image', input_image),
                    ('fake_image', util.tensor2im(fake_B[0, -1])),
                    ('fake_first_image', util.tensor2im(fake_B_first)),
                    ('fake_raw_image', util.tensor2im(fake_B_raw[0, -1])),
                    ('real_image', util.tensor2im(real_B[0, -1])),
                    ('flow_ref', util.tensor2flow(flow_ref[0, -1])),
                    ('conf_ref',
                     util.tensor2im(conf_ref[0, -1], normalize=False))
                ]
                if flow is not None:
                    visual_list += [('flow', util.tensor2flow(flow[0, -1])),
                                    ('weight',
                                     util.tensor2im(weight[0, -1],
                                                    normalize=False))]
                visuals = OrderedDict(visual_list)
                visualizer.display_current_results(visuals, epoch, total_steps)

            # save latest model
            if total_steps % opt.save_latest_freq == 0:
                visualizer.vis_print(
                    'saving the latest model (epoch %d, total_steps %d)' %
                    (epoch, total_steps))
                modelG.module.save('latest')
                modelD.module.save('latest')
                np.savetxt(iter_path, (epoch, epoch_iter),
                           delimiter=',',
                           fmt='%d')

            if epoch_iter > dataset_size - opt.batchSize:
                epoch_iter = 0
                break

        # end of epoch
        iter_end_time = time.time()
        visualizer.vis_print('End of epoch %d / %d \t Time Taken: %d sec' %
                             (epoch, opt.niter + opt.niter_decay,
                              time.time() - epoch_start_time))

        # save model for this epoch
        if epoch % opt.save_epoch_freq == 0:
            visualizer.vis_print(
                'saving the model at the end of epoch %d, iters %d' %
                (epoch, total_steps))
            modelG.module.save('latest')
            modelD.module.save('latest')
            modelG.module.save(epoch)
            modelD.module.save(epoch)
            np.savetxt(iter_path, (epoch + 1, 0), delimiter=',', fmt='%d')

        # linearly decay learning rate after certain iterations
        if epoch > opt.niter:
            modelG.module.update_learning_rate(epoch)
            modelD.module.update_learning_rate(epoch)

        # gradually grow training sequence length
        if (epoch % opt.niter_step) == 0:
            data_loader.dataset.update_training_batch(epoch // opt.niter_step)
            modelG.module.update_training_batch(epoch // opt.niter_step)

        # finetune all scales
        if (opt.n_scales_spatial > 1) and (opt.niter_fix_global != 0) and (
                epoch == opt.niter_fix_global):
            modelG.module.update_fixed_params()
Пример #24
0
import time
from options.train_options import TrainOptions
from data.data_loader import CreateDataLoader
from models.models import create_model
from util.visualizer import Visualizer
import pudb

opt = TrainOptions().parse()
data_loader = CreateDataLoader(opt)  # inits CustomDatasetDataLoader
dataset = data_loader.load_data()  # returns CustomDatasetDataLoader obj
dataset_size = len(data_loader)
print('#training images = %d' % dataset_size)

model = create_model(opt)
visualizer = Visualizer(opt)
total_steps = 0

for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
    epoch_start_time = time.time()
    iter_data_time = time.time()
    epoch_iter = 0
    skip_count = 0

    for i, data in enumerate(dataset):
        iter_start_time = time.time()
        if total_steps % opt.print_freq == 0:
            t_data = iter_start_time - iter_data_time
        visualizer.reset()
        total_steps += opt.batchSize
        epoch_iter += opt.batchSize
Пример #25
0
from options.encoder_decoder_options_v2 import TrainEncoderDecoderOptions_V2
from misc.visualizer import GANVisualizer_V3

import util.io as io
import os
import sys
import time
import numpy as np
from collections import OrderedDict

opt = TrainEncoderDecoderOptions_V2().parse()
# create model
model = EncoderDecoderFramework_DFN()
model.initialize(opt)
# create data loader
train_loader = CreateDataLoader(opt, split = 'train')
val_loader = CreateDataLoader(opt, split = 'test')
# create visualizer
visualizer = GANVisualizer_V3(opt)

total_steps = 0

for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
    model.update_learning_rate()
    for i, data in enumerate(train_loader):
        total_steps += 1
        model.set_input(data)
        model.forward()
        model.optimize_parameters()

        if total_steps % opt.display_freq == 0:
Пример #26
0
import time
from options.test_options import TestOptions
from options.train_options import TrainOptions
from data.data_loader import CreateDataLoader
from models.models import create_model
from util.visualizer import Visualizer

# train
opt_train = TrainOptions().parse()
data_loader_train = CreateDataLoader(opt_train)
dataset_train = data_loader_train.load_data()
dataset_size_train = len(data_loader_train)
print('#training images = %d' % dataset_size_train)

# test
opt_test = TestOptions().parse()
opt_test.nThreads = 1  # test code only supports nThreads = 1
opt_test.batchSize = 1  # test code only supports batchSize = 1
opt_test.serial_batches = False  # no shuffle
opt_test.no_flip = True  # no flip
opt_test.how_many = 100
data_loader_test = CreateDataLoader(opt_test)
dataset_test = data_loader_test.load_data()
dataset_size_test = len(data_loader_test)
print('#test images = %d' % dataset_size_test)

model = create_model(opt_train)
visualizer = Visualizer(opt_train)
total_steps = 0

for epoch in range(opt_train.epoch_count,
Пример #27
0
loss_fns = [GramMSELoss()] * len(loss_layer)
if torch.cuda.is_available():
    loss_fns = [loss_fn.cuda() for loss_fn in loss_fns]
style_weights = [1e3 / n**2 for n in [64, 128, 256, 512]]

optimizer = torch.optim.Adam(
    filter(lambda p: p.requires_grad, model.parameters()), 0.0001,
    [0.9, 0.999])
optimizer_D = torch.optim.Adam(
    filter(lambda p: p.requires_grad, netD.parameters()), 0.0002, [0.9, 0.999])
print('# generator parameters:',
      sum(param.numel() for param in model.parameters()))
print('# discriminator parameters:',
      sum(param.numel() for param in netD.parameters()))
print()

for i in range(opt.start_training_step, 4):
    opt.nEpochs = training_settings[i - 1]['nEpochs']
    opt.lr = training_settings[i - 1]['lr']
    opt.step = training_settings[i - 1]['step']
    opt.lr_decay = training_settings[i - 1]['lr_decay']
    opt.lambda_db = training_settings[i - 1]['lambda_db']
    opt.gated = training_settings[i - 1]['gated']
    print(opt)
    for epoch in range(opt.start_epoch, opt.nEpochs + 1):
        lr = adjust_learning_rate(epoch - 1)
        trainloader = CreateDataLoader(opt)
        train(trainloader, model, criterion, optimizer, epoch, lr)
        if epoch % 5 == 0:
            checkpoint(i, epoch)
Пример #28
0
# load options
opt = TestMMGANOptions_V3().parse()
opt.batch_size = batch_size
org_opt = io.load_json(os.path.join('checkpoints', opt.id, 'train_opt.json'))
preserved_opt = {'id', 'gpu_ids', 'batch_size', 'is_train', 'dataset_mode'}
for k, v in org_opt.iteritems():
    if k in opt and (k not in preserved_opt):
        setattr(opt, k, v)
# create model
model = MultimodalDesignerGAN_V3()
model.initialize(opt)
for name, net in model.modules.iteritems():
    for p in net.parameters():
        p.requires_grad = False
# create data loader
val_loader_iter = iter(CreateDataLoader(opt, split='vis'))
# create visualizer
visualizer = GANVisualizer_V3(opt)

for i in range(num_batch):
    print('[%s] test generation: %d / %d' % (opt.id, i + 1, num_batch))

    data = val_loader_iter.next()
    imgs_edge = data['img_edge'][0::7].cpu().clone()
    imgs_color = data['img_color'][0:7].cpu().clone()
    s_id = data['id'][0][0]

    model.set_input(data)
    model.test('normal')

    vis_dir = 'vis_gen'
Пример #29
0
import time
import os
from options.test_options import TestOptions
from data.data_loader import CreateDataLoader
from models.models import create_model
from util.visualizer import Visualizer
from util import html

opt = TestOptions().parse()
opt.nThreads = 1  # test code only supports nThreads = 1
opt.batchSize = 1  # test code only supports batchSize = 1
opt.serial_batches = True  # no shuffle
opt.no_flip = True  # no flip

data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()
model = create_model(opt)
visualizer = Visualizer(opt)
# create website
web_dir = os.path.join(opt.results_dir, opt.name,
                       '%s_%s' % (opt.phase, opt.which_epoch))
webpage = html.HTML(
    web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' %
    (opt.name, opt.phase, opt.which_epoch))
# test
for i, data in enumerate(dataset):
    if i >= opt.how_many:
        break
    model.set_input(data)
    model.test()
    visuals = model.get_current_visuals_test()
Пример #30
0
import torch.backends.cudnn as cudnn
from torchsummary import summary
from tqdm import tqdm

from data.data_loader import CreateDataLoader
from models.models import create_model
from options.train_options import TrainOptions

cudnn.benchmark = True

opt = TrainOptions().parse()
data_loader = CreateDataLoader(
    opt, {
        'mode':
        'Train',
        'manifestFn':
        '/home/caspardu/data/LipReadProject/LipNetData/manifestFiles/wellDone_train.list',
        'labelFn':
        '/home/caspardu/data/LipReadProject/LipNetData/manifestFiles/label.txt'
    })
testDataLoader = CreateDataLoader(
    opt, {
        'mode':
        'Test',
        'manifestFn':
        '/home/caspardu/data/LipReadProject/LipNetData/manifestFiles/wellDone_test.list',
        'labelFn':
        '/home/caspardu/data/LipReadProject/LipNetData/manifestFiles/label.txt'
    })

dataset = data_loader.load_data()