Exemple #1
0
import copy
import pdb
opt = TestOptions().parse()
opt.nThreads = 1  # test code only supports nThreads = 1
opt.batchSize = 1  # test code only supports batchSize = 1
opt.serial_batches = True  # no shuffle
opt.no_flip = True  # no flip
#copy
opt2 = copy.copy(opt)
opt2.dataroot = opt.dataroot_more

data_loader = CreateDataLoader(opt)
data_loader2 = CreateDataLoader(opt2)
dataset_size = min(len(data_loader), len(data_loader2))

dataset = data_loader.load_data()
dataset2 = data_loader2.load_data()
model = create_model(opt)
visualizer = Visualizer(opt)
# create website
web_dir = os.path.join(opt.results_dir, opt.name,
                       '%s_%s' % (opt.phase, opt.which_epoch))
webpage = html.HTML(
    web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' %
    (opt.name, opt.phase, opt.which_epoch))
# test
loader = enumerate(dataset)
loader2 = enumerate(dataset2)
for steps in range(dataset_size):
    i, data = next(loader)
    i, data2 = next(loader2)
Exemple #2
0
def main():
    opt = TestOptions().parse(save=False)
    opt.nThreads = 1  # test code only supports nThreads = 1
    opt.batchSize = 1  # test code only supports batchSize = 1
    opt.serial_batches = True  # no shuffle
    opt.no_flip = True  # no flip

    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    visualizer = Visualizer(opt)
    # create website
    web_dir = os.path.join(opt.results_dir, opt.name,
                           '%s_%s' % (opt.phase, opt.which_epoch))
    webpage = html.HTML(
        web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' %
        (opt.name, opt.phase, opt.which_epoch))

    # test
    if not opt.engine and not opt.onnx:
        model = create_model(opt)
        if opt.data_type == 16:
            model.half()
        elif opt.data_type == 8:
            model.type(torch.uint8)

        if opt.verbose:
            print(model)
    else:
        from run_engine import run_trt_engine, run_onnx

    for i, data in enumerate(dataset):
        if i >= opt.how_many:
            break
        if opt.data_type == 16:
            data['label'] = data['label'].half()
            data['inst'] = data['inst'].half()
        elif opt.data_type == 8:
            data['label'] = data['label'].uint8()
            data['inst'] = data['inst'].uint8()
        if opt.export_onnx:
            print("Exporting to ONNX: ", opt.export_onnx)
            assert opt.export_onnx.endswith(
                "onnx"), "Export model file should end with .onnx"
            torch.onnx.export(model, [data['label'], data['inst']],
                              opt.export_onnx,
                              verbose=True)
            exit(0)
        minibatch = 1
        if opt.engine:
            generated = run_trt_engine(opt.engine, minibatch,
                                       [data['label'], data['inst']])
        elif opt.onnx:
            generated = run_onnx(opt.onnx, opt.data_type, minibatch,
                                 [data['label'], data['inst']])
        else:
            generated = model.inference(data['label'], data['inst'],
                                        data['image'])

        visuals = OrderedDict([
            ('input_label', util.tensor2label(data['label'][0], opt.label_nc)),
            ('synthesized_image', util.tensor2im(generated.data[0]))
        ])
        img_path = data['path']
        print('process image... %s' % img_path)
        visualizer.save_images(webpage, visuals, img_path)

    webpage.save()
Exemple #3
0
from options.train_options import TrainOptions
from data.data_loader import CreateDataLoader
from models.models import create_model
import ntpath
import os
from util import util
import shutil
import numpy as np
#from util.visualizer import Visualizer

opt = TrainOptions().parse()


# Load dataset_train
data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()
dataset_size = len(data_loader)
print('#training images = %d' % dataset_size)

# Load dataset_test
opt.batchSize = 1
opt.phase = 'test'
data_loader_test = CreateDataLoader(opt)
dataset_test = data_loader_test.load_data()
dataset_size = len(data_loader_test)
print('#testing images = %d' % dataset_size)

# Create or clear dir for saving generated samples
if os.path.exists(opt.testing_path):
    shutil.rmtree(opt.testing_path)
