예제 #1
0
import time
import copy

from data import CreateDataLoader
from models import create_model
from options.train_options import TrainOptions
from util.visualizer import Visualizer, save_images

if __name__ == '__main__':
    opt = TrainOptions().parse()
    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    print('#training images = %d' % dataset_size)

    if opt.validate_freq > 0:
        validate_opt = copy.deepcopy(opt)
        validate_opt.phase = 'val'
        validate_opt.serial_batches = True  # no shuffle
        val_data_loader = CreateDataLoader(validate_opt)
        val_dataset = val_data_loader.load_data()
        val_dataset_size = len(val_data_loader)
        print('#validation images = %d' % val_dataset_size)

    model = create_model(opt)
    model.setup(opt)
    visualizer = Visualizer(opt)
    total_steps = 0

    for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()
예제 #2
0
import os
from options.test_options import TestOptions
from data import CreateDataLoader
from models import create_model
from util.visualizer import Visualizer
from util import html


if __name__ == '__main__':
    opt = TestOptions().parse()
    opt.nThreads = 1   # test code only supports nThreads = 1
    opt.batchSize = 1  # test code only supports batchSize = 1
    opt.display_id = -1  # no visdom display
    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    model = create_model(opt)
    visualizer = Visualizer(opt)
    # create website
    web_dir = os.path.join(opt.results_dir, '%s_%s' % (opt.phase, opt.which_epoch))
    webpage = html.HTML(web_dir, 'Phase = %s, Epoch = %s' % (opt.phase, opt.which_epoch))
    # test
    for i, data in enumerate(dataset):
        if i >= opt.how_many:
            break
        model.set_input(data)
        model.test()
        visuals = model.get_current_visuals()
        print('%04d: process image...' % (i))
        visualizer.save_images(webpage, visuals, i, aspect_ratio=opt.aspect_ratio)

    webpage.save()
예제 #3
0
def main():
    # step1: opt
    opt = TrainOptions().parse()

    # step2: data
    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    print('#training images = {}'.format(dataset_size))

    # step3: model
    model = create_model(opt)
    model.setup(opt)  #

    # step4: vis
    visualizer = Visualier(opt)

    total_steps = 0
    for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()
        iter_data_time = time.time()
        epoch_iter = 0

        for i, data in enumerate(dataset):
            iter_start_time = time.time()
            if total_steps % opt.print_freq == 0:
                t_data = iter_start_time - iter_data_time
            visualizer.reset()
            total_steps += opt.batch_size
            epoch_iter += opt.batch_size
            model.set_input(data)
            model.optimize_parameters()

            # display images on visdom and html
            if total_steps % opt.display_freq == 0:
                save_result = total_steps % opt.update_html_freq == 0
                visualizer.display_current_results(model.get_current_visuals(),
                                                   epoch, save_result)

            # display losses on visdom and console
            if total_steps % opt.print_freq == 0:
                losses = model.get_current_losses()
                t = (time.time() - iter_start_time) / opt.batch_size
                visualizer.print_current_losses(epoch, epoch_iter, losses, t,
                                                t_data)
                if opt.display_id > 0:
                    visualizer.plot_current_losses(
                        epoch,
                        float(epoch_iter) / dataset_size, opt, losses)
            # save by iter
            if total_steps % opt.save_latest_freq == 0:
                print('saving the latest model epoch:{}, total_steps{}'.format(
                    epoch, total_steps))
                save_suffix = 'iter_{}'.format(
                    total_steps if opt.save_by_iter else 'latest')
                model.save_networks(save_suffix)

            iter_data_time = time.time()

        if epoch % opt.save_epoch_freq == 0:
            print('saving the model at the end of epoch {}, iter {}'.format(
                epoch, total_steps))
            model.save_networks('latest')
            model.save_networks(epoch)

        print('End of epoch {} {}\t Time Taken: {} sec '.format(
            epoch, opt.niter + opt.niter_decay,
            time.time() - epoch_start_time))

        model.update_learning_rate()
예제 #4
0
    print("time used for validation: ", time.time() - start_time)
    return ret_metrics