os.makedirs(opt.testing_path)
Exemple #4
0
def train():
    opt = TrainOptions().parse()
    if opt.debug:
        opt.display_freq = 1
        opt.print_freq = 1
        opt.nThreads = 1

    ### initialize dataset
    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    if opt.dataset_mode == 'pose':
        print('#training frames = %d' % dataset_size)
    else:
        print('#training videos = %d' % dataset_size)

    ### initialize models
    modelG, modelD, flowNet = create_model(opt)
    visualizer = Visualizer(opt)

    iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')
    ### if continue training, recover previous states
    if opt.continue_train:
        try:
            start_epoch, epoch_iter = np.loadtxt(iter_path,
                                                 delimiter=',',
                                                 dtype=int)
        except:
            start_epoch, epoch_iter = 1, 0
        print('Resuming from epoch %d at iteration %d' %
              (start_epoch, epoch_iter))
        if start_epoch > opt.niter:
            modelG.module.update_learning_rate(start_epoch - 1)
            modelD.module.update_learning_rate(start_epoch - 1)
        if (opt.n_scales_spatial > 1) and (opt.niter_fix_global != 0) and (
                start_epoch > opt.niter_fix_global):
            modelG.module.update_fixed_params()
        if start_epoch > opt.niter_step:
            data_loader.dataset.update_training_batch(
                (start_epoch - 1) // opt.niter_step)
            modelG.module.update_training_batch(
                (start_epoch - 1) // opt.niter_step)
    else:
        start_epoch, epoch_iter = 1, 0

    ### set parameters
    n_gpus = opt.n_gpus_gen // opt.batchSize  # number of gpus used for generator for each batch
    tG, tD = opt.n_frames_G, opt.n_frames_D
    tDB = tD * opt.output_nc
    s_scales = opt.n_scales_spatial
    t_scales = opt.n_scales_temporal
    input_nc = 1 if opt.label_nc != 0 else opt.input_nc
    output_nc = opt.output_nc

    opt.print_freq = lcm(opt.print_freq, opt.batchSize)
    total_steps = (start_epoch - 1) * dataset_size + epoch_iter
    total_steps = total_steps // opt.print_freq * opt.print_freq

    ### real training starts here
    for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()
        for idx, data in enumerate(dataset, start=epoch_iter):
            if total_steps % opt.print_freq == 0:
                iter_start_time = time.time()
            total_steps += opt.batchSize
            epoch_iter += opt.batchSize

            # whether to collect output images
            save_fake = total_steps % opt.display_freq == 0

            _, n_frames_total, height, width = data['B'].size(
            )  # n_frames_total = n_frames_load * n_loadings + tG - 1
            n_frames_total = n_frames_total // opt.output_nc
            n_frames_load = opt.max_frames_per_gpu * n_gpus  # number of total frames loaded into GPU at a time for each batch
            n_frames_load = min(n_frames_load, n_frames_total - tG + 1)
            t_len = n_frames_load + tG - 1  # number of loaded frames plus previous frames

            fake_B_last = None  # the last generated frame from previous training batch (which becomes input to the next batch)
            real_B_all, fake_B_all, flow_ref_all, conf_ref_all = None, None, None, None  # all real/generated frames so far
            if opt.sparse_D:
                real_B_all, fake_B_all, flow_ref_all, conf_ref_all = [
                    None
                ] * t_scales, [None] * t_scales, [None] * t_scales, [
                    None
                ] * t_scales
            real_B_skipped, fake_B_skipped = [None] * t_scales, [
                None
            ] * t_scales  # temporally subsampled frames
            flow_ref_skipped, conf_ref_skipped = [None] * t_scales, [
                None
            ] * t_scales  # temporally subsampled flows

            for i in range(0, n_frames_total - t_len + 1, n_frames_load):
                # 5D tensor: batchSize, # of frames, # of channels, height, width
                input_A = Variable(
                    data['A'][:, i * input_nc:(i + t_len) * input_nc,
                              ...]).view(-1, t_len, input_nc, height, width)
                input_B = Variable(
                    data['B'][:, i * output_nc:(i + t_len) * output_nc,
                              ...]).view(-1, t_len, output_nc, height, width)
                inst_A = Variable(data['inst'][:, i:i + t_len, ...]).view(
                    -1, t_len, 1, height,
                    width) if len(data['inst'].size()) > 2 else None

                ###################################### Forward Pass ##########################
                ####### generator
                fake_B, fake_B_raw, flow, weight, real_A, real_Bp, fake_B_last = modelG(
                    input_A, input_B, inst_A, fake_B_last)

                if i == 0:
                    fake_B_first = fake_B[
                        0, 0]  # the first generated image in this sequence
                real_B_prev, real_B = real_Bp[:, :
                                              -1], real_Bp[:,
                                                           1:]  # the collection of previous and current real frames

                ####### discriminator
                ### individual frame discriminator
                flow_ref, conf_ref = flowNet(
                    real_B, real_B_prev)  # reference flows and confidences
                fake_B_prev = real_B_prev[:, 0:
                                          1] if fake_B_last is None else fake_B_last[
                                              0][:, -1:]
                if fake_B.size()[1] > 1:
                    fake_B_prev = torch.cat(
                        [fake_B_prev, fake_B[:, :-1].detach()], dim=1)

                losses = modelD(
                    0,
                    reshape([
                        real_B, fake_B, fake_B_raw, real_A, real_B_prev,
                        fake_B_prev, flow, weight, flow_ref, conf_ref
                    ]))
                losses = [
                    torch.mean(x) if x is not None else 0 for x in losses
                ]
                loss_dict = dict(zip(modelD.module.loss_names, losses))

                ### temporal discriminator
                loss_dict_T = []
                # get skipped frames for each temporal scale
                if t_scales > 0:
                    if opt.sparse_D:
                        real_B_all, real_B_skipped = get_skipped_frames_sparse(
                            real_B_all, real_B, t_scales, tD, n_frames_load, i)
                        fake_B_all, fake_B_skipped = get_skipped_frames_sparse(
                            fake_B_all, fake_B, t_scales, tD, n_frames_load, i)
                        flow_ref_all, flow_ref_skipped = get_skipped_frames_sparse(
                            flow_ref_all,
                            flow_ref,
                            t_scales,
                            tD,
                            n_frames_load,
                            i,
                            is_flow=True)
                        conf_ref_all, conf_ref_skipped = get_skipped_frames_sparse(
                            conf_ref_all,
                            conf_ref,
                            t_scales,
                            tD,
                            n_frames_load,
                            i,
                            is_flow=True)
                    else:
                        real_B_all, real_B_skipped = get_skipped_frames(
                            real_B_all, real_B, t_scales, tD)
                        fake_B_all, fake_B_skipped = get_skipped_frames(
                            fake_B_all, fake_B, t_scales, tD)
                        flow_ref_all, conf_ref_all, flow_ref_skipped, conf_ref_skipped = get_skipped_flows(
                            flowNet, flow_ref_all, conf_ref_all,
                            real_B_skipped, flow_ref, conf_ref, t_scales, tD)

                # run discriminator for each temporal scale
                for s in range(t_scales):
                    if real_B_skipped[s] is not None:
                        losses = modelD(s + 1, [
                            real_B_skipped[s], fake_B_skipped[s],
                            flow_ref_skipped[s], conf_ref_skipped[s]
                        ])
                        losses = [
                            torch.mean(x) if not isinstance(x, int) else x
                            for x in losses
                        ]
                        loss_dict_T.append(
                            dict(zip(modelD.module.loss_names_T, losses)))

                # collect losses
                loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5
                loss_G = loss_dict['G_GAN'] + loss_dict[
                    'G_GAN_Feat'] + loss_dict['G_VGG']
                loss_G += loss_dict['G_Warp'] + loss_dict[
                    'F_Flow'] + loss_dict['F_Warp'] + loss_dict['W']
                if opt.add_face_disc:
                    loss_G += loss_dict['G_f_GAN'] + loss_dict['G_f_GAN_Feat']
                    loss_D += (loss_dict['D_f_fake'] +
                               loss_dict['D_f_real']) * 0.5

                # collect temporal losses
                loss_D_T = []
                t_scales_act = min(t_scales, len(loss_dict_T))
                for s in range(t_scales_act):
                    loss_G += loss_dict_T[s]['G_T_GAN'] + loss_dict_T[s][
                        'G_T_GAN_Feat'] + loss_dict_T[s]['G_T_Warp']
                    loss_D_T.append((loss_dict_T[s]['D_T_fake'] +
                                     loss_dict_T[s]['D_T_real']) * 0.5)

                ###################################### Backward Pass #################################
                optimizer_G = modelG.module.optimizer_G
                optimizer_D = modelD.module.optimizer_D
                # update generator weights
                optimizer_G.zero_grad()
                loss_G.backward()
                optimizer_G.step()

                # update discriminator weights
                # individual frame discriminator
                optimizer_D.zero_grad()
                loss_D.backward()
                optimizer_D.step()
                # temporal discriminator
                for s in range(t_scales_act):
                    optimizer_D_T = getattr(modelD.module,
                                            'optimizer_D_T' + str(s))
                    optimizer_D_T.zero_grad()
                    loss_D_T[s].backward()
                    optimizer_D_T.step()

            if opt.debug:
                call([
                    "nvidia-smi", "--format=csv",
                    "--query-gpu=memory.used,memory.free"
                ])

            ############## Display results and errors ##########
            ### print out errors
            if total_steps % opt.print_freq == 0:
                t = (time.time() - iter_start_time) / opt.print_freq
                errors = {
                    k: v.data.item() if not isinstance(v, int) else v
                    for k, v in loss_dict.items()
                }
                for s in range(len(loss_dict_T)):
                    errors.update({
                        k + str(s):
                        v.data.item() if not isinstance(v, int) else v
                        for k, v in loss_dict_T[s].items()
                    })
                visualizer.print_current_errors(epoch, epoch_iter, errors, t)
                visualizer.plot_current_errors(errors, total_steps)

            ### display output images
            if save_fake:
                if opt.label_nc != 0:
                    input_image = util.tensor2label(real_A[0, -1],
                                                    opt.label_nc)
                elif opt.dataset_mode == 'pose':
                    input_image = util.tensor2im(real_A[0, -1, :3])
                    if real_A.size()[2] == 6:
                        input_image2 = util.tensor2im(real_A[0, -1, 3:])
                        input_image[input_image2 != 0] = input_image2[
                            input_image2 != 0]
                else:
                    c = 3 if opt.input_nc == 3 else 1
                    input_image = util.tensor2im(real_A[0, -1, :c],
                                                 normalize=True)
                if opt.use_instance:
                    edges = util.tensor2im(real_A[0, -1, -1:, ...],
                                           normalize=True)
                    input_image += edges[:, :, np.newaxis]

                if opt.add_face_disc:
                    ys, ye, xs, xe = modelD.module.get_face_region(real_A[0,
                                                                          -1:])
                    if ys is not None:
                        input_image[ys, xs:xe, :] = input_image[
                            ye, xs:xe, :] = input_image[
                                ys:ye, xs, :] = input_image[ys:ye, xe, :] = 255

                visual_list = [
                    ('input_image', input_image),
                    ('fake_image', util.tensor2im(fake_B[0, -1])),
                    ('fake_first_image', util.tensor2im(fake_B_first)),
                    ('fake_raw_image', util.tensor2im(fake_B_raw[0, -1])),
                    ('real_image', util.tensor2im(real_B[0, -1])),
                    ('flow_ref', util.tensor2flow(flow_ref[0, -1])),
                    ('conf_ref',
                     util.tensor2im(conf_ref[0, -1], normalize=False))
                ]
                if flow is not None:
                    visual_list += [('flow', util.tensor2flow(flow[0, -1])),
                                    ('weight',
                                     util.tensor2im(weight[0, -1],
                                                    normalize=False))]
                visuals = OrderedDict(visual_list)
                visualizer.display_current_results(visuals, epoch, total_steps)

            ### save latest model
            if total_steps % opt.save_latest_freq == 0:
                visualizer.vis_print(
                    'saving the latest model (epoch %d, total_steps %d)' %
                    (epoch, total_steps))
                modelG.module.save('latest')
                modelD.module.save('latest')
                np.savetxt(iter_path, (epoch, epoch_iter),
                           delimiter=',',
                           fmt='%d')

            if epoch_iter > dataset_size - opt.batchSize:
                epoch_iter = 0
                break

        # end of epoch
        iter_end_time = time.time()
        visualizer.vis_print('End of epoch %d / %d \t Time Taken: %d sec' %
                             (epoch, opt.niter + opt.niter_decay,
                              time.time() - epoch_start_time))

        ### save model for this epoch
        if epoch % opt.save_epoch_freq == 0:
            visualizer.vis_print(
                'saving the model at the end of epoch %d, iters %d' %
                (epoch, total_steps))
            modelG.module.save('latest')
            modelD.module.save('latest')
            modelG.module.save(epoch)
            modelD.module.save(epoch)
            np.savetxt(iter_path, (epoch + 1, 0), delimiter=',', fmt='%d')

        ### linearly decay learning rate after certain iterations
        if epoch > opt.niter:
            modelG.module.update_learning_rate(epoch)
            modelD.module.update_learning_rate(epoch)

        ### gradually grow training sequence length
        if (epoch % opt.niter_step) == 0:
            data_loader.dataset.update_training_batch(epoch // opt.niter_step)
            modelG.module.update_training_batch(epoch // opt.niter_step)

        ### finetune all scales
        if (opt.n_scales_spatial > 1) and (opt.niter_fix_global != 0) and (
                epoch == opt.niter_fix_global):
            modelG.module.update_fixed_params()
def train_pose2vid(target_dir, run_name, temporal_smoothing=False):
    import src.config.train_opt as opt

    opt = update_opt(opt, target_dir, run_name, temporal_smoothing)

    iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.json')
    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    print('#training images = %d' % dataset_size)

    if opt.load_pretrain != '':
        with open(iter_path, 'r') as f:
            iter_json = json.load(f)
    else:
        iter_json = {'start_epoch': 1, 'epoch_iter': 0}

    start_epoch = iter_json['start_epoch']
    epoch_iter = iter_json['epoch_iter']
    total_steps = (start_epoch - 1) * dataset_size + epoch_iter
    display_delta = total_steps % opt.display_freq
    print_delta = total_steps % opt.print_freq
    save_delta = total_steps % opt.save_latest_freq

    model = create_model(opt)
    model = model.to(device)
    visualizer = Visualizer(opt)

    for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()
        if epoch != start_epoch:
            epoch_iter = epoch_iter % dataset_size
        for i, data in enumerate(dataset, start=epoch_iter):
            iter_start_time = time.time()
            total_steps += opt.batchSize
            epoch_iter += opt.batchSize

            # whether to collect output images
            save_fake = total_steps % opt.display_freq == display_delta

            ############## Forward Pass ######################
            if temporal_smoothing:
                losses, generated = model(Variable(data['label']),
                                          Variable(data['inst']),
                                          Variable(data['image']),
                                          Variable(data['feat']),
                                          Variable(data['previous_label']),
                                          Variable(data['previous_image']),
                                          infer=save_fake)
            else:
                losses, generated = model(Variable(data['label']),
                                          Variable(data['inst']),
                                          Variable(data['image']),
                                          Variable(data['feat']),
                                          infer=save_fake)

            # sum per device losses
            losses = [
                torch.mean(x) if not isinstance(x, int) else x for x in losses
            ]
            loss_dict = dict(zip(model.loss_names, losses))

            # calculate final loss scalar
            loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5
            loss_G = loss_dict['G_GAN'] + loss_dict.get(
                'G_GAN_Feat', 0) + loss_dict.get('G_VGG', 0)

            ############### Backward Pass ####################
            # update generator weights
            model.optimizer_G.zero_grad()
            loss_G.backward()
            model.optimizer_G.step()

            # update discriminator weights
            model.optimizer_D.zero_grad()
            loss_D.backward()
            model.optimizer_D.step()

            ############## Display results and errors ##########

            print(f"Epoch {epoch} batch {i}:")
            print(f"loss_D: {loss_D}, loss_G: {loss_G}")
            print(
                f"loss_D_fake: {loss_dict['D_fake']}, loss_D_real: {loss_dict['D_real']}"
            )
            print(
                f"loss_G_GAN {loss_dict['G_GAN']}, loss_G_GAN_Feat: {loss_dict.get('G_GAN_Feat', 0)}, loss_G_VGG: {loss_dict.get('G_VGG', 0)}\n"
            )

            ### print out errors
            if total_steps % opt.print_freq == print_delta:
                errors = {
                    k: v.item() if not isinstance(v, int) else v
                    for k, v in loss_dict.items()
                }
                # errors = {k: v.data[0] if not isinstance(v, int) else v for k, v in loss_dict.items()}
                t = (time.time() - iter_start_time) / opt.batchSize
                visualizer.print_current_errors(epoch, epoch_iter, errors, t)
                visualizer.plot_current_errors(errors, total_steps)

            ### display output images
            if save_fake:
                visuals = OrderedDict([
                    ('input_label',
                     util.tensor2label(data['label'][0], opt.label_nc)),
                    ('synthesized_image', util.tensor2im(generated.data[0])),
                    ('real_image', util.tensor2im(data['image'][0]))
                ])
                visualizer.display_current_results(visuals, epoch, total_steps)

            ### save latest model
            if total_steps % opt.save_latest_freq == save_delta:
                print('saving the latest model (epoch %d, total_steps %d)' %
                      (epoch, total_steps))
                model.save('latest')
                iter_json['start_epoch'] = epoch
                iter_json['epoch_iter'] = epoch_iter
                with open(iter_path, 'w') as f:
                    json.dump(iter_json, f)

            if epoch_iter >= dataset_size:
                break

        # end of epoch
        print('End of epoch %d / %d \t Time Taken: %d sec' %
              (epoch, opt.niter + opt.niter_decay,
               time.time() - epoch_start_time))

        ### save model for this epoch
        if epoch % opt.save_epoch_freq == 0:
            print('saving the model at the end of epoch %d, iters %d' %
                  (epoch, total_steps))
            model.save('latest')
            model.save(epoch)
            iter_json['start_epoch'] = epoch + 1
            iter_json['epoch_iter'] = 0
            with open(iter_path, 'w') as f:
                json.dump(iter_json, f)

        ### instead of only training the local enhancer, train the entire network after certain iterations
        if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global):
            model.update_fixed_params()

        ### linearly decay learning rate after certain iterations
        if epoch > opt.niter:
            model.update_learning_rate()

    torch.cuda.empty_cache()
Exemple #6
0
from options.test_options import TestOptions
from data.data_loader import CreateDataLoader
from models.create_model import create_model
from utils import visualizer

opt = TestOptions().parse()

### set dataloader ###
print('### prepare DataLoader ###')
data_loader = CreateDataLoader(opt)
test_loader = data_loader.load_data()
print('test images     : {}'.format(len(data_loader)))
print('numof_iteration : {}'.format(len(test_loader)))

### define model ###
model = create_model(opt)
model.gen.eval()

print('### start test ! ###')
for iter, data in enumerate(test_loader):
    model.set_variables(data)
    fake_image = model.forward()

    fake_image = (fake_image + 1.0) / 2.0
    visualizer.save_test_images(opt, iter, data['label'], fake_image, nrow=7)
Exemple #7
0
            else:
                break
    avg_loss = sum(losses) / len(losses)
    if opt.tensorboard:
        writer.add_scalar('data/val_loss', avg_loss, index)
    print('val loss: %.3f' % avg_loss)
    return avg_loss


#parse arguments
opt = TrainOptions().parse()
opt.device = torch.device("cuda")

#construct data loader
data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()
dataset_size = len(data_loader)
print('#training clips = %d' % dataset_size)

#create validation set data loader if validation_on option is set
if opt.validation_on:
    #temperally set to val to load val data
    opt.mode = 'val'
    data_loader_val = CreateDataLoader(opt)
    dataset_val = data_loader_val.load_data()
    dataset_size_val = len(data_loader_val)
    print('#validation clips = %d' % dataset_size_val)
    opt.mode = 'train'  #set it back

if opt.tensorboard:
    from tensorboardX import SummaryWriter
Exemple #8
0
import torch.backends.cudnn as cudnn
from tqdm import tqdm

from data.data_loader import CreateDataLoader
from models.models import create_model
from options.train_options import TrainOptions

cudnn.benchmark = True

opt = TrainOptions().parse()
ROOT = '/data1'
testDataLoader = CreateDataLoader(opt, {'mode':'Test', 'labelFn':'in5008.txt', 'rootPath':ROOT, 'subDir':'eval'})
#testDataLoader = CreateDataLoader(opt, {'mode':'Train', 'labelFn':'in5008.txt', 'rootPath':ROOT, 'subDir':'train'})
 

testDataset = testDataLoader.load_data()
print('#testing images = %d' % len(testDataLoader))

model = create_model(opt)

totalWer = list()
totalCer = list()
for i, data in tqdm(enumerate(testDataset), total=len(testDataset)):
    model.set_input(data)
    results = model.test()
    totalWer.append(results['wer'])
    totalCer.append(results['cer'])

cer = sum(totalCer) * 100 / len(totalCer)
wer = sum(totalWer) * 100 / len(totalWer)
print('############################')
Exemple #9
0
                    type=int,
                    default=100,
                    help='frequency of showing training results on console')
parser.add_argument('--ckpt_path', default='')

opt = parser.parse_args()
print(opt)

train_visual = Visualizer(opt.train_display_id, 'train', 5)
val_visual = Visualizer(opt.val_display_id, 'val', 5)

if not os.path.exists(opt.ckpt_path):
    os.makedirs(opt.ckpt_path)

data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()
dataset_size = len(data_loader)

opt.phase = 'val'
#opt.pahse = 'val'
val_data_loader = CreateDataLoader(opt)
val_dataset = val_data_loader.load_data()

# val_loader = DataLoader(val_set, batch_size=1, num_workers=12, pin_memory=True)

## define models
# 1---256x256 stage
# 2---128x128 stage
# 3---64x64 stage
GA = nets.define_G(input_nc=3,
                   output_nc=3,
Exemple #10
0
def train():
    opt = TrainOptions().parse()
    if opt.debug:
        opt.display_freq = 1
        opt.print_freq = 1    
        opt.nThreads = 1

    ### initialize dataset
    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)    
    print('#training frames = %d' % dataset_size)

    ### initialize models
    models = create_model_cloth(opt)
    ClothWarper, ClothWarperLoss, flowNet, optimizer = create_optimizer_cloth(opt, models)

    ### set parameters    
    n_gpus, tG, input_nc_1, input_nc_2, input_nc_3, start_epoch, epoch_iter, print_freq, total_steps, iter_path, tD, t_scales = init_params(opt, ClothWarper, data_loader)
    visualizer = Visualizer(opt)    

    ### real training starts here  
    for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()    
        for idx, data in enumerate(dataset, start=epoch_iter):        
            if total_steps % print_freq == 0:
                iter_start_time = time.time()
            total_steps += opt.batchSize
            epoch_iter += opt.batchSize

            # whether to collect output images
            save_fake = total_steps % opt.display_freq == 0
            n_frames_total, n_frames_load, t_len = data_loader.dataset.init_data_params_cloth(data, n_gpus, tG)
            flow_total_prev_last, frames_all = data_loader.dataset.init_data_cloth(t_scales)

            for i in range(0, n_frames_total, n_frames_load):
                is_first_frame = flow_total_prev_last is None
                input_TParsing, input_TFG, input_SParsing, input_SFG, input_SFG_full = data_loader.dataset.prepare_data_cloth(data, i)

                ###################################### Forward Pass ##########################
                ####### C2F-FWN                  
                fg_tps, fg_dense, lo_tps, lo_dense, flow_tps, flow_dense, flow_totalp, real_input_1, real_input_2, real_SFG, real_SFG_fullp, flow_total_last = ClothWarper(input_TParsing, input_TFG, input_SParsing, input_SFG, input_SFG_full, flow_total_prev_last)
                real_SLO = real_input_2[:, :, -opt.label_nc_2:]
                ####### compute losses
                ### individual frame losses and FTC loss with l=1
                real_SFG_full_prev, real_SFG_full = real_SFG_fullp[:, :-1], real_SFG_fullp[:, 1:]   # the collection of previous and current real frames
                flow_optical_ref, conf_optical_ref = flowNet(real_SFG_full, real_SFG_full_prev)       # reference flows and confidences                
                
                flow_total_prev, flow_total = flow_totalp[:, :-1], flow_totalp[:, 1:]
                if is_first_frame:
                    flow_total_prev = flow_total_prev[:, 1:]

                flow_total_prev_last = flow_total_last
                
                losses, flows_sampled_0 = ClothWarperLoss(0, reshape([real_SFG, real_SLO, fg_tps, fg_dense, lo_tps, lo_dense, flow_tps, flow_dense, flow_total, flow_total_prev, flow_optical_ref, conf_optical_ref]), is_first_frame)
                losses = [ torch.mean(x) if x is not None else 0 for x in losses ]
                loss_dict = dict(zip(ClothWarperLoss.module.loss_names, losses))          

                ### FTC losses with l=3,9
                # get skipped frames for each temporal scale
                frames_all, frames_skipped = ClothWarperLoss.module.get_all_skipped_frames(frames_all, \
                        real_SFG_full, flow_total, flow_optical_ref, conf_optical_ref, real_SLO, t_scales, tD, n_frames_load, i, flowNet)                                

                # compute losses for l=3,9
                loss_dict_T = []
                for s in range(1, t_scales):                
                    if frames_skipped[0][s] is not None and not opt.tps_only:                        
                        losses, flows_sampled_1 = ClothWarperLoss(s+1, [frame_skipped[s] for frame_skipped in frames_skipped], False)
                        losses = [ torch.mean(x) if not isinstance(x, int) else x for x in losses ]
                        loss_dict_T.append(dict(zip(ClothWarperLoss.module.loss_names_T, losses)))                  

                # collect losses
                loss, _ = ClothWarperLoss.module.get_losses(loss_dict, loss_dict_T, t_scales-1)

                ###################################### Backward Pass #################################                 
                # update generator weights     
                loss_backward(opt, loss, optimizer)                

                if i == 0: fg_dense_first = fg_dense[0, 0]   # the first generated image in this sequence


            if opt.debug:
                call(["nvidia-smi", "--format=csv", "--query-gpu=memory.used,memory.free"]) 

            ############## Display results and errors ##########
            ### print out errors
            if total_steps % print_freq == 0:
                t = (time.time() - iter_start_time) / print_freq
                errors = {k: v.data.item() if not isinstance(v, int) else v for k, v in loss_dict.items()}
                for s in range(len(loss_dict_T)):
                    errors.update({k+str(s): v.data.item() if not isinstance(v, int) else v for k, v in loss_dict_T[s].items()})

                loss_names_vis = ClothWarperLoss.module.loss_names.copy()
                {loss_names_vis.append(ClothWarperLoss.module.loss_names_T[0]+str(idx)) for idx in range(len(loss_dict_T))}
                visualizer.print_current_errors_new(epoch, epoch_iter, errors, loss_names_vis, t)
                visualizer.plot_current_errors(errors, total_steps)
            ### display output images
            if save_fake:                
                visuals = util.save_all_tensors_cloth(opt, real_input_1, real_input_2, fg_tps, fg_dense, lo_tps, lo_dense, fg_dense_first, real_SFG, real_SFG_full, flow_tps, flow_dense, flow_total)            
                visualizer.display_current_results(visuals, epoch, total_steps)

            ### save latest model
            save_models_cloth(opt, epoch, epoch_iter, total_steps, visualizer, iter_path, ClothWarper)            
            if epoch_iter > dataset_size - opt.batchSize:
                epoch_iter = 0
                break
           
        # end of epoch 
        iter_end_time = time.time()
        visualizer.vis_print('End of epoch %d / %d \t Time Taken: %d sec' %
              (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))

        ### save model for this epoch and update model params
        save_models_cloth(opt, epoch, epoch_iter, total_steps, visualizer, iter_path, ClothWarper, end_of_epoch=True)
        update_models_cloth(opt, epoch, ClothWarper, data_loader) 
Exemple #11
0
def train_target(param_dir, if_train, loadSize, continue_train = False):

    # import pix2pixHD module
    pix2pixhd_dir = Path('../src/pix2pixHD/')  # should download github code 'pix2pixHD'

    import sys
    sys.path.append(str(pix2pixhd_dir))

    #get_ipython().run_line_magic('load_ext', 'autoreload')
    #get_ipython().run_line_magic('autoreload', '2')

    from options.train_options import TrainOptions
    from data.data_loader import CreateDataLoader
    from models.models import create_model
    import util.util as util
    from util.visualizer import Visualizer
    from util import html

    with open(param_dir, mode='rb') as f:
        opt = pickle.load(f)
        opt.loadSize = loadSize
        opt.fineSize = loadSize
        #opt.fine_size = loadSize
        #opt.loadsize = loadSize
        opt.continue_train = continue_train
        print('opt parameters: ', opt)

    iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')

    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    
    model = create_model(opt)
    visualizer = Visualizer(opt)
        
    if if_train:
    
        print('#training images = %d' % dataset_size)
        if opt.continue_train:
            print ('Resumming training ...')
        else:
            print ('Starting new training ...')

        start_epoch, epoch_iter = 1, 0
        total_steps = (start_epoch-1) * dataset_size + epoch_iter
        display_delta = total_steps % opt.display_freq
        print_delta = total_steps % opt.print_freq
        save_delta = total_steps % opt.save_latest_freq
        
        for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):  # (1,20+20+1)
            epoch_start_time = time.time()
            if epoch != start_epoch:
                epoch_iter = epoch_iter % dataset_size
            for i, data in enumerate(dataset, start=epoch_iter):
                iter_start_time = time.time()
                total_steps += opt.batchSize
                epoch_iter += opt.batchSize

                # whether to collect output images
                save_fake = total_steps % opt.display_freq == display_delta
                
                ############## Forward Pass ######################
                losses, generated = model(Variable(data['label']), Variable(data['inst']), 
                    Variable(data['image']), Variable(data['feat']), infer=save_fake)
                
                # sum per device losses
                losses = [ torch.mean(x) if not isinstance(x, int) else x for x in losses ]
                loss_dict = dict(zip(model.module.loss_names, losses))

                # calculate final loss scalar
                loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5
                loss_G = loss_dict['G_GAN'] + loss_dict.get('G_GAN_Feat',0) + loss_dict.get('G_VGG',0)
                
                ############### Backward Pass ####################
                # update generator weights
                model.module.optimizer_G.zero_grad()
                loss_G.backward()
                model.module.optimizer_G.step()

                # update discriminator weights
                model.module.optimizer_D.zero_grad()
                loss_D.backward()
                model.module.optimizer_D.step()
                
                #call(["nvidia-smi", "--format=csv", "--query-gpu=memory.used,memory.free"]) 

                ############## Display results and errors ##########
                ### print out errors
                if total_steps % opt.print_freq == print_delta:
                    errors = {k: v.data[0] if not isinstance(v, int) else v for k, v in loss_dict.items()}
                    t = (time.time() - iter_start_time) / opt.batchSize
                    visualizer.print_current_errors(epoch, epoch_iter, errors, t)
                    visualizer.plot_current_errors(errors, total_steps)

                ### display output images
                if save_fake:
                    visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)),
                                           ('synthesized_image', util.tensor2im(generated.data[0])),
                                           ('real_image', util.tensor2im(data['image'][0]))])
                    visualizer.display_current_results(visuals, epoch, total_steps)

                ### save latest model
                if total_steps % opt.save_latest_freq == save_delta:
                    print('saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps))
                    model.module.save('latest')            
                    np.savetxt(iter_path, (epoch, epoch_iter), delimiter=',', fmt='%d')

                if epoch_iter >= dataset_size:
                    break
               
            # end of epoch 
            iter_end_time = time.time()
            print('End of epoch %d / %d \t Time Taken: %d sec' %
                  (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))

            ### save model for this epoch
            if epoch % opt.save_epoch_freq == 0:  # opt.save_epoch_freq == 10
                print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps))        
                model.module.save('latest')
                model.module.save(epoch)
                np.savetxt(iter_path, (epoch+1, 0), delimiter=',', fmt='%d')

            ### instead of only training the local enhancer, train the entire network after certain iterations
            if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global):
                model.module.update_fixed_params()

            ### linearly decay learning rate after certain iterations
            if epoch > opt.niter:
                model.module.update_learning_rate()
                
        torch.cuda.empty_cache()
    else:
        
        # create website
        web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch))
        webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch))

        for data in tqdm(dataset):
            minibatch = 1 
            generated = model.inference(data['label'], data['inst'])
                
            visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)),
                                   ('synthesized_image', util.tensor2im(generated.data[0]))])
            img_path = data['path']
            visualizer.save_images(webpage, visuals, img_path)
        webpage.save()
        torch.cuda.empty_cache()
def prepare_face_enhancer_data(target_dir, run_name):

    face_sync_dir = os.path.join(target_dir, 'face')
    os.makedirs(face_sync_dir, exist_ok=True)
    test_sync_dir = os.path.join(face_sync_dir, 'test_sync')
    os.makedirs(test_sync_dir, exist_ok=True)
    test_real_dir = os.path.join(face_sync_dir, 'test_real')
    os.makedirs(test_real_dir, exist_ok=True)
    test_img = os.path.join(target_dir, 'test_img')
    os.makedirs(test_img, exist_ok=True)
    test_label = os.path.join(target_dir, 'test_label')
    os.makedirs(test_label, exist_ok=True)

    transfer_face_sync_dir = os.path.join(target_dir, 'face_transfer')
    os.makedirs(transfer_face_sync_dir, exist_ok=True)
    transfer_test_sync_dir = os.path.join(transfer_face_sync_dir, 'test_sync')
    os.makedirs(transfer_test_sync_dir, exist_ok=True)
    transfer_test_real_dir = os.path.join(transfer_face_sync_dir, 'test_real')
    os.makedirs(transfer_test_real_dir, exist_ok=True)

    train_dir = os.path.join(target_dir, 'train', 'train_img')
    label_dir = os.path.join(target_dir, 'train', 'train_label')

    print('Prepare test_real....')
    for img_file in tqdm(sorted(os.listdir(train_dir))):
        img_idx = int(img_file.split('.')[0])
        img = cv2.imread(os.path.join(train_dir, '{:05}.png'.format(img_idx)))
        label = cv2.imread(os.path.join(label_dir,
                                        '{:05}.png'.format(img_idx)))
        cv2.imwrite(os.path.join(test_real_dir, '{:05}.png'.format(img_idx)),
                    img)
        cv2.imwrite(
            os.path.join(transfer_test_real_dir, '{:05}.png'.format(img_idx)),
            img)
        cv2.imwrite(os.path.join(test_img, '{:05}.png'.format(img_idx)), img)
        cv2.imwrite(os.path.join(test_label, '{:05}.png'.format(img_idx)),
                    label)

    print('Prepare test_sync....')

    import src.config.test_opt as opt
    if device == torch.device('cpu'):
        opt.gpu_ids = []
    else:
        os.environ['CUDA_VISIBLE_DEVICES'] = "0"

    opt.checkpoints_dir = os.path.join(dir_name, '../../checkpoints/')
    opt.dataroot = target_dir
    opt.name = run_name
    opt.nThreads = 0
    opt.results_dir = os.path.join(dir_name, '../../face_enhancer_results/')

    iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')

    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    visualizer = Visualizer(opt)

    web_dir = os.path.join(opt.results_dir, opt.name,
                           '%s_%s' % (opt.phase, opt.which_epoch))
    webpage = html.HTML(
        web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' %
        (opt.name, opt.phase, opt.which_epoch))

    model = create_model(opt)

    for data in tqdm(dataset):
        minibatch = 1
        generated = model.inference(data['label'], data['inst'])

        visuals = OrderedDict([('synthesized_image',
                                util.tensor2im(generated.data[0]))])
        img_path = data['path']
        visualizer.save_images(webpage, visuals, img_path)
    webpage.save()
    torch.cuda.empty_cache()

    print(f'Copy the synthesized images in {test_sync_dir}...')
    synthesized_image_dir = os.path.join(dir_name,
                                         '../../face_enhancer_results',
                                         run_name, 'test_latest/images/')
    img_list = [
        f for f in os.listdir(synthesized_image_dir)
        if f.endswith('synthesized_image.jpg')
    ]
    for img_file in tqdm(sorted(img_list)):
        img_idx = int(img_file.split('_')[0])
        img = cv2.imread(
            os.path.join(synthesized_image_dir,
                         '{:05}_synthesized_image.jpg'.format(img_idx)))
        cv2.imwrite(os.path.join(test_sync_dir, '{:05}.png'.format(img_idx)),
                    img)

    print('Copy transfer_test_sync')
    previous_run_img_dir = os.path.join(dir_name, '../../results', run_name,
                                        'test_latest/images/')
    img_list = [
        f for f in os.listdir(previous_run_img_dir)
        if f.endswith('synthesized_image.jpg')
    ]
    for img_file in tqdm(sorted(img_list)):
        img_idx = int(img_file.split('_')[0])
        img = cv2.imread(
            os.path.join(previous_run_img_dir,
                         '{:05}_synthesized_image.jpg'.format(img_idx)))
        cv2.imwrite(
            os.path.join(transfer_test_sync_dir, '{:05}.png'.format(img_idx)),
            img)