if __name__ == '__main__':
    opt_test = TestOptions().parse()

    # hard-code some parameters for test
    opt_test.num_threads = 1  # test code only supports num_threads = 1
    opt_test.batch_size = 1  # test code only supports batch_size = 1
    opt_test.serial_batches = True  # no shuffle
    opt_test.no_flip = True  # no flip
    opt_test.display_id = -1  # no visdom display
    opt_test.dataset_mode = 'ms_3d'
    data_loader = CreateDataLoader(opt_test)
    dataset_test = data_loader.load_data()

    models = []
    models_indx = opt_test.load_str.split(',')
    models_weight = [1] * len(models_indx)
    for i in models_indx:
        current_model = create_model(opt_test, i)
        current_model.setup(opt_test)
        if opt_test.eval:
            current_model.eval()
        models.append(current_model)

    losses = model_test(models,
                        dataset_test,
                        opt_test,
                        len(data_loader),
예제 #5
0
import time
from options.train_options import TrainOptions
from options.val_options import ValOptions
from data import CreateDataLoader
from models import create_model
from util.visualizer import Visualizer
import numpy as np
import math
from skimage.measure import compare_ssim

if __name__ == '__main__':
    opt = TrainOptions().parse()
    val = ValOptions().parse()
    data_loader = CreateDataLoader(opt)
    data_loader_val = CreateDataLoader(val)
    dataset = data_loader.load_data()
    dataset_val = data_loader_val.load_data()
    dataset_size = len(data_loader)
    dataset_val_size = len(data_loader_val)
    print('#training images = %d' % dataset_size)
    print('#validation images = %d' % dataset_val_size)

    val.nThreads = 1
    val.batchSize = 1
    model = create_model(opt)
    model_val = create_model(val)
    model.setup(opt)
    visualizer = Visualizer(opt)
    total_steps = 0

    for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
from models import create_model
from util.visualizer import Visualizer

# import random
# random.seed(666)

### landmark+bbox <-> face
### aligned, use UnalignedDataset class to achieve
if __name__ == '__main__':
    print('start...')
    ### data
    opt = TrainOptions().parse()
    print('opt parse: success!')
    data_loader = CreateDataLoader(opt)
    print('initialize data_loader: success!')
    dataset = data_loader.load_data()  ### return self 233
    dataset_size = len(data_loader)
    print('#training images = %d' % dataset_size)

    ### model
    model = create_model(opt)
    model.setup(opt)
    visualizer = Visualizer(opt)
    total_steps = 0

    ### train
    for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()
        iter_data_time = time.time()
        epoch_iter = 0
예제 #7
0
def convert():
    if (chkGpuVar.get() == 0):
        opt.gpu_ids.clear()
    #opt.remove_images = chkDelVar.get()
    opt.resize_or_crop = drpResizeOp.get()
    try:
        opt.epoch = txtEpoch.get()
    except Exception as e:
        print(e)

    if (opt.resize_or_crop.__contains__('scale')):
        for i in range(len(validSizes) - 2):
            if (sclFineVar.get() < validSizes[i + 1]
                    and sclFineVar.get() >= validSizes[i]):
                opt.fineSize = validSizes[i]

    print(testOptions.return_options(opt))
    try:
        data_loader = CreateDataLoader(opt)
        dataset = data_loader.load_data()
        model = create_model(opt)
        model.setup(opt)

        # test with eval mode. This only affects layers like batchnorm and dropout.
        # pix2pix: we use batchnorm and dropout in the original pix2pix. You can experiment it with and without eval() mode.
        # CycleGAN: It should not affect CycleGAN as CycleGAN uses instancenorm without dropout.

        progressbar.configure(maximum=len(dataset))
        #progressbar.start(len(dataset))
        for i, data in enumerate(dataset):
            while running:
                if i >= opt.num_test or running == False:
                    break
                model.set_input(data)
                model.test()
                visuals = model.get_current_visuals()
                img_path = model.get_image_paths()
                mess = 'processing (%04d)-th of %04d image... %s' % (
                    i + 1, len(dataset), img_path[0])
                print(mess)

                # Open a file with access mode 'a'
                file_object = open('conversion_progress.txt', 'a')
                # Append 'hello' at the end of file
                file_object.write(mess + '\n')
                # Close the file
                file_object.close()
                save_images(opt.results_dir,
                            visuals,
                            img_path,
                            save_both=opt.save_both,
                            aspect_ratio=opt.aspect_ratio)
                progress_var.set(i + 1)
                if (opt.remove_images):
                    os.remove(img_path[0])
                    print('removed image', img_path[0])
    except KeyboardInterrupt:
        progress_var.set(0)
        print("==============Cancelled==============")
        raise
    except Exception as e:
        print(e)
        raise
예제 #8
0
import time
from options.train_options import TrainOptions
from data import CreateDataLoader
from models import create_model
from util.visualizer import Visualizer
from util.laplotter import LossAccPlotter
import pdb

if __name__ == '__main__':
    opt = TrainOptions().parse()
    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    print('#training images = %d' % dataset_size)

    # validate
    opt.phase = 'val'
    validation_loader = CreateDataLoader(opt)
    validation_dataset = validation_loader.load_data()
    validation_dataset_size = len(validation_loader)
    print('validate images = %d' % validation_dataset_size)

    model = create_model(opt)  # has been initialized
    visualizer = Visualizer(opt)
    total_steps = 0

    plotter = LossAccPlotter(save_to_filepath='./checkpoints/nyud_fcrn/')

    for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()
        iter_data_time = time.time()
예제 #9
0
import numpy as np
import random
import torch
import cv2
from tensorboardX import SummaryWriter

if __name__ == '__main__':
    train_opt = TrainOptions().parse()

    np.random.seed(train_opt.seed)
    random.seed(train_opt.seed)
    torch.manual_seed(train_opt.seed)
    torch.cuda.manual_seed(train_opt.seed)

    train_data_loader = CreateDataLoader(train_opt)
    train_dataset = train_data_loader.load_data()
    train_dataset_size = len(train_data_loader)
    print('#training images = %d' % train_dataset_size)

    valid_opt = TrainOptions().parse()
    valid_opt.phase = 'val'
    valid_opt.batch_size = 1
    valid_opt.num_threads = 1
    valid_opt.serial_batches = True
    valid_opt.isTrain = False
    valid_data_loader = CreateDataLoader(valid_opt)
    valid_dataset = valid_data_loader.load_data()
    valid_dataset_size = len(valid_data_loader)
    print('#validation images = %d' % valid_dataset_size)

    writer = SummaryWriter()
예제 #10
0
from data import CreateDataLoader
from models import create_model
from util.visualizer import Visualizer
from PIL import Image
import visdom
from util.util import sdmkdir
from util import util
import time
import os

test_opt = TestOptions().parse()
model = create_model(test_opt)
model.setup(test_opt)

test_data_loader = CreateDataLoader(test_opt)
test_set = test_data_loader.load_data()
test_save_path = os.path.join(test_opt.checkpoints_dir, 'test')

if not os.path.isdir(test_save_path):
    os.makedirs(test_save_path)

model.eval()
idx = 0
for i, data in enumerate(test_set):
    idx += 1
    visuals = model.get_prediction(data)
    pred = visuals['final']
    gt = visuals['gt']
    im_name = data['imname'][0].split('.')[0]
    util.save_image(gt, os.path.join(test_save_path, im_name + '_gt.png'))
    util.save_image(pred, os.path.join(test_save_path,
예제 #11
0
def sanity_check(opt):
    abort_file = "/mnt/raid/patrickradner/kill" + str(opt.gpu_ids[0]) if len(
        opt.gpu_ids) > 0 else "cpu"

    if os.path.exists(abort_file):
        os.remove(abort_file)
        exit("Abort using file: " + abort_file)

    opt.max_dataset_size = 1
    opt.max_val_dataset_size = 1
    freq = 10
    opt.batch_size = 1
    opt.print_freq = freq
    opt.display_freq = freq
    opt.update_html_freq = freq
    opt.validation_freq = 50
    opt.niter = 500
    opt.niter_decay = 0
    opt.display_env = "sanity_check"
    opt.num_display_frames = 10
    opt.train_mode = "frame"
    #opt.reparse_data=True
    opt.lr = 0.004
    opt.pretrain_epochs = 0

    opt.verbose = True

    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    t_min = 100000
    t_max = 0

    print(f"Length: Min: {t_min} Max: {t_max}")

    if opt.validation_freq > 0:
        phase = opt.phase
        opt.phase = opt.validation_set
        validation_loader = CreateDataLoader(opt)
        validation_set = validation_loader.load_data()
        opt.phase = phase

        validation_size = len(validation_loader)
        print('#training samples = %d' % dataset_size)
        print('#validation samples = %d' % validation_size)

    model = create_model(opt)
    model.setup(opt)

    visualizer = Visualizer(opt)
    total_steps = 0

    data = next(iter(dataset))

    for epoch in range(5000):
        # training loop

        epoch_start_time = time.time()
        iter_data_time = time.time()
        epoch_iter = 0
        losses = {}

        if os.path.exists(abort_file):
            exit("Abort using file: " + abort_file)

        iter_start_time = time.time()
        if total_steps % opt.print_freq == 0:
            t_data = iter_start_time - iter_data_time
        visualizer.reset()
        total_steps += opt.batch_size
        epoch_iter += opt.batch_size

        model.set_input(data)
        model.optimize_parameters(epoch, verbose=opt.verbose)

        if total_steps % opt.display_freq == 0:
            save_result = total_steps % opt.update_html_freq == 0
            visualizer.display_current_results(model.get_current_visuals(),
                                               epoch, save_result)

        if total_steps % opt.print_freq == 0:
            losses = model.get_current_losses()
            t = (time.time() - iter_start_time) / opt.batch_size
            visualizer.print_current_losses(epoch, epoch_iter, losses, t,
                                            t_data)
            if opt.display_id > 0:
                visualizer.plot_current_losses(
                    epoch,
                    float(epoch_iter) / dataset_size, opt, losses)

        iter_data_time = time.time()

        if epoch % 50 == 0:
            print('End of sanity_check epoch %d / %d \t Time Taken: %d sec' %
                  (epoch, opt.niter + opt.niter_decay,
                   time.time() - epoch_start_time))

    print("SANITY CHECK DONE")
예제 #12
0
def create_data_loader(opt_this_phase):
    data_loader = CreateDataLoader(opt_this_phase)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    print('#%s images = %d' % (opt_this_phase.phase, dataset_size))
    return dataset, dataset_size
예제 #13
0
    def on_created(self, event):  # when file is created
        # do something, eg. call your function to process the image
        print("Got G event for file %s" % event.src_path)

        try:
            go = os.path.abspath(os.path.join(event.src_path, os.pardir, "go"))

            if not os.path.isfile(go):
                return

            with open(go) as f:
                name = f.readlines()[0]

            print("starting to process %s" % name)

            if self.opt.mlabel_condition:
                self.opt.mlabel_dataroot = self.opt.dataroot.rstrip(
                    '/\\') + '_mlabels'

            if self.opt.metrics_condition or self.opt.empty_condition:
                self.opt.empty_dataroot = self.opt.dataroot.rstrip(
                    '/\\') + '_empty'

            self.model.opt = self.opt

            data_loader = CreateDataLoader(self.opt)
            dataset = data_loader.load_data()

            for i, data in enumerate(dataset):
                # try:

                zs = os.path.basename(data['A_paths'][0])[:-4].split("_")[1:]
                z = np.array([float(i) for i in zs], dtype=np.float32)

                self.model.set_input(data)

                _, real_A, fake_B, real_B, _ = self.model.test_simple(
                    z, encode_real_B=False)

                img_path = self.model.get_image_paths()
                print('%04d: process image... %s' % (i, img_path))

                save_image(
                    fake_B, "./output/%s/%s/%s" %
                    (self.directory, name, os.path.basename(img_path[0])))
                save_image(
                    real_A, "./output/%s/%s/%s_label" %
                    (self.directory, name, os.path.basename(img_path[0])))

                if self.fit_boxes is not None:
                    fit_boxes(
                        img=fake_B,
                        classes=self.fit_boxes[0],
                        fit_labels=self.fit_boxes[1],
                        json_path="./output/%s/%s/%s_boxes" %
                        (self.directory, name, os.path.basename(img_path[0])))

                if self.fit_circles is not None:
                    fit_circles(
                        img=fake_B,
                        classes=self.fit_circles[0],
                        fit_labels=self.fit_circles[1],
                        json_path="./output/%s/%s/%s_circles" %
                        (self.directory, name, os.path.basename(img_path[0])))

        except Exception as e:
            traceback.print_exc()
            print(e)

        try:
            rmrf('./input/%s/val/*' % self.directory)
            rmrf('./input/%s_empty/val/*' % self.directory)
            rmrf('./input/%s_mlabel/val/*' % self.directory)

            if os.path.isfile(go):
                os.remove(go)
        except Exception as e:
            traceback.print_exc()
            print(e)
예제 #14
0
def train():
    try:
        if opt.continue_train == 1:
            opt.epoch = int(txtEpoch.get())
        opt.name = txtName.get()
        opt.loadSize = int(txtLoadSize.get())
        opt.fineSize = int(txtFineSize.get())
        opt.epoch_count = int(txtEpochCount.get())

    except ValueError as ve:
        print("\nPlease ensure all text boxes have only numbers")
        raise
    except Exception as e:
        print(e)
        raise
    if __name__ == '__main__':
        try:
            data_loader = CreateDataLoader(opt)
            dataset = data_loader.load_data()
            dataset_size = len(data_loader)
            print('#training images = %d' % dataset_size)

            model = create_model(opt)
            model.setup(opt)
            visualizer = Visualizer(opt)
            total_steps = 0

            print(trainOptions.return_options(opt))

            for epoch in range(opt.epoch_count,
                               opt.niter + opt.niter_decay + 1):
                epoch_start_time = time.time()
                iter_data_time = time.time()
                epoch_iter = 0

                for i, data in enumerate(dataset):
                    global running
                    if running == False:
                        raise KeyboardInterrupt
                    iter_start_time = time.time()
                    if total_steps % opt.print_freq == 0:
                        t_data = iter_start_time - iter_data_time
                    visualizer.reset()
                    total_steps += opt.batch_size
                    epoch_iter += opt.batch_size
                    model.set_input(data)
                    model.optimize_parameters()

                    if total_steps % opt.display_freq == 0:
                        save_result = total_steps % opt.update_html_freq == 0
                        visualizer.display_current_results(
                            model.get_current_visuals(), epoch, save_result)

                    if total_steps % opt.print_freq == 0:
                        losses = model.get_current_losses()
                        t = (time.time() - iter_start_time) / opt.batch_size
                        visualizer.print_current_losses(
                            epoch, epoch_iter, losses, t, t_data)
                        if opt.display_id > 0:
                            visualizer.plot_current_losses(
                                epoch,
                                float(epoch_iter) / dataset_size, opt, losses)

                    if total_steps % opt.save_latest_freq == 0:
                        print(
                            'saving the latest model (epoch %d, total_steps %d)'
                            % (epoch, total_steps))
                        model.save_networks('latest')

                    iter_data_time = time.time()
                if epoch % opt.save_epoch_freq == 0:
                    print('saving the model at the end of epoch %d, iters %d' %
                          (epoch, total_steps))
                    model.save_networks('latest')
                    model.save_networks(epoch)

                print('End of epoch %d / %d \t Time Taken: %d sec' %
                      (epoch, opt.niter + opt.niter_decay,
                       time.time() - epoch_start_time))
                model.update_learning_rate()
        except KeyboardInterrupt:
            print("==============Cancelled==============")
            raise
        except Exception as e:
            print(e)
            raise
예제 #15
0
    opt.dataroot = './traindata'
    opt.name = 'DANN_miter1step1'
    opt.batchSize = 64
    opt.lr = 0.00001
    opt.model = 'DANN_m_iter'
    opt.which_epochs_DA = 1
    opt.which_usename_DA = 'DANN_mstep1without'
    opt.which_epochs_Di = 10
    opt.which_usename_Di = 'DANN_mv3step2'
    opt.gpu_ids = [0]
    opt.save_epoch_freq = 100
    """

    mnist_data_loader, mnistm_data_loader, eval_data_loader = CreateDataLoader(
        opt)
    mnist_dataset, mnistm_dataset, eval_dataset = mnist_data_loader.load_data(
    ), mnistm_data_loader.load_data(), eval_data_loader.load_data()
    mnist_dataset_size = len(mnist_data_loader)
    mnistm_dataset_size = len(mnistm_data_loader)
    eval_dataset_size = len(eval_data_loader)

    print('#mnist training images = %d' % mnist_dataset_size)
    print('#mnistm training images = %d' % mnistm_dataset_size)
    print('#eval training images = %d' % eval_dataset_size)
    print('#eval training images = %d' % len(eval_dataset))

    model = create_model(opt)
    best_acc = 0
    total_steps = 0
    i = 0
    for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()
예제 #16
0
def runMethodTest(destPath, opt, contEncPath, styleEncPath, decPath,
                  useConcat):
    if not os.path.exists(destPath):
        os.makedirs(destPath)

    data_loader = CreateDataLoader(opt)
    dataloader = data_loader.load_data()
    dataset_size = len(data_loader)
    print('#testing images = %d' % dataset_size)

    cont_enc = _EncoderNoa(opt.imageSize)
    stl_enc = _EncoderNoa(opt.imageSize)
    dec = _DecoderNoa(opt.imageSize, useConcat)
    cont_enc.load_state_dict(torch.load(contEncPath))
    stl_enc.load_state_dict(torch.load(styleEncPath))
    dec.load_state_dict(torch.load(decPath))

    if opt.cuda:
        # feature_extractor.cuda()
        cont_enc.cuda()
        stl_enc.cuda()
        dec.cuda()

    torch.manual_seed(0)

    for i, data in enumerate(dataloader, 0):

        img1 = data['A']
        img2 = data['B']
        img12 = data['A2']
        img21 = data['B2']

        if opt.cuda:
            img1 = img1.cuda()
            img2 = img2.cuda()
            img12 = img12.cuda()
            img21 = img21.cuda()

        stl1 = stl_enc(img1)
        stl2 = stl_enc(img2)
        cont1 = cont_enc(img1)
        cont2 = cont_enc(img2)

        # stl12 = stl_enc(img12)
        # stl21 = stl_enc(img21)
        # cont12 = cont_enc(img12)
        # cont21 = cont_enc(img21)

        if (useConcat):
            stl1cont2 = torch.cat((stl1, cont2), 1)
            stl2cont1 = torch.cat((stl2, cont1), 1)
            # stl1cont1 = torch.cat((stl1, cont1), 1)
            # stl2cont2 = torch.cat((stl2, cont2), 1)
        else:
            stl1cont2 = stl1 + cont2
            stl2cont1 = stl2 + cont1
            # stl1cont1 = stl1 + cont1
            # stl2cont2 = stl2 + cont2

        dec12 = dec(stl1cont2)
        dec21 = dec(stl2cont1)
        # dec11 = dec(stl1cont1)
        # dec22 = dec(stl2cont2)

        if i % 10 == 0:
            im1 = util.tensor2im(img1[0])
            im2 = util.tensor2im(img2[0])
            oim12 = util.tensor2im(img12[0])
            oim21 = util.tensor2im(img21[0])
            im12 = util.tensor2im(dec12[0])
            im21 = util.tensor2im(dec21[0])

            imageio.imwrite(
                os.path.join(destPath, '%d_style_1_cont_2.png' % (i)), im12)
            imageio.imwrite(
                os.path.join(destPath, '%d_style_2_cont_1.png' % (i)), im21)
            imageio.imwrite(
                os.path.join(destPath, '%d_style_1_cont_1_orig.png' % (i)),
                im1)
            imageio.imwrite(
                os.path.join(destPath, '%d_style_2_cont_2_orig.png' % (i)),
                im2)
            imageio.imwrite(
                os.path.join(destPath, '%d_style_1_cont_2_orig.png' % (i)),
                oim12)
            imageio.imwrite(
                os.path.join(destPath, '%d_style_2_cont_1_orig.png' % (i)),
                oim21)
예제 #17
0
def train():
    import time
    from options.train_options import TrainOptions
    from data import CreateDataLoader
    from models import create_model
    from util.visualizer import Visualizer
    opt = TrainOptions().parse()
    model = create_model(opt)
    #Loading data
    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    print('Training images = %d' % dataset_size)
    visualizer = Visualizer(opt)
    total_steps = 0
    #Starts training
    for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()
        iter_data_time = time.time()
        epoch_iter = 0
        for i, data in enumerate(dataset):
            iter_start_time = time.time()
            if total_steps % opt.print_freq == 0:
                t_data = iter_start_time - iter_data_time
            visualizer.reset()
            total_steps += opt.batchSize
            epoch_iter += opt.batchSize
            model.set_input(data)
            model.optimize_parameters()
            #Save current images (real_A, real_B, fake_A, fake_B, rec_A, rec_B)
            if epoch_iter % opt.display_freq == 0:
                save_result = total_steps % opt.update_html_freq == 0
                visualizer.display_current_results(model.get_current_visuals(),
                                                   epoch, epoch_iter,
                                                   save_result)
            #Save current errors
            if total_steps % opt.print_freq == 0:
                errors = model.get_current_errors()
                t = (time.time() - iter_start_time) / opt.batchSize
                visualizer.print_current_errors(epoch, epoch_iter, errors, t,
                                                t_data)
                if opt.display_id > 0:
                    visualizer.plot_current_errors(
                        epoch,
                        float(epoch_iter) / dataset_size, opt, errors)
            #Save model based on the number of iterations
            if total_steps % opt.save_latest_freq == 0:
                print('saving the latest model (epoch %d, total_steps %d)' %
                      (epoch, total_steps))
                model.save('latest')

            iter_data_time = time.time()
        #Save model based on the number of epochs
        print(opt.dataset_mode)
        if epoch % opt.save_epoch_freq == 0:
            print('saving the model at the end of epoch %d, iters %d' %
                  (epoch, total_steps))
            model.save('latest')
            model.save(epoch)

        print('End of epoch %d / %d \t Time Taken: %d sec' %
              (epoch, opt.niter + opt.niter_decay,
               time.time() - epoch_start_time))
        model.update_learning_rate()
예제 #18
0
 def train(self):
     
     dataloader = CreateDataLoader(self.opt)
     dataset = dataloader.load_data()
     data_size = len(dataloader)
     
     writer = SummaryWriter( log_dir = self.log_dir)
     
     #now_epoch = self.step // self.
     
     for i in range(self.niter + self.niter_decay + 1):
         
         epoch_start = time.time()
         epoch_step = 0
         
         for j , data in enumerate(dataset):
             
             
             
             self.set_input(data)
             self.update_bicycleGAN()
             
             self.step = self.step + 1
             
             epoch_step += data['A'].size(0)
             
             
             
             #if self.step % self.print_freq == 0 :
             print ("[Epoch %d/%d] [Batch %d/%d] [D loss1: %f] [D loss2: %f] [Z loss: %f] [G loss: %f]" % \
                    (i, self.epoch  , j * self.batch_size , data_size, self.d_loss1.item(), self.d_loss2.item() , self.z_loss.item(), self.eg_loss.item() ))
             
             if self.step % 200 == 0 :
                 writer.add_scalar('loss/g1_loss' , self.gan_loss1 , self.step)
                 writer.add_scalar('loss/g2_loss' , self.gan_loss2 , self.step)
                 writer.add_scalar('loss/d1_loss' , self.d_loss1 , self.step)
                 writer.add_scalar('loss/d2_loss' , self.d_loss2 , self.step)
                 writer.add_scalar('loss/z_loss' , self.z_loss , self.step)
                 writer.add_scalar('loss/l1_loss' , self.l1_loss , self.step)
                 writer.add_scalar('loss/kl_loss' , self.kl_loss , self.step)
             
             if self.step % 500 == 0:
                 torch.save({ 'generator' : self.generator.state_dict(),
                           'discriminator1' : self.discriminator1.state_dict(),
                           'discriminator2' : self.discriminator2.state_dict(),
                           'encoder' : self.encoder.state_dict(),
                           'opt_g' : self.opt_g.state_dict(),
                           'opt_d1' : self.opt_d1.state_dict(),
                           'opt_d2' : self.opt_d2.state_dict(),
                           'opt_e' : self.opt_e.state_dict(),
                           'step' : self.step} , os.path.join(self.ckpt_dir,self.run_name+'.ckpt'))
             
             if self.step % self.sample_freq == 0 :
                 encode_pair = torch.cat( [self.realA_encode , self.realB_encode , self.fakeB_encode , self.fakeB_random] , dim = 0 )
                 vutils.save_image( encode_pair , self.sample_dir + '/%d.png' % self.step , normalize=True)
             
             # write the logs
         #############################################
         epoch_end = time.time()
         self.update_learning_rate()
         
         if i % self.save_freq == 0:
                 torch.save({ 'generator' : self.generator.state_dict(),
                           'discriminator1' : self.discriminator1.state_dict(),
                           'discriminator2' : self.discriminator2.state_dict(),
                           'encoder' : self.encoder.state_dict(),
                           'opt_g' : self.opt_g.state_dict(),
                           'opt_d1' : self.opt_d1.state_dict(),
                           'opt_d2' : self.opt_d2.state_dict(),
                           'opt_e' : self.opt_e.state_dict(),
                           'step' : self.step} , os.path.join(self.ckpt_dir,self.run_name+'.ckpt'))
예제 #19
0
        opt.batchSize = 200
        opt.serial_batches = True  # no shuffle
        opt.no_flip = True  # no flip
        opt.fid_count = True
        fid_data_length = 20000
        last_measure = 10
        save_path = opt.result_path
        save_result_flag = True
        if not os.path.isdir(save_path):
            os.mkdir(save_path)
        save_root = os.path.join(save_path, opt.name)
        txt_path = os.path.join(save_root, '%s.txt' % opt.name)

        opt.phase = 'test'
        data_loader = CreateDataLoader(opt)
        test_dataset = data_loader.load_data()
        model = create_model(opt)

        if not model.load_networks(opt.which_step):
            print('error no checkpoint')
            exit(0)
        # creat data loader
        opt.phase = 'train'
        data_loader = CreateDataLoader(opt)
        dataset = data_loader.load_data()
        # creat test_data

        # ground truth create
        block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
        Inception_model = InceptionV3([block_idx])
        Inception_model.cuda()
예제 #20
0
import time
from options.train_options import TrainOptions
from data import CreateDataLoader
from models import create_model
from util.visualizer import Visualizer

if __name__ == '__main__':
    opt = TrainOptions().parse()
    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    print('#training images = %d' % dataset_size)

    model = create_model(opt)
    model.setup(opt)
    visualizer = Visualizer(opt)
    total_steps = 0

    opt.dataset_mode = 'unaligned'
    opt.dataroot = opt.dataroot_unaligned
    data_loader_unaligned = CreateDataLoader(opt)
    dataset_unaligned = data_loader_unaligned.load_data()

    for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()
        iter_data_time = time.time()
        epoch_iter = 0

        for i, data in enumerate(dataset):
            iter_start_time = time.time()
            if total_steps % opt.print_freq == 0:
예제 #21
0
import random
import torch
import cv2

if __name__ == '__main__':
    train_opt = TrainOptions().parse()

    np.random.seed(train_opt.seed)
    random.seed(train_opt.seed)
    torch.manual_seed(train_opt.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.cuda.manual_seed(train_opt.seed)

    train_data_loader = CreateDataLoader(train_opt)
    train_dataset = train_data_loader.load_data()
    train_dataset_size = len(train_data_loader)
    print('#training images = %d' % train_dataset_size)

    test_opt = TrainOptions().parse()
    test_opt.phase = 'val'
    test_opt.batch_size = 1
    test_opt.num_threads = 1
    test_opt.serial_batches = True
    test_opt.no_flip = True
    test_opt.display_id = -1
    test_data_loader = CreateDataLoader(test_opt)
    test_dataset = test_data_loader.load_data()
    test_dataset_size = len(test_data_loader)
    print('#test images = %d' % test_dataset_size)
예제 #22
0
import time
import copy
import torch
from options.train_options import TrainOptions
from data import CreateDataLoader
from models import create_model
from util.visualizer import Visualizer

if __name__ == '__main__':
    # training dataset
    opt = TrainOptions().parse()
    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    print('#training images = %d' % dataset_size)
    print('#training objects = %d' % opt.nTrainObjects)

    ## validation dataset
    if opt.compute_val:
        opt_validation = copy.copy(opt)  # create a clone
        opt_validation.phase = 'val'
        opt_validation.serial_batches = True
        opt_validation.isTrain = False
        data_loader_validation = CreateDataLoader(opt_validation)
        dataset_validation = data_loader_validation.load_data()
        dataset_size_validation = len(data_loader_validation)
        print('#validation images = %d' % dataset_size_validation)
        print('#validation objects = %d' % opt_validation.nValObjects)

    # model
    model = create_model(opt)
예제 #23
0
opt.resizeFM = 'resize'
opt.isTrain = True
opt.no_flip = True
opt.gray = False
opt.serial_batches = False
opt.nThreads = 4
opt.max_dataset_size = float("inf")
#opt.which_direction = 'AToB'
opt.input_nc = 1
opt.nc = 1
opt.useConcat = useConcat
opt.classifyFonts = classifyFonts
opt.useFeatureMatchingLoss = useFeatureMatchingLoss
opt.useMSE = useMSE
data_loader = CreateDataLoader(opt)
dataloader = data_loader.load_data()
dataset_size = len(data_loader)
print('#training images = %d' % dataset_size)

lfile = os.path.join(opt.out, 'loss_log.txt')
lfile_handle = open(lfile, 'w')

optFile = os.path.join(opt.out, 'opt.txt')
optFile_handle = open(optFile, 'w')
optFile_handle.write(str(opt))
optFile_handle.close()

modelFile = os.path.join(opt.out, 'models.txt')
modelFile_handle = open(modelFile, 'w')

with torch.no_grad():
예제 #24
0
if __name__ == '__main__':
    opt = TrainOptions().parse()
    if opt.no_normalize:
        transform = tr.ToTensor()
    else:
        transform = tr.Compose([
            tr.ToTensor(),
            tr.Normalize(mean=opt.transform_mean, std=opt.transform_std)
        ])
    #mix train---syn data
    if opt.train_type == 'mix':
        opt.batch_size = opt.batch_size // 2
        train_loader = CreateDataLoader(opt,dataroot=opt.dataroot,image_dir=opt.train_img_dir_syn,\
                                   label_dir=opt.train_label_dir_syn,record_txt=opt.train_img_list_syn,\
                                                transform=transform,is_aug=False)
        train_dataset = train_loader.load_data()
        dataset_size = len(train_loader)
        print('#Synthetic training images = %d, batchsize = %d' %
              (dataset_size, opt.batch_size))

    #train---real data
    train_loader_real = CreateDataLoader(opt,dataroot=opt.dataroot,image_dir=opt.train_img_dir_real,\
                                   label_dir=opt.train_label_dir_real,record_txt=opt.train_img_list_real,\
                                         transform=transform,is_aug=False)
    train_dataset_real = train_loader_real.load_data()
    dataset_size_real = len(train_loader_real)
    print('#Real training images = %d, batchsize = %d' %
          (dataset_size_real, opt.batch_size))

    # eval data
    if not opt.no_eval:
예제 #25
0
def main(style):

    opt = TestOptions().parse()

    opt.dataroot = "datasets/own_data/testA"

    # four styles
    # opt.name = "style_ink_pretrained"
    # opt.name = "style_monet_pretrained"
    # opt.name = "style_cezanne_pretrained"
    # opt.name = "style_ukiyoe_pretrained"
    # opt.name = "style_vangogh_pretrained"


    # set original img size
    original_img = cv2.imread(opt.dataroot+"/temp.jpg")
    original_img_shape = tuple([item for item in original_img.shape[:-1]][::-1])

    opt.name = "style_%s_pretrained" % style
    # 不可更改
    opt.model = "test"

    cv2.imread("temp.jpg")

    opt.nThreads = 1   # test code only supports nThreads = 1
    opt.batchSize = 1  # test code only supports batchSize = 1
    opt.serial_batches = True  # no shuffle
    opt.no_flip = True  # no flip
    opt.display_id = -1  # no visdom display

    # need to overwrite 8-27 这边可以不要
    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()

    # create model
    model = create_model(opt)
    model.setup(opt)

    # create website
    # website没什么用,但是作者把保存图片写到了web_dir里面了,我就没有修改。

    web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch))
    print("web_dir", web_dir)
    webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch))
    print("webpage", webpage)
    # exit()

    # test
    for i, data in enumerate(dataset):
        # i is index enumerate生成,很简单的
        # type of data is dict
        # one key is A, A is a tensor which size is ([1, 3, 256, 256]), another is A_path which type is str. from the read path (include the name)
        # i. e. datasets/own_data/testA/2test.jpg
        # default how_many is 50 : 一个数据集中只能处理 50 张照片

        # need to overwrite  "data"
        # data 的形状和其一样,然后外面改写一个监听,应该就可以了
        if i >= opt.how_many:
            break
        model.set_input(data)

        model.test()
        visuals = model.get_current_visuals()
        img_path = model.get_image_paths()
        if i % 5 == 0:
            print('processing (%04d)-th image... %s' % (i, img_path))
        save_images(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize)

        generate_img = cv2.imread("results/generate_images/" + "temp.png")
        reshape_generate_img = cv2.resize(generate_img, original_img_shape, interpolation=cv2.INTER_CUBIC)

        cv2.imwrite("results/generate_images/" + "temp.png", reshape_generate_img)
예제 #26
0
import time
from options.train_options import TrainOptions
from data import CreateDataLoader
from models import create_model
from util.visualizer import Visualizer

if __name__ == '__main__':
    opt = TrainOptions().parse()         # opt用于处理命令行参数
    data_loader = CreateDataLoader(opt)     # data_loader用于加载数据
    dataset = data_loader.load_data()       # 加载数据集
    dataset_size = len(data_loader)         # 数据集的size
    print('#training images = %d' % dataset_size)     # training data:1096张(train, trainA, trainB均为1096张)

    model = create_model(opt)           # 创建模型,opt.model默认值是cycle_gan
    model.setup(opt)                    # 模型读取opt中的参数,并进行相关初始化操作
    visualizer = Visualizer(opt)        # 用于可视化输出
    total_steps = 0

    # 训练200个epoch
    for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()
        iter_data_time = time.time()
        epoch_iter = 0

        for i, data in enumerate(dataset):
            iter_start_time = time.time()
            if total_steps % opt.print_freq == 0:   # print_freq = 100
                t_data = iter_start_time - iter_data_time
            visualizer.reset()
            total_steps += opt.batch_size      # default batch_size = 1
            epoch_iter += opt.batch_size
예제 #27
0
from options.test_options import TestOptions
from data import CreateDataLoader
from models import create_model
from util.visualizer import save_images
from util import html


if __name__ == '__main__':
    opt = TestOptions().parse()
    opt.nThreads = 1   # test code only supports nThreads = 1
    opt.batchSize = 1  # test code only supports batchSize = 1
    opt.serial_batches = True  # no shuffle
    opt.no_flip = True  # no flip
    opt.display_id = -1  # no visdom display
    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    model = create_model(opt)
    model.setup(opt)
    # create website
    web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch))
    webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch))
    # test
    for i, data in enumerate(dataset):
        if i >= opt.how_many:
            break
        model.set_input(data)
        model.test()
        visuals = model.get_current_visuals()
        img_path = model.get_image_paths()
        if i % 5 == 0:
            print('processing (%04d)-th image... %s' % (i, img_path))
def debug_layer_images(target_dir, target_layer):
    """Save each channel in target_layer as one image into target_dir.

	Args:
		target_dir: str, target directory to save result images.
		target_layer: int, target layer. None for all layers.
	"""

    opt = TestOptions().parse()
    # hard-code some parameters for test
    opt.num_threads = 1  # test code only supports num_threads = 1
    opt.batch_size = 1  # test code only supports batch_size = 1
    opt.serial_batches = True  # no shuffle
    opt.no_flip = True  # no flip
    opt.display_id = -1  # no visdom display
    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    model = create_model(opt)
    model.setup(opt)

    # extract each layers in netG_A
    layers = list(
        list(list(list(model.netG_A.children())[0].children())[0].modules())
        [0].children())

    # print network
    print("--- Start network A->B ---")
    for i, layer in enumerate(layers):
        print("#{}: {}".format(i, layer))
    print("--- End network A->B ---")

    # prepare data: only use first data for test
    data_list_enu = enumerate(dataset)
    i, data = next(data_list_enu)
    print("--- Start data info ---")
    print("data[A].shape: ", data['A'].shape)
    print("A_paths: ", data['A_paths'])
    print("--- End data info ---")

    # compute each layers
    output = data['A']
    result = []
    for i in range(len(layers)):
        output = layers[i].cpu()(output)
        print("layer{} output shape: {}".format(i, output.shape))
        result.append(output.detach().numpy())

    # create target dir
    if not os.path.exists(target_dir):
        os.makedirs(target_dir)

    # save input image
    path = os.path.join(target_dir, 'input.jpg')
    copyfile(data['A_paths'][0], path)

    # save result image
    path = os.path.join(target_dir, 'output.jpg')
    save_final_layer_image(result[27][0], path)

    # save each layer's each channel as one image

    for i in range(len(layers)):

        if target_layer is not None and target_layer != i:
            continue

        print("Create images for layer_{}".format(i))

        layer_path = os.path.join(target_dir, "layer_{}".format(i))
        if not os.path.exists(layer_path):
            os.makedirs(layer_path)

        for target_channel in range(result[i].shape[1]):
            path = os.path.join(layer_path,
                                'channel_{}.jpg'.format(target_channel))
            save_channel_image(result[i][0, target_channel:target_channel + 1],
                               path)