Exemple #13
0
                                             dtype=int)
    except:
        start_epoch, epoch_iter = 1, 0

    try:
        best_iou = np.loadtxt(ioupath_path, dtype=float)
    except:
        best_iou = 0.
    print('Resuming from epoch %d at iteration %d, previous best IoU %f' %
          (start_epoch, epoch_iter, best_iou))
else:
    start_epoch, epoch_iter = 1, 0
    best_iou = 0.

data_loader = CreateDataLoader(opt)
dataset, dataset_val = data_loader.load_data()
dataset_size = len(dataset)
print('#training images = %d' % dataset_size)

model = create_model(opt, dataset.dataset)
# print (model)
visualizer = Visualizer(opt)
total_steps = (start_epoch - 1) * dataset_size + epoch_iter
for epoch in range(start_epoch, opt.nepochs):
    epoch_start_time = time.time()
    if epoch != start_epoch:
        epoch_iter = epoch_iter % dataset_size

    model.model.train()
    for i, data in enumerate(dataset, start=epoch_iter):
        iter_start_time = time.time()
Exemple #14
0
def main():
    with open('./train/train_opt.pkl', mode='rb') as f:
        opt = pickle.load(f)
        opt.checkpoints_dir = './checkpoints/'
        opt.dataroot = './train'
        opt.no_flip = True
        opt.label_nc = 0
        opt.batchSize = 2
        print(opt)

    iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')

    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    print('#training images = %d' % dataset_size)

    start_epoch, epoch_iter = 1, 0
    total_steps = (start_epoch - 1) * dataset_size + epoch_iter
    display_delta = total_steps % opt.display_freq
    print_delta = total_steps % opt.print_freq
    save_delta = total_steps % opt.save_latest_freq
    best_loss = 999999
    epoch_loss = 9999999999
    model = create_model(opt)
    model = model.cuda()
    visualizer = Visualizer(opt)
    #niter = 20,niter_decay = 20
    for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()
        if epoch != start_epoch:
            epoch_iter = epoch_iter % dataset_size
        for i, data in enumerate(dataset, start=epoch_iter):
            iter_start_time = time.time()
            total_steps += opt.batchSize
            epoch_iter += opt.batchSize

            # whether to collect output images
            save_fake = total_steps % opt.display_freq == display_delta

            ############## Forward Pass ######################
            losses, generated = model(Variable(data['label']), Variable(data['inst']),
                                      Variable(data['image']), Variable(data['feat']), infer=save_fake)

            # sum per device losses
            losses = [torch.mean(x) if not isinstance(x, int) else x for x in losses]
            loss_dict = dict(zip(model.loss_names, losses))

            # calculate final loss scalar
            loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5
            loss_G = loss_dict['G_GAN'] + loss_dict.get('G_GAN_Feat', 0) + loss_dict.get('G_VGG', 0)
            loss_DG = loss_D + loss_G

            ############### Backward Pass ####################
            # update generator weights
            model.optimizer_G.zero_grad()
            loss_G.backward()
            model.optimizer_G.step()

            # update discriminator weights
            model.optimizer_D.zero_grad()
            loss_D.backward()
            model.optimizer_D.step()

            # call(["nvidia-smi", "--format=csv", "--query-gpu=memory.used,memory.free"])

            ############## Display results and errors ##########

            ### print out errors
            if total_steps % opt.print_freq == print_delta:
                errors = {k: v.data[0] if not isinstance(v, int) else v for k, v in loss_dict.items()}
                t = (time.time() - iter_start_time) / opt.batchSize
                visualizer.print_current_errors(epoch, epoch_iter, errors, t)
                visualizer.plot_current_errors(errors, total_steps)

            ### display output images
            if save_fake:
                visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)),
                                       ('synthesized_image', util.tensor2im(generated.data[0])),
                                       ('real_image', util.tensor2im(data['image'][0]))])
                visualizer.display_current_results(visuals, epoch, total_steps)

            ### save latest model
            if total_steps % opt.save_latest_freq == save_delta and loss_DG<best_loss:
                best_loss = loss_DG
                print('saving the latest model (epoch %d, total_steps %d ,total loss %g)' % (epoch, total_steps,loss_DG.item()))
                model.save('latest')
                np.savetxt(iter_path, (epoch, epoch_iter), delimiter=',', fmt='%d')

            if epoch_iter >= dataset_size:
                break

        # end of epoch
        iter_end_time = time.time()
        print('End of epoch %d / %d \t Time Taken: %d sec' %
              (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))

        ### save model for this epoch
        if epoch % opt.save_epoch_freq == 0:

            print('saving the model at the end of epoch %d, iters %d ' % (epoch, total_steps))
            model.save('latest')
            model.save(epoch)
            np.savetxt(iter_path, (epoch + 1, 0), delimiter=',', fmt='%d')

        ### instead of only training the local enhancer, train the entire network after certain iterations
        if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global):
            model.update_fixed_params()

        ### linearly decay learning rate after certain iterations
        if epoch > opt.niter:
            model.update_learning_rate()

    torch.cuda.empty_cache()
Exemple #15
0
def test(opt):
    opt.nThreads = 1  # test code only supports nThreads = 1
    opt.batchSize = 1  # test code only supports batchSize = 1
    opt.serial_batches = True  # no shuffle
    opt.no_flip = True  # no flip

    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    print('#test batches = %d' % (int(dataset_size / len(opt.sort_order))))
    visualizer = Visualizer(opt)
    model = create_model(opt)
    model.eval()

    # create webpage
    if opt.random_seed != -1:
        exp_dir = '%s_%s_seed%s' % (opt.phase, opt.which_epoch,
                                    str(opt.random_seed))
    else:
        exp_dir = '%s_%s' % (opt.phase, opt.which_epoch)
    web_dir = os.path.join(opt.results_dir, opt.name, exp_dir)

    if opt.traverse or opt.deploy:
        if opt.traverse:
            out_dirname = 'traversal'
        else:
            out_dirname = 'deploy'
        output_dir = os.path.join(web_dir, out_dirname)
        if not os.path.isdir(output_dir):
            os.makedirs(output_dir)

        for image_path in opt.image_path_list:
            print(image_path)
            data = dataset.dataset.get_item_from_path(image_path)
            visuals = model.inference(data)
            if opt.traverse and opt.make_video:
                out_path = os.path.join(
                    output_dir,
                    os.path.splitext(os.path.basename(image_path))[0] + '.mp4')
                visualizer.make_video(visuals, out_path)
            elif opt.traverse or (opt.deploy and opt.full_progression):
                if opt.traverse and opt.compare_to_trained_outputs:
                    out_path = os.path.join(
                        output_dir,
                        os.path.splitext(os.path.basename(image_path))[0] +
                        '_compare_to_{}_jump_{}.png'.format(
                            opt.compare_to_trained_class,
                            opt.trained_class_jump))
                else:
                    out_path = os.path.join(
                        output_dir,
                        os.path.splitext(os.path.basename(image_path))[0] +
                        '.png')
                visualizer.save_row_image(visuals,
                                          out_path,
                                          traverse=opt.traverse)
            else:
                out_path = os.path.join(output_dir,
                                        os.path.basename(image_path[:-4]))
                visualizer.save_images_deploy(visuals, out_path)
    else:
        webpage = html.HTML(
            web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' %
            (opt.name, opt.phase, opt.which_epoch))

        # test
        for i, data in enumerate(dataset):
            if i >= opt.how_many:
                break

            visuals = model.inference(data)
            img_path = data['Paths']
            rem_ind = []
            for i, path in enumerate(img_path):
                if path != '':
                    print('process image... %s' % path)
                else:
                    rem_ind += [i]

            for ind in reversed(rem_ind):
                del img_path[ind]

            visualizer.save_images(webpage, visuals, img_path)

            webpage.save()
def train():
    opt = TrainOptions().parse()

    if opt.distributed:
        init_dist()
        print('batch size per GPU: %d' % opt.batchSize)
    torch.backends.cudnn.benchmark = True

    ### setup dataset
    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    pose = 'pose' in opt.dataset_mode

    ### setup trainer
    trainer = Trainer(opt, data_loader)

    ### setup models
    model, flowNet, [optimizer_G,
                     optimizer_D] = create_model(opt, trainer.start_epoch)
    flow_gt = conf_gt = [None] * 2

    for epoch in range(trainer.start_epoch, opt.niter + opt.niter_decay + 1):
        if opt.distributed:
            dataset.sampler.set_epoch(epoch)
        trainer.start_of_epoch(epoch, model, data_loader)
        n_frames_total, n_frames_load = data_loader.dataset.n_frames_total, opt.n_frames_per_gpu
        for idx, data in enumerate(dataset, start=trainer.epoch_iter):
            data = trainer.start_of_iter(data)

            if not opt.no_flow_gt:
                data_list = [
                    data['tgt_label'], data['ref_label']
                ] if pose else [data['tgt_image'], data['ref_image']]
                flow_gt, conf_gt = flowNet(data_list, epoch)
            data_list = [
                data['tgt_label'], data['tgt_image'], flow_gt, conf_gt
            ]
            data_ref_list = [data['ref_label'], data['ref_image']]
            data_prev = [None, None, None]

            ############## Forward Pass ######################
            for t in range(0, n_frames_total, n_frames_load):
                data_list_t = get_data_t(data_list, n_frames_load,
                                         t) + data_ref_list + data_prev

                d_losses = model(data_list_t, mode='discriminator')
                d_losses = loss_backward(opt, d_losses, optimizer_D, 1)

                g_losses, generated, data_prev = model(
                    data_list_t, save_images=trainer.save, mode='generator')
                g_losses = loss_backward(opt, g_losses, optimizer_G, 0)

            loss_dict = dict(
                zip(model.module.lossCollector.loss_names,
                    g_losses + d_losses))

            if trainer.end_of_iter(loss_dict,
                                   generated + data_list + data_ref_list,
                                   model):
                break
        trainer.end_of_epoch(model)
def train():
    opt = TrainOptions().parse()
    iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')
    if opt.continue_train:
        try:
            start_epoch, epoch_iter = np.loadtxt(iter_path,
                                                 delimiter=',',
                                                 dtype=int)
        except:
            start_epoch, epoch_iter = 1, 0
        # compute resume lr
        if start_epoch > opt.niter:
            lrd_unit = opt.lr / opt.niter_decay
            resume_lr = opt.lr - (start_epoch - opt.niter) * lrd_unit
            opt.lr = resume_lr
        print('Resuming from epoch %d at iteration %d' %
              (start_epoch, epoch_iter))
    else:
        start_epoch, epoch_iter = 1, 0

    opt.print_freq = lcm(opt.print_freq, opt.batchSize)
    if opt.debug:
        opt.display_freq = 2
        opt.print_freq = 2
        opt.niter = 3
        opt.niter_decay = 0
        opt.max_dataset_size = 1
        opt.valSize = 1

    ## Loading data
    # train data
    data_loader = CreateDataLoader(opt, isVal=False)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    print('# training images = %d' % dataset_size)
    # validation data
    data_loader = CreateDataLoader(opt, isVal=True)
    valset = data_loader.load_data()
    print('# validation images = %d' % len(data_loader))

    ## Loading model
    model = create_model(opt)
    visualizer = Visualizer(opt)
    if opt.fp16:
        from apex import amp
        model, [optimizer_G, optimizer_D
                ] = amp.initialize(model,
                                   [model.optimizer_G, model.optimizer_D],
                                   opt_level='O1')
        model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids)
    else:
        optimizer_G, optimizer_D = model.module.optimizer_G, model.module.optimizer_D

    total_steps = (start_epoch - 1) * dataset_size + epoch_iter

    display_delta = total_steps % opt.display_freq
    print_delta = total_steps % opt.print_freq
    save_delta = total_steps % opt.save_latest_freq

    for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()
        if epoch != start_epoch:
            # epoch_iter = epoch_iter % dataset_size
            epoch_iter = 0
        for i, data in enumerate(dataset, start=epoch_iter):
            if total_steps % opt.print_freq == print_delta:
                iter_start_time = time.time()

            total_steps += opt.batchSize
            epoch_iter += opt.batchSize

            # whether to collect output images
            save_fake = total_steps % opt.display_freq == display_delta

            ############## Forward Pass ######################
            model = model.train()
            losses, generated, metrics = model(data['A'],
                                               data['B'],
                                               data['geometry'],
                                               infer=False)

            # sum per device losses and metrics
            losses = [
                torch.mean(x) if not isinstance(x, int) else x for x in losses
            ]
            metric_dict = {k: torch.mean(v) for k, v in metrics.items()}
            loss_dict = dict(zip(model.module.loss_names, losses))

            # calculate final loss scalar
            loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5
            loss_G = loss_dict['G_GAN'] + loss_dict.get(
                'G_GAN_Feat', 0) + opt.vgg_weight * loss_dict.get('G_VGG', 0)

            ############### Backward Pass ####################
            # update generator weights
            optimizer_G.zero_grad()
            if opt.fp16:
                with amp.scale_loss(loss_G, optimizer_G) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss_G.backward()
            optimizer_G.step()

            # update discriminator weights
            optimizer_D.zero_grad()
            if opt.fp16:
                with amp.scale_loss(loss_D, optimizer_D) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss_D.backward()
            optimizer_D.step()

            ############## Display results and errors ##########
            ### print out errors
            if total_steps % opt.print_freq == print_delta:
                errors = {
                    k: v.data.item() if not isinstance(v, int) else v
                    for k, v in loss_dict.items()
                }
                metrics_ = {
                    k: v.data.item() if not isinstance(v, int) else v
                    for k, v in metric_dict.items()
                }
                t = (time.time() - iter_start_time) / opt.print_freq
                visualizer.print_current_errors(epoch, epoch_iter, errors, t)
                visualizer.plot_current_errors(errors, total_steps)
                visualizer.print_current_metrics(epoch, epoch_iter, metrics_,
                                                 t)
                visualizer.plot_current_metrics(metrics_, total_steps)
                #call(["nvidia-smi", "--format=csv", "--query-gpu=memory.used,memory.free"])

            ### display output images
            if save_fake:
                if opt.task_type == 'specular':
                    visuals = OrderedDict([
                        ('albedo', util.tensor2im(data['A'][0])),
                        ('generated',
                         util.tensor2im_exr(generated.data[0], type=1)),
                        ('GT', util.tensor2im_exr(data['B'][0], type=1))
                    ])
                elif opt.task_type == 'low':
                    visuals = OrderedDict([
                        ('albedo', util.tensor2im(data['A'][0])),
                        ('generated',
                         util.tensor2im_exr(generated.data[0], type=2)),
                        ('GT', util.tensor2im_exr(data['B'][0], type=2))
                    ])
                elif opt.task_type == 'high':
                    visuals = OrderedDict([
                        ('albedo', util.tensor2im(data['A'][0])),
                        ('generated',
                         util.tensor2im_exr(generated.data[0], type=3)),
                        ('GT', util.tensor2im_exr(data['B'][0], type=3))
                    ])
                visualizer.display_current_results(visuals, epoch, total_steps)

            ### save latest model
            if total_steps % opt.save_latest_freq == save_delta:
                print('saving the latest model (epoch %d, total_steps %d)' %
                      (epoch, total_steps))
                model.module.save('latest')
                np.savetxt(iter_path, (epoch, epoch_iter),
                           delimiter=',',
                           fmt='%d')

            if epoch_iter >= dataset_size:
                break

        # end of epoch
        print('End of epoch %d / %d \t Time Taken: %d sec' %
              (epoch, opt.niter + opt.niter_decay,
               time.time() - epoch_start_time))

        ###########################################################################################
        # validation at the end of each epoch
        val_start_time = time.time()
        metrics_val = []
        for _, val_data in enumerate(valset):
            model = model.eval()
            # model.half()
            generated, metrics = model(val_data['A'],
                                       val_data['B'],
                                       val_data['geometry'],
                                       infer=True)
            metric_dict = {k: torch.mean(v) for k, v in metrics.items()}
            metrics_ = {
                k: v.data.item() if not isinstance(v, int) else v
                for k, v in metric_dict.items()
            }
            metrics_val.append(metrics_)
        # Print out losses
        metrics_val = visualizer.mean4dict(metrics_val)
        t = (time.time() - val_start_time) / opt.print_freq
        visualizer.print_current_metrics(epoch,
                                         epoch_iter,
                                         metrics_val,
                                         t,
                                         isVal=True)
        visualizer.plot_current_metrics(metrics_val, total_steps, isVal=True)
        # visualization
        if opt.task_type == 'specular':
            visuals = OrderedDict([
                ('albedo', util.tensor2im(val_data['A'][0])),
                ('generated', util.tensor2im_exr(generated.data[0], type=1)),
                ('GT', util.tensor2im_exr(val_data['B'][0], type=1))
            ])
        if opt.task_type == 'low':
            visuals = OrderedDict([
                ('albedo', util.tensor2im(val_data['A'][0])),
                ('generated', util.tensor2im_exr(generated.data[0], type=2)),
                ('GT', util.tensor2im_exr(val_data['B'][0], type=2))
            ])
        if opt.task_type == 'high':
            visuals = OrderedDict([
                ('albedo', util.tensor2im(val_data['A'][0])),
                ('generated', util.tensor2im_exr(generated.data[0], type=3)),
                ('GT', util.tensor2im_exr(val_data['B'][0], type=3))
            ])
        visualizer.display_current_results(visuals, epoch, epoch, isVal=True)
        ###########################################################################################

        ### save model for this epoch
        if epoch % opt.save_epoch_freq == 0:
            print('saving the model at the end of epoch %d, iters %d' %
                  (epoch, total_steps))
            model.module.save('latest')
            model.module.save(epoch)
            np.savetxt(iter_path, (epoch + 1, 0), delimiter=',', fmt='%d')

        ### instead of only training the local enhancer, train the entire network after certain iterations
        if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global):
            model.module.update_fixed_params()

        ### linearly decay learning rate after certain iterations
        if epoch > opt.niter:
            model.module.update_learning_rate()
from pdb import set_trace as st
from util import html

opt = TrainOptions().parse()

opt.manualSeed = random.randint(1, 10000) # fix seed
print("Random Seed: ", opt.manualSeed)
random.seed(opt.manualSeed)
torch.manual_seed(opt.manualSeed)

avi2pngs(opt)
if opt.file_list != None:
     opt.dataroot = opt.dataroot + '/split'

data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()
dataset_size = len(data_loader)
print('#training samples = %d' % dataset_size)

train_batch_size = opt.batchSize



model = create_model(opt)
visualizer = Visualizer(opt)

opt.batchSize = 1  # test code only supports batchSize = 1
opt.serial_batches = True  # no shuffle
opt.no_flip = True  # no flip
opt.dataroot = opt.testroot
opt.file_list = opt.test_file_list
Exemple #19
0
def train():
    opt = TrainOptions().parse()
    if opt.debug:
        opt.display_freq = 1
        opt.print_freq = 1
        opt.nThreads = 1

    ### initialize dataset
    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    print('#training videos = %d' % dataset_size)

    ### initialize models
    models = create_model(opt)
    modelG, modelD, flowNet, optimizer_G, optimizer_D, optimizer_D_T = create_optimizer(
        opt, models)

    ### set parameters
    n_gpus, tG, tD, tDB, s_scales, t_scales, input_nc, output_nc, \
        start_epoch, epoch_iter, print_freq, total_steps, iter_path = init_params(opt, modelG, modelD, data_loader)
    visualizer = Visualizer(opt)

    ### real training starts here
    for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()
        for idx, data in enumerate(dataset, start=epoch_iter):
            if total_steps % print_freq == 0:
                iter_start_time = time.time()
            total_steps += opt.batchSize
            epoch_iter += opt.batchSize

            # whether to collect output images
            save_fake = total_steps % opt.display_freq == 0
            n_frames_total, n_frames_load, t_len = data_loader.dataset.init_data_params(
                data, n_gpus, tG)
            fake_B_prev_last, frames_all = data_loader.dataset.init_data(
                t_scales)

            for i in range(0, n_frames_total, n_frames_load):
                input_A, input_B, inst_A = data_loader.dataset.prepare_data(
                    data, i, input_nc, output_nc)

                ###################################### Forward Pass ##########################
                ####### generator
                fake_B, fake_B_raw, flow, weight, real_A, real_Bp, fake_B_last = modelG(
                    input_A, input_B, inst_A, fake_B_prev_last)

                ####### discriminator
                ### individual frame discriminator
                real_B_prev, real_B = real_Bp[:, :
                                              -1], real_Bp[:,
                                                           1:]  # the collection of previous and current real frames
                flow_ref, conf_ref = flowNet(
                    real_B, real_B_prev)  # reference flows and confidences
                fake_B_prev = modelG.module.compute_fake_B_prev(
                    real_B_prev, fake_B_prev_last, fake_B)
                fake_B_prev_last = fake_B_last

                losses = modelD(
                    0,
                    reshape([
                        real_B, fake_B, fake_B_raw, real_A, real_B_prev,
                        fake_B_prev, flow, weight, flow_ref, conf_ref
                    ]))
                losses = [
                    torch.mean(x) if x is not None else 0 for x in losses
                ]
                loss_dict = dict(zip(modelD.module.loss_names, losses))

                ### temporal discriminator
                # get skipped frames for each temporal scale
                frames_all, frames_skipped = modelD.module.get_all_skipped_frames(frames_all, \
                        real_B, fake_B, flow_ref, conf_ref, t_scales, tD, n_frames_load, i, flowNet)

                # run discriminator for each temporal scale
                loss_dict_T = []
                for s in range(t_scales):
                    if frames_skipped[0][s] is not None:
                        losses = modelD(s + 1, [
                            frame_skipped[s]
                            for frame_skipped in frames_skipped
                        ])
                        losses = [
                            torch.mean(x) if not isinstance(x, int) else x
                            for x in losses
                        ]
                        loss_dict_T.append(
                            dict(zip(modelD.module.loss_names_T, losses)))

                # collect losses
                loss_G, loss_D, loss_D_T, t_scales_act = modelD.module.get_losses(
                    loss_dict, loss_dict_T, t_scales)

                ###################################### Backward Pass #################################
                # update generator weights
                loss_backward(opt, loss_G, optimizer_G)

                # update individual discriminator weights
                loss_backward(opt, loss_D, optimizer_D)

                # update temporal discriminator weights
                for s in range(t_scales_act):
                    loss_backward(opt, loss_D_T[s], optimizer_D_T[s])

                if i == 0:
                    fake_B_first = fake_B[
                        0, 0]  # the first generated image in this sequence

            if opt.debug:
                call([
                    "nvidia-smi", "--format=csv",
                    "--query-gpu=memory.used,memory.free"
                ])

            ############## Display results and errors ##########
            ### print out errors
            if total_steps % print_freq == 0:
                t = (time.time() - iter_start_time) / print_freq
                errors = {
                    k: v.data.item() if not isinstance(v, int) else v
                    for k, v in loss_dict.items()
                }
                for s in range(len(loss_dict_T)):
                    errors.update({
                        k + str(s):
                        v.data.item() if not isinstance(v, int) else v
                        for k, v in loss_dict_T[s].items()
                    })
                visualizer.print_current_errors(epoch, epoch_iter, errors, t)
                visualizer.plot_current_errors(errors, total_steps)

            ### display output images
            if save_fake:
                visuals = util.save_all_tensors(opt, real_A, fake_B,
                                                fake_B_first, fake_B_raw,
                                                real_B, flow_ref, conf_ref,
                                                flow, weight, modelD)
                visualizer.display_current_results(visuals, epoch, total_steps)

            ### save latest model
            save_models(opt, epoch, epoch_iter, total_steps, visualizer,
                        iter_path, modelG, modelD)
            if epoch_iter > dataset_size - opt.batchSize:
                epoch_iter = 0
                break

        # end of epoch
        iter_end_time = time.time()
        visualizer.vis_print('End of epoch %d / %d \t Time Taken: %d sec' %
                             (epoch, opt.niter + opt.niter_decay,
                              time.time() - epoch_start_time))

        ### save model for this epoch and update model params
        save_models(opt,
                    epoch,
                    epoch_iter,
                    total_steps,
                    visualizer,
                    iter_path,
                    modelG,
                    modelD,
                    end_of_epoch=True)
        update_models(opt, epoch, modelG, modelD, data_loader)
Exemple #20
0
def main():
    opt = TrainOptions().parse()
    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    print('#training images = %d' % dataset_size)

    # read pix2pix/PAN moodel
    if opt.model == 'pix2pix':
        assert (opt.dataset_mode == 'aligned')
        from models.pix2pix_model import Pix2PixModel
        model = Pix2PixModel()
        model.initialize(opt)
    elif opt.model == 'pan':
        from models.pan_model import PanModel
        model = PanModel()
        model.initialize(opt)

    total_steps = 0

    batch_size = opt.batchSize
    print_freq = opt.print_freq
    epoch_count = opt.epoch_count
    niter = opt.niter
    niter_decay = opt.niter_decay
    display_freq = opt.display_freq
    save_latest_freq = opt.save_latest_freq
    save_epoch_freq = opt.save_epoch_freq

    for epoch in range(epoch_count, niter + niter_decay + 1):
        epoch_start_time = time.time()
        epoch_iter = 0

        for i, data in enumerate(dataset):
            # data --> (1, 3, 256, 256)
            iter_start_time = time.time()
            total_steps += batch_size
            epoch_iter += batch_size
            model.set_input(data)
            model.optimize_parameters()

            if total_steps % print_freq == 0:
                errors = model.get_current_errors()
                t = (time.time() - iter_start_time) / batch_size

                message = '(epoch: %d, iters: %d, time: %.3f) ' % (
                    epoch, epoch_iter, t)
                for k, v in errors.items():
                    message += '%s: %.3f ' % (k, v)
                print(message)

            # save latest weights
            if total_steps % save_latest_freq == 0:
                print('saving the latest model (epoch %d, total_steps %d)' %
                      (epoch, total_steps))
                model.save('latest')

        # save weights periodicaly
        if epoch % save_epoch_freq == 0:
            print('saving the model at the end of epoch %d, iters %d' %
                  (epoch, total_steps))
            model.save('latest')
            model.save(epoch)

        print('End of epoch %d / %d \t Time Taken: %d sec' %
              (epoch, niter + niter_decay, time.time() - epoch_start_time))
        model.update_learning_rate()
def main():  # 입력 X, return X
    iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')
    # 반복 경로 받아오기
    data_loader = CreateDataLoader(opt)
    # option에 해당하는 data_loader 생성

    dataset = data_loader.load_data()
    # dataset을 data_loader로부터 받아온다.
    dataset_size = len(data_loader)
    # dataset의 사이즈를 지정
    print('#training images = %d' % dataset_size)

    start_epoch, epoch_iter = 1, 0
    total_steps = (start_epoch - 1) * dataset_size + epoch_iter
    display_delta = total_steps % opt.display_freq
    print_delta = total_steps % opt.print_freq
    save_delta = total_steps % opt.save_latest_freq
    # delta 값들 지정

    model = create_model(opt)
    # model = model.cuda()
    visualizer = Visualizer(opt)
    # 현재 option에 해당하는 훈련 과정 출력

    for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):
        # 총 40번 반복
        epoch_start_time = time.time()
        if epoch != start_epoch:
            epoch_iter = epoch_iter % dataset_size
        for i, data in enumerate(dataset, start=epoch_iter):
            iter_start_time = time.time()
            total_steps += opt.batchSize
            epoch_iter += opt.batchSize

            # whether to collect output images
            save_fake = total_steps % opt.display_freq == display_delta

            ############## Forward Pass ######################
            losses, generated = model(Variable(data['label']),
                                      Variable(data['inst']),
                                      Variable(data['image']),
                                      Variable(data['feat']),
                                      infer=save_fake)

            # sum per device losses
            losses = [
                torch.mean(x) if not isinstance(x, int) else x for x in losses
            ]
            loss_dict = dict(zip(model.loss_names, losses))

            # calculate final loss scalar
            loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5
            loss_G = loss_dict['G_GAN'] + loss_dict.get(
                'G_GAN_Feat', 0) + loss_dict.get('G_VGG', 0)

            ############### Backward Pass ####################
            # update generator weights
            model.optimizer_G.zero_grad()
            loss_G.backward()
            model.optimizer_G.step()

            # update discriminator weights
            model.optimizer_D.zero_grad()
            loss_D.backward()
            model.optimizer_D.step()

            ############## Display results and errors ##########
            ### print out errors
            if total_steps % opt.print_freq == print_delta:
                errors = {
                    k: v.data if not isinstance(v, int) else v
                    for k, v in loss_dict.items()
                }
                t = (time.time() - iter_start_time) / opt.batchSize
                visualizer.print_current_errors(epoch, epoch_iter, errors, t)
                visualizer.plot_current_errors(errors, total_steps)

            ### display output images
            if save_fake:
                visuals = OrderedDict([
                    ('input_label',
                     util.tensor2label(data['label'][0], opt.label_nc)),
                    ('synthesized_image', util.tensor2im(generated.data[0])),
                    ('real_image', util.tensor2im(data['image'][0]))
                ])
                visualizer.display_current_results(visuals, epoch, total_steps)

            ### save latest model
            if total_steps % opt.save_latest_freq == save_delta:
                print('saving the latest model (epoch %d, total_steps %d)' %
                      (epoch, total_steps))
                model.save('latest')
                np.savetxt(iter_path, (epoch, epoch_iter),
                           delimiter=',',
                           fmt='%d')

            if epoch_iter >= dataset_size:
                break

        # end of epoch
        print('End of epoch %d / %d \t Time Taken: %d sec' %
              (epoch, opt.niter + opt.niter_decay,
               time.time() - epoch_start_time))

        ### save model for this epoch
        if epoch % opt.save_epoch_freq == 0:
            print('saving the model at the end of epoch %d, iters %d' %
                  (epoch, total_steps))
            model.save('latest')
            model.save(epoch)
            np.savetxt(iter_path, (epoch + 1, 0), delimiter=',', fmt='%d')

        ### instead of only training the local enhancer, train the entire network after certain iterations
        if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global):
            model.update_fixed_params()

        ### linearly decay learning rate after certain iterations
        if epoch > opt.niter:
            model.update_learning_rate()

    torch.cuda.empty_cache()
Exemple #22
0
import time
from options.train_options import TrainOptions
from data.data_loader import CreateDataLoader
from models.models import create_model


opt = TrainOptions().parse()
data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()
dataset_size = len(data_loader)
print('#training images = %d' % dataset_size)

model = create_model(opt)
total_steps = 0

for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
    epoch_start_time = time.time()
    epoch_iter = 0

    for i, data in enumerate(dataset):
        iter_start_time = time.time()

        total_steps += opt.batchSize
        epoch_iter += opt.batchSize
        model.set_input(data)       
        model.optimize_parameters()

        
        if total_steps % opt.print_freq == 0:            
            errors = model.get_current_errors()
            t = (time.time() - iter_start_time) / opt.batchSize
Exemple #23
0
import time
from options.train_options import TrainOptions
opt = TrainOptions().parse()

from data.data_loader import CreateDataLoader
from models.models import create_model
from util.logger import Logger

data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data(
)  # dataset is actually a torch.utils.data.DataLoader object
dataset_size = len(data_loader)
print('#training images = %d' % dataset_size)

logger = Logger(opt)

model = create_model(opt)

total_steps = 0
for epoch in range(1, opt.niter + opt.niter_decay + 1):
    epoch_start_time = time.time()

    for i, data in enumerate(dataset):
        iter_start_time = time.time()
        total_steps += opt.batchSize
        epoch_iter = total_steps - dataset_size * (
            epoch - 1)  # iter index in current epoch

        model.set_input(data)
        model.optimize_parameters()
import time
import os
from options.test_options import TestOptions
from data.data_loader import CreateDataLoader
from models.models import create_model
from util.visualizer import Visualizer
from util import html

opt = TestOptions().parse()
opt.nThreads = 1   # test code only supports nThreads = 1
opt.batchSize = 1  # test code only supports batchSize = 1
opt.serial_batches = True  # no shuffle
opt.no_flip = True  # no flip

data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()
model = create_model(opt)
visualizer = Visualizer(opt)
# create website
web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch))
webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch))
# test
for i, data in enumerate(dataset):
    if i >= opt.how_many:
        break
    model.set_input(data)
    model.test()
    visuals = model.get_current_visuals()
    img_path = model.get_image_paths()
    print('process image... %s' % img_path)
    visualizer.save_images(webpage, visuals, img_path)
Exemple #25
0
from options.train_options import TrainOptions
from data.data_loader import CreateDataLoader
from models.create_model import create_model
from utils import visualizer


opt = TrainOptions().parse()

### set dataloader ###
print('### prepare DataLoader ###')
data_loader = CreateDataLoader(opt)
train_loader = data_loader.load_data()
print('training images : {}'.format(len(data_loader)))
print('numof_iteration : {}'.format(len(train_loader)))

### define model ###
model = create_model(opt)

### training loop ###
print('### start training ! ###')
for epoch in range(opt.epoch+opt.epoch_decay+1):
    for iter, data in enumerate(train_loader):

        model.set_variables(data)
        model.optimize_parameters()

        if iter % opt.print_iter_freq == 0:
            losses = model.get_current_losses()
            visualizer.print_current_losses(epoch, iter, losses)

            image = (data['image'] + 1.0) / 2.0