Exemplo n.º 1
0
def train_gpu(rank, world_size, opt, dataset):
    signal.signal(signal.SIGINT, signal_handler)  #to really kill the process
    if len(opt.gpu_ids) > 1:
        setup(rank, world_size, opt.ddp_port)
    dataloader = create_dataloader(
        opt, rank,
        dataset)  # create a dataset given opt.dataset_mode and other options
    dataset_size = len(dataset)  # get the number of images in the dataset.
    model = create_model(
        opt, rank)  # create a model given opt.model and other options

    if hasattr(model, 'data_dependent_initialize'):
        data = next(iter(dataloader))
        model.data_dependent_initialize(data)

    model.setup(
        opt)  # regular setup: load and print networks; create schedulers

    if len(opt.gpu_ids) > 1:
        model.parallelize(rank)
    else:
        model.single_gpu()

    if rank == 0:
        visualizer = Visualizer(
            opt)  # create a visualizer that display/save images and plots
    total_iters = 0  # the total number of training iterations

    if rank == 0:
        model.real_A_val, model.real_B_val = dataset.get_validation_set(
            opt.pool_size)
        model.real_A_val, model.real_B_val = model.real_A_val.to(
            model.device), model.real_B_val.to(model.device)

    if rank == 0 and opt.display_networks:
        data = next(iter(dataloader))
        for path in model.save_networks_img(data):
            visualizer.display_img(path + '.png')

    for epoch in range(
            opt.epoch_count, opt.n_epochs + opt.n_epochs_decay + 1
    ):  # outer loop for different epochs; we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>
        epoch_start_time = time.time()  # timer for entire epoch
        iter_data_time = time.time()  # timer for data loading per iteration
        epoch_iter = 0  # the number of training iterations in current epoch, reset to 0 every epoch
        if rank == 0:
            visualizer.reset(
            )  # reset the visualizer: make sure it saves the results to HTML at least once every epoch

        for i, data in enumerate(
                dataloader):  # inner loop (minibatch) within one epoch

            iter_start_time = time.time(
            )  # timer for computation per iteration
            t_data_mini_batch = iter_start_time - iter_data_time

            model.set_input(
                data)  # unpack data from dataloader and apply preprocessing
            model.optimize_parameters(
            )  # calculate loss functions, get gradients, update network weights

            t_comp = (time.time() - iter_start_time) / opt.batch_size

            batch_size = model.get_current_batch_size() * len(opt.gpu_ids)
            total_iters += batch_size
            epoch_iter += batch_size

            if rank == 0:
                if total_iters % opt.display_freq < batch_size:  # display images on visdom and save images to a HTML file
                    save_result = total_iters % opt.update_html_freq == 0
                    model.compute_visuals()
                    visualizer.display_current_results(
                        model.get_current_visuals(),
                        epoch,
                        save_result,
                        params=model.get_display_param())

                if total_iters % opt.print_freq < batch_size:  # print training losses and save logging information to the disk
                    losses = model.get_current_losses()
                    visualizer.print_current_losses(epoch, epoch_iter, losses,
                                                    t_comp, t_data_mini_batch)
                    if opt.display_id > 0:
                        visualizer.plot_current_losses(
                            epoch,
                            float(epoch_iter) / dataset_size, losses)

                if total_iters % opt.save_latest_freq < batch_size:  # cache our latest model every <save_latest_freq> iterations
                    print(
                        'saving the latest model (epoch %d, total_iters %d)' %
                        (epoch, total_iters))
                    save_suffix = 'iter_%d' % total_iters if opt.save_by_iter else 'latest'
                    model.save_networks(save_suffix)

                if total_iters % opt.fid_every < batch_size and opt.compute_fid:
                    model.compute_fid(epoch, total_iters)
                    if opt.display_id > 0:
                        fids = model.get_current_fids()
                        visualizer.plot_current_fid(
                            epoch,
                            float(epoch_iter) / dataset_size, fids)

                if total_iters % opt.D_accuracy_every < batch_size and opt.compute_D_accuracy:
                    model.compute_D_accuracy()
                    if opt.display_id > 0:
                        accuracies = model.get_current_D_accuracies()
                        visualizer.plot_current_D_accuracies(
                            epoch,
                            float(epoch_iter) / dataset_size, accuracies)

                if total_iters % opt.display_freq < batch_size and opt.APA:
                    if opt.display_id > 0:
                        p = model.get_current_APA_prob()
                        visualizer.plot_current_APA_prob(
                            epoch,
                            float(epoch_iter) / dataset_size, p)

                iter_data_time = time.time()

        if epoch % opt.save_epoch_freq == 0:  # cache our model every <save_epoch_freq> epochs
            if rank == 0:
                print('saving the model at the end of epoch %d, iters %d' %
                      (epoch, total_iters))
                model.save_networks('latest')
                model.save_networks(epoch)

        if rank == 0:
            print('End of epoch %d / %d \t Time Taken: %d sec' %
                  (epoch, opt.n_epochs + opt.n_epochs_decay,
                   time.time() - epoch_start_time))
        model.update_learning_rate(
        )  # update learning rates at the end of every epoch.
Exemplo n.º 2
0
def main():
    with open('./train/train_opt.pkl', mode='rb') as f:
        opt = pickle.load(f)
        opt.checkpoints_dir = './checkpoints/'
        opt.dataroot = './train'
        opt.no_flip = True
        opt.label_nc = 0
        opt.batchSize = 2
        print(opt)

    iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')

    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    print('#training images = %d' % dataset_size)

    start_epoch, epoch_iter = 1, 0
    total_steps = (start_epoch - 1) * dataset_size + epoch_iter
    display_delta = total_steps % opt.display_freq
    print_delta = total_steps % opt.print_freq
    save_delta = total_steps % opt.save_latest_freq
    best_loss = 999999
    epoch_loss = 9999999999
    model = create_model(opt)
    model = model.cuda()
    visualizer = Visualizer(opt)
    #niter = 20,niter_decay = 20
    for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()
        if epoch != start_epoch:
            epoch_iter = epoch_iter % dataset_size
        for i, data in enumerate(dataset, start=epoch_iter):
            iter_start_time = time.time()
            total_steps += opt.batchSize
            epoch_iter += opt.batchSize

            # whether to collect output images
            save_fake = total_steps % opt.display_freq == display_delta

            ############## Forward Pass ######################
            losses, generated = model(Variable(data['label']), Variable(data['inst']),
                                      Variable(data['image']), Variable(data['feat']), infer=save_fake)

            # sum per device losses
            losses = [torch.mean(x) if not isinstance(x, int) else x for x in losses]
            loss_dict = dict(zip(model.loss_names, losses))

            # calculate final loss scalar
            loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5
            loss_G = loss_dict['G_GAN'] + loss_dict.get('G_GAN_Feat', 0) + loss_dict.get('G_VGG', 0)
            loss_DG = loss_D + loss_G

            ############### Backward Pass ####################
            # update generator weights
            model.optimizer_G.zero_grad()
            loss_G.backward()
            model.optimizer_G.step()

            # update discriminator weights
            model.optimizer_D.zero_grad()
            loss_D.backward()
            model.optimizer_D.step()

            # call(["nvidia-smi", "--format=csv", "--query-gpu=memory.used,memory.free"])

            ############## Display results and errors ##########

            ### print out errors
            if total_steps % opt.print_freq == print_delta:
                errors = {k: v.data[0] if not isinstance(v, int) else v for k, v in loss_dict.items()}
                t = (time.time() - iter_start_time) / opt.batchSize
                visualizer.print_current_errors(epoch, epoch_iter, errors, t)
                visualizer.plot_current_errors(errors, total_steps)

            ### display output images
            if save_fake:
                visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)),
                                       ('synthesized_image', util.tensor2im(generated.data[0])),
                                       ('real_image', util.tensor2im(data['image'][0]))])
                visualizer.display_current_results(visuals, epoch, total_steps)

            ### save latest model
            if total_steps % opt.save_latest_freq == save_delta and loss_DG<best_loss:
                best_loss = loss_DG
                print('saving the latest model (epoch %d, total_steps %d ,total loss %g)' % (epoch, total_steps,loss_DG.item()))
                model.save('latest')
                np.savetxt(iter_path, (epoch, epoch_iter), delimiter=',', fmt='%d')

            if epoch_iter >= dataset_size:
                break

        # end of epoch
        iter_end_time = time.time()
        print('End of epoch %d / %d \t Time Taken: %d sec' %
              (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))

        ### save model for this epoch
        if epoch % opt.save_epoch_freq == 0:

            print('saving the model at the end of epoch %d, iters %d ' % (epoch, total_steps))
            model.save('latest')
            model.save(epoch)
            np.savetxt(iter_path, (epoch + 1, 0), delimiter=',', fmt='%d')

        ### instead of only training the local enhancer, train the entire network after certain iterations
        if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global):
            model.update_fixed_params()

        ### linearly decay learning rate after certain iterations
        if epoch > opt.niter:
            model.update_learning_rate()

    torch.cuda.empty_cache()
Exemplo n.º 3
0
    def __init__(self, celebA_loader, rafd_loader, config):
        # Data loader
        self.celebA_loader = celebA_loader
        self.rafd_loader = rafd_loader
        self.visualizer = Visualizer()
        # Model hyper-parameters
        self.c_dim = config.c_dim
        self.s_dim = config.s_dim
        self.c2_dim = config.c2_dim
        self.image_size = config.image_size
        self.g_conv_dim = config.g_conv_dim
        self.d_conv_dim = config.d_conv_dim
        self.g_repeat_num = config.g_repeat_num
        self.d_repeat_num = config.d_repeat_num
        self.d_train_repeat = config.d_train_repeat

        # Hyper-parameteres
        self.lambda_cls = config.lambda_cls
        self.lambda_rec = config.lambda_rec
        self.lambda_gp = config.lambda_gp
        self.lambda_s = config.lambda_s
        self.g_lr = config.g_lr
        self.d_lr = config.d_lr
        self.a_lr = config.a_lr
        self.beta1 = config.beta1
        self.beta2 = config.beta2

        # Criterion
        self.criterion_s = CrossEntropyLoss2d(size_average=True).cuda()

        # Training settings
        self.dataset = config.dataset
        self.num_epochs = config.num_epochs
        self.num_epochs_decay = config.num_epochs_decay
        self.num_iters = config.num_iters
        self.num_iters_decay = config.num_iters_decay
        self.batch_size = config.batch_size
        self.use_tensorboard = config.use_tensorboard
        self.pretrained_model = config.pretrained_model

        # Test settings
        self.test_model = config.test_model
        self.config = config

        # Path
        self.log_path = config.log_path
        self.sample_path = config.sample_path
        self.model_save_path = config.model_save_path
        self.result_path = config.result_path

        # Step size
        self.log_step = config.log_step
        self.visual_step = self.log_step
        self.sample_step = config.sample_step
        self.model_save_step = config.model_save_step

        # Build tensorboard if use
        self.build_model()
        if self.use_tensorboard:
            self.build_tensorboard()

        # Start with trained model
        if self.pretrained_model:
            self.load_pretrained_model()
Exemplo n.º 4
0
def main():
    opt = TestOptions().parse(save=False)
    opt.nThreads = 1  # test code only supports nThreads = 1
    opt.batchSize = 1  # test code only supports batchSize = 1
    opt.serial_batches = True  # no shuffle
    opt.no_flip = True  # no flip

    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    visualizer = Visualizer(opt)
    # create website
    web_dir = os.path.join(opt.results_dir, opt.name,
                           '%s_%s' % (opt.phase, opt.which_epoch))
    webpage = html.HTML(
        web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' %
        (opt.name, opt.phase, opt.which_epoch))

    # test
    if not opt.engine and not opt.onnx:
        model = create_model(opt)
        if opt.data_type == 16:
            model.half()
        elif opt.data_type == 8:
            model.type(torch.uint8)

        if opt.verbose:
            print(model)
    else:
        from run_engine import run_trt_engine, run_onnx

    for i, data in enumerate(dataset):
        if i >= opt.how_many:
            break
        if opt.data_type == 16:
            data['label'] = data['label'].half()
            data['inst'] = data['inst'].half()
        elif opt.data_type == 8:
            data['label'] = data['label'].uint8()
            data['inst'] = data['inst'].uint8()
        if opt.export_onnx:
            print("Exporting to ONNX: ", opt.export_onnx)
            assert opt.export_onnx.endswith(
                "onnx"), "Export model file should end with .onnx"
            torch.onnx.export(model, [data['label'], data['inst']],
                              opt.export_onnx,
                              verbose=True)
            exit(0)
        minibatch = 1
        if opt.engine:
            generated = run_trt_engine(opt.engine, minibatch,
                                       [data['label'], data['inst']])
        elif opt.onnx:
            generated = run_onnx(opt.onnx, opt.data_type, minibatch,
                                 [data['label'], data['inst']])
        else:
            generated = model.inference(data['label'], data['inst'],
                                        data['image'])

        visuals = OrderedDict([
            ('input_label', util.tensor2label(data['label'][0], opt.label_nc)),
            ('synthesized_image', util.tensor2im(generated.data[0]))
        ])
        img_path = data['path']
        print('process image... %s' % img_path)
        visualizer.save_images(webpage, visuals, img_path)

    webpage.save()
Exemplo n.º 5
0
def train():
    opt = vars(TrainOptions().parse())
    # Initialize dataset:==============================================================================================#
    data_loader = create_dataloader(**opt)
    dataset_size = len(data_loader)
    print(f'Number of training videos = {dataset_size}')
    # Initialize models:===============================================================================================#
    models = prepare_models(**opt)
    model_g, model_d, flow_net, optimizer_g, optimizer_d, optimizer_d_t = create_optimizer(
        models, **opt)
    # Set parameters:==================================================================================================#
    n_gpus, tG, tD, tDB, s_scales, t_scales, input_nc, output_nc, \
    start_epoch, epoch_iter, print_freq, total_steps, iter_path = init_params(model_g, model_d, data_loader, **opt)
    visualizer = Visualizer(**opt)

    # Initialize loss list:============================================================================================#
    losses_G = []
    losses_D = []
    # Training start:==================================================================================================#
    for epoch in range(start_epoch, opt['niter'] + opt['niter_decay'] + 1):
        epoch_start_time = time.time()
        for idx, video in enumerate(data_loader, start=epoch_iter):
            if not total_steps % print_freq:
                iter_start_time = time.time()
            total_steps += opt['batch_size']
            epoch_iter += opt['batch_size']

            # whether to collect output images
            save_fake = total_steps % opt['display_freq'] == 0
            fake_B_prev_last = None
            real_B_all, fake_B_all, flow_ref_all, conf_ref_all = None, None, None, None  # all real/generated frames so far
            if opt['sparse_D']:
                real_B_all, fake_B_all, flow_ref_all, conf_ref_all = [
                    None
                ] * t_scales, [None] * t_scales, [None] * t_scales, [
                    None
                ] * t_scales
            frames_all = real_B_all, fake_B_all, flow_ref_all, conf_ref_all

            for i, (input_A, input_B) in enumerate(VideoSeq(**video, **opt)):

                # Forward Pass:========================================================================================#
                # Generator:===========================================================================================#
                fake_B, fake_B_raw, flow, weight, real_A, real_Bp, fake_B_last = model_g(
                    input_A, input_B, fake_B_prev_last)

                # Discriminator:=======================================================================================#
                # individual frame discriminator:==============================#
                # the collection of previous and current real frames
                real_B_prev, real_B = real_Bp[:, :-1], real_Bp[:, 1:]
                # reference flows and confidences
                flow_ref, conf_ref = flow_net(real_B, real_B_prev)
                fake_B_prev = model_g.compute_fake_B_prev(
                    real_B_prev, fake_B_prev_last, fake_B)
                fake_B_prev_last = fake_B_last

                losses = model_d(
                    0,
                    reshape([
                        real_B, fake_B, fake_B_raw, real_A, real_B_prev,
                        fake_B_prev, flow, weight, flow_ref, conf_ref
                    ]))
                losses = [
                    torch.mean(x) if x is not None else 0 for x in losses
                ]
                loss_dict = dict(zip(model_d.loss_names, losses))

                # Temporal Discriminator:======================================#
                # get skipped frames for each temporal scale
                frames_all, frames_skipped = \
                    model_d.get_all_skipped_frames(frames_all,
                                                  real_B,
                                                  fake_B,
                                                  flow_ref,
                                                  conf_ref,
                                                  t_scales,
                                                  tD,
                                                  video.n_frames_load,
                                                  i,
                                                  flow_net)
                # run discriminator for each temporal scale:===================#
                loss_dict_T = []
                for s in range(t_scales):
                    if frames_skipped[0][s] is not None:
                        losses = model_d(s + 1, [
                            frame_skipped[s]
                            for frame_skipped in frames_skipped
                        ])
                        losses = [
                            torch.mean(x) if not isinstance(x, int) else x
                            for x in losses
                        ]
                        loss_dict_T.append(
                            dict(zip(model_d.loss_names_T, losses)))

                # Collect losses:==============================================#
                loss_G, loss_D, loss_D_T, t_scales_act = model_d.get_losses(
                    loss_dict, loss_dict_T, t_scales)

                losses_G.append(loss_G.item())
                losses_D.append(loss_D.item())

                ################################## Backward Pass ###########################
                # Update generator weights
                loss_backward(loss_G, optimizer_g)

                # update individual discriminator weights
                loss_backward(loss_D, optimizer_d)

                # update temporal discriminator weights
                for s in range(t_scales_act):
                    loss_backward(opt, loss_D_T[s], optimizer_d_t[s])
                # the first generated image in this sequence
                if i == 0:
                    fake_B_first = fake_B[0, 0]

            # Display results and errors:==============================================#
            # Print out errors:================================================#
            if total_steps % print_freq == 0:
                t = (time.time() - iter_start_time) / print_freq
                errors = {k: v.data.item() if not isinstance(v, int) \
                    else v for k, v in loss_dict.items()}
                for s in range(len(loss_dict_T)):
                    errors.update({k + str(s): v.data.item() \
                        if not isinstance(v, int) \
                        else v for k, v in loss_dict_T[s].items()})
                visualizer.print_current_errors(epoch, epoch_iter, errors, t)
                visualizer.plot_current_errors(errors, total_steps)

            # Display output images:===========================================#
            if save_fake:
                visuals = util.save_all_tensors(opt, real_A, fake_B,
                                                fake_B_first, fake_B_raw,
                                                real_B, flow_ref, conf_ref,
                                                flow, weight, model_d)
                visualizer.display_current_results(visuals, epoch, total_steps)

            # Save latest model:===============================================#
            save_models(epoch, epoch_iter, total_steps, visualizer, iter_path,
                        model_g, model_d, **opt)
            if epoch_iter > dataset_size - opt['batch_size']:
                epoch_iter = 0
                break

        # End of epoch:========================================================#
        visualizer.vis_print(
            f'End of epoch {epoch} / {opt["niter"] + opt["niter_decay"]} \t'
            f' Time Taken: {time.time() - epoch_start_time} sec')

        # save model for this epoch and update model params:=================#
        save_models(epoch,
                    epoch_iter,
                    total_steps,
                    visualizer,
                    iter_path,
                    model_g,
                    model_d,
                    end_of_epoch=True,
                    **opt)
        update_models(epoch, model_g, model_d, data_loader, **opt)

        from matplotlib import pyplot as plt
        plt.switch_backend('agg')
        print("Generator Loss: %f." % losses_G[-1])
        print("Discriminator Loss: %f." % losses_D[-1])
        # Plot Losses
        plt.plot(losses_G, '-b', label='losses_G')
        plt.plot(losses_D, '-r', label='losses_D')
        # plt.plot(losses_D_T, '-r', label='losses_D_T')
        plot_name = 'checkpoints/' + opt['name'] + '/losses_plot.png'
        plt.savefig(plot_name)
        plt.close()
Exemplo n.º 6
0
# data_loader = getLoader(opt)
# dataset_size = len(data_loader)
# print('#training images = %d' % dataset_size)

model = CasGAN()
model.initialize(opt)
# visualizer = Visualizer(opt)
# total_steps = 0
# fixed_noise = torch.FloatTensor(opt.valBatchSize, opt.hidden_size, 1, 1).uniform_(-1, 1)
# fixed_noise = fixed_noise.cuda()

image = rtf(opt.img_path, opt)
image = image.unsqueeze(0)
print(image.size())
n_fake, c_fake = model.encode(image)
vis = Visualizer(opt)

c_fake_m = c_fake.clone()
c_fake_m.data[0, 0, 0, 0] = 1
vis.plot_current_label((c_fake, c_fake_m), 1)

img_generated = model.decode(n_fake, c_fake)
img_generated = tensor2im(img_generated)
from matplotlib import pyplot as plt
plt.imshow(img_generated, interpolation='nearest')
plt.show()
# img_generated = model.decode(n_fake,c_fake_m)
# img_generated = tensor2im(img_generated)
# from matplotlib import pyplot as plt
# plt.imshow(img_generated, interpolation='nearest')
# plt.show()
Exemplo n.º 7
0
def train():
    opt = TrainOptions().parse()
    if opt.debug:
        opt.display_freq = 1
        opt.print_freq = 1
        opt.nThreads = 1

    # Initialize dataset:======================================================#
    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    print('Number of training videos = %d' % dataset_size)

    # Initialize models:=======================================================#
    models = prepare_models(opt)
    modelG, modelD, flowNet, optimizer_G, optimizer_D, optimizer_D_T = \
            create_optimizer(opt, models)

    # Set parameters:==========================================================#
    n_gpus, tG, tD, tDB, s_scales, t_scales, input_nc, output_nc, \
    start_epoch, epoch_iter, print_freq, total_steps, iter_path = \
    init_params(opt, modelG, modelD, data_loader)
    visualizer = Visualizer(opt)

    # Initialize loss list:====================================================#
    losses_G = []
    losses_D = []
    losses_D_T = []
    losses_t_scales = []

    # Real training starts here:===============================================#
    for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()
        for idx, data in enumerate(dataset, start=epoch_iter):
            if total_steps % print_freq == 0:
                iter_start_time = time.time()
            total_steps += opt.batch_size
            epoch_iter += opt.batch_size

            # whether to collect output images
            save_fake = total_steps % opt.display_freq == 0
            n_frames_total, n_frames_load, t_len = \
                data_loader.dataset.init_data_params(data, n_gpus, tG)
            fake_B_prev_last, frames_all = data_loader.dataset.init_data(
                t_scales)

            for i in range(0, n_frames_total, n_frames_load):
                input_A, input_B, input_C, inst_A = \
                        data_loader.dataset.prepare_data(data, i, input_nc, output_nc)

                ############################### Forward Pass ###############################
                ####### Generator:=========================================================#
                fake_B, fake_B_raw, flow, weight, real_A, real_Bp, fake_B_last = \
                        modelG(input_A, input_B, inst_A, fake_B_prev_last)

                ####### Discriminator:=====================================================#
                # individual frame discriminator:==============================#
                # the collection of previous and current real frames
                real_B_prev, real_B = real_Bp[:, :-1], real_Bp[:, 1:]
                # reference flows and confidences
                flow_ref, conf_ref = flowNet(real_B, real_B_prev)
                fake_B_prev = modelG.module.compute_fake_B_prev(
                    real_B_prev, fake_B_prev_last, fake_B)
                fake_B_prev_last = fake_B_last

                losses = modelD(
                    0,
                    reshape([
                        real_B, fake_B, fake_B_raw, real_A, real_B_prev,
                        fake_B_prev, flow, weight, flow_ref, conf_ref, input_C
                    ]))
                losses = [
                    torch.mean(x) if x is not None else 0 for x in losses
                ]
                loss_dict = dict(zip(modelD.module.loss_names, losses))

                # Temporal Discriminator:======================================#
                # get skipped frames for each temporal scale
                frames_all, frames_skipped = \
                modelD.module.get_all_skipped_frames(frames_all,
                real_B,
                                                     fake_B,
                                                     flow_ref,
                                                     conf_ref,
                                                     t_scales,
                                                     tD,
                                                     n_frames_load,
                                                     i,
                                                     flowNet)
                # run discriminator for each temporal scale:===================#
                loss_dict_T = []
                for s in range(t_scales):
                    if frames_skipped[0][s] is not None:
                        losses = modelD(s + 1, [
                            frame_skipped[s]
                            for frame_skipped in frames_skipped
                        ])
                        losses = [
                            torch.mean(x) if not isinstance(x, int) else x
                            for x in losses
                        ]
                        loss_dict_T.append(
                            dict(zip(modelD.module.loss_names_T, losses)))

                # Collect losses:==============================================#
                loss_G, loss_D, loss_D_T, t_scales_act = \
                                modelD.module.get_losses(loss_dict, loss_dict_T, t_scales)

                losses_G.append(loss_G.item())
                losses_D.append(loss_D.item())

                ################################## Backward Pass ###########################
                # Update generator weights
                loss_backward(opt, loss_G, optimizer_G)

                # update individual discriminator weights
                loss_backward(opt, loss_D, optimizer_D)

                # update temporal discriminator weights
                for s in range(t_scales_act):
                    loss_backward(opt, loss_D_T[s], optimizer_D_T[s])
                # the first generated image in this sequence
                if i == 0: fake_B_first = fake_B[0, 0]

            if opt.debug:
                call([
                    "nvidia-smi", "--format=csv",
                    "--query-gpu=memory.used,memory.free"
                ])

    # Display results and errors:==============================================#
    # Print out errors:================================================#
            if total_steps % print_freq == 0:
                t = (time.time() - iter_start_time) / print_freq
                errors = {k: v.data.item() if not isinstance(v, int) \
                                           else v for k, v in loss_dict.items()}
                for s in range(len(loss_dict_T)):
                    errors.update({k + str(s): v.data.item() \
                                   if not isinstance(v, int) \
                                   else v for k, v in loss_dict_T[s].items()})
                visualizer.print_current_errors(epoch, epoch_iter, errors, t)
                visualizer.plot_current_errors(errors, total_steps)

            # Display output images:===========================================#
            if save_fake:
                visuals = util.save_all_tensors(opt, real_A, fake_B,
                                                fake_B_first, fake_B_raw,
                                                real_B, flow_ref, conf_ref,
                                                flow, weight, modelD)
                visualizer.display_current_results(visuals, epoch, total_steps)

            # Save latest model:===============================================#
            save_models(opt, epoch, epoch_iter, total_steps, visualizer,
                        iter_path, modelG, modelD)
            if epoch_iter > dataset_size - opt.batch_size:
                epoch_iter = 0
                break

        # End of epoch:========================================================#
        visualizer.vis_print('End of epoch %d / %d \t Time Taken: %d sec' % \
       (epoch, opt.niter + opt.niter_decay,
                             time.time() - epoch_start_time))

        ### save model for this epoch and update model params:=================#
        save_models(opt,
                    epoch,
                    epoch_iter,
                    total_steps,
                    visualizer,
                    iter_path,
                    modelG,
                    modelD,
                    end_of_epoch=True)
        update_models(opt, epoch, modelG, modelD, data_loader)

        from matplotlib import pyplot as plt
        plt.switch_backend('agg')
        print("Generator Loss: %f." % losses_G[-1])
        print("Discriminator loss: %f." % losses_D[-1])
        #Plot Losses
        plt.plot(losses_G, '-b', label='losses_G')
        plt.plot(losses_D, '-r', label='losses_D')
        # plt.plot(losses_D_T, '-r', label='losses_D_T')
        plot_name = 'checkpoints/' + opt.name + '/losses_plot.png'
        plt.savefig(plot_name)
        plt.close()
Exemplo n.º 8
0
def main():  # 입력 X, return X
    iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')
    # 반복 경로 받아오기
    data_loader = CreateDataLoader(opt)
    # option에 해당하는 data_loader 생성

    dataset = data_loader.load_data()
    # dataset을 data_loader로부터 받아온다.
    dataset_size = len(data_loader)
    # dataset의 사이즈를 지정
    print('#training images = %d' % dataset_size)

    start_epoch, epoch_iter = 1, 0
    total_steps = (start_epoch - 1) * dataset_size + epoch_iter
    display_delta = total_steps % opt.display_freq
    print_delta = total_steps % opt.print_freq
    save_delta = total_steps % opt.save_latest_freq
    # delta 값들 지정

    model = create_model(opt)
    # model = model.cuda()
    visualizer = Visualizer(opt)
    # 현재 option에 해당하는 훈련 과정 출력

    for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):
        # 총 40번 반복
        epoch_start_time = time.time()
        if epoch != start_epoch:
            epoch_iter = epoch_iter % dataset_size
        for i, data in enumerate(dataset, start=epoch_iter):
            iter_start_time = time.time()
            total_steps += opt.batchSize
            epoch_iter += opt.batchSize

            # whether to collect output images
            save_fake = total_steps % opt.display_freq == display_delta

            ############## Forward Pass ######################
            losses, generated = model(Variable(data['label']),
                                      Variable(data['inst']),
                                      Variable(data['image']),
                                      Variable(data['feat']),
                                      infer=save_fake)

            # sum per device losses
            losses = [
                torch.mean(x) if not isinstance(x, int) else x for x in losses
            ]
            loss_dict = dict(zip(model.loss_names, losses))

            # calculate final loss scalar
            loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5
            loss_G = loss_dict['G_GAN'] + loss_dict.get(
                'G_GAN_Feat', 0) + loss_dict.get('G_VGG', 0)

            ############### Backward Pass ####################
            # update generator weights
            model.optimizer_G.zero_grad()
            loss_G.backward()
            model.optimizer_G.step()

            # update discriminator weights
            model.optimizer_D.zero_grad()
            loss_D.backward()
            model.optimizer_D.step()

            ############## Display results and errors ##########
            ### print out errors
            if total_steps % opt.print_freq == print_delta:
                errors = {
                    k: v.data if not isinstance(v, int) else v
                    for k, v in loss_dict.items()
                }
                t = (time.time() - iter_start_time) / opt.batchSize
                visualizer.print_current_errors(epoch, epoch_iter, errors, t)
                visualizer.plot_current_errors(errors, total_steps)

            ### display output images
            if save_fake:
                visuals = OrderedDict([
                    ('input_label',
                     util.tensor2label(data['label'][0], opt.label_nc)),
                    ('synthesized_image', util.tensor2im(generated.data[0])),
                    ('real_image', util.tensor2im(data['image'][0]))
                ])
                visualizer.display_current_results(visuals, epoch, total_steps)

            ### save latest model
            if total_steps % opt.save_latest_freq == save_delta:
                print('saving the latest model (epoch %d, total_steps %d)' %
                      (epoch, total_steps))
                model.save('latest')
                np.savetxt(iter_path, (epoch, epoch_iter),
                           delimiter=',',
                           fmt='%d')

            if epoch_iter >= dataset_size:
                break

        # end of epoch
        print('End of epoch %d / %d \t Time Taken: %d sec' %
              (epoch, opt.niter + opt.niter_decay,
               time.time() - epoch_start_time))

        ### save model for this epoch
        if epoch % opt.save_epoch_freq == 0:
            print('saving the model at the end of epoch %d, iters %d' %
                  (epoch, total_steps))
            model.save('latest')
            model.save(epoch)
            np.savetxt(iter_path, (epoch + 1, 0), delimiter=',', fmt='%d')

        ### instead of only training the local enhancer, train the entire network after certain iterations
        if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global):
            model.update_fixed_params()

        ### linearly decay learning rate after certain iterations
        if epoch > opt.niter:
            model.update_learning_rate()

    torch.cuda.empty_cache()
Exemplo n.º 9
0
def train():
    opt = TrainOptions().parse()
    iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')
    if opt.continue_train:
        try:
            start_epoch, epoch_iter = np.loadtxt(iter_path,
                                                 delimiter=',',
                                                 dtype=int)
        except:
            start_epoch, epoch_iter = 1, 0
        # compute resume lr
        if start_epoch > opt.niter:
            lrd_unit = opt.lr / opt.niter_decay
            resume_lr = opt.lr - (start_epoch - opt.niter) * lrd_unit
            opt.lr = resume_lr
        print('Resuming from epoch %d at iteration %d' %
              (start_epoch, epoch_iter))
    else:
        start_epoch, epoch_iter = 1, 0

    opt.print_freq = lcm(opt.print_freq, opt.batchSize)
    if opt.debug:
        opt.display_freq = 2
        opt.print_freq = 2
        opt.niter = 3
        opt.niter_decay = 0
        opt.max_dataset_size = 1
        opt.valSize = 1

    ## Loading data
    # train data
    data_loader = CreateDataLoader(opt, isVal=False)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    print('# training images = %d' % dataset_size)
    # validation data
    data_loader = CreateDataLoader(opt, isVal=True)
    valset = data_loader.load_data()
    print('# validation images = %d' % len(data_loader))

    ## Loading model
    model = create_model(opt)
    visualizer = Visualizer(opt)
    if opt.fp16:
        from apex import amp
        model, [optimizer_G, optimizer_D
                ] = amp.initialize(model,
                                   [model.optimizer_G, model.optimizer_D],
                                   opt_level='O1')
        model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids)
    else:
        optimizer_G, optimizer_D = model.module.optimizer_G, model.module.optimizer_D

    total_steps = (start_epoch - 1) * dataset_size + epoch_iter

    display_delta = total_steps % opt.display_freq
    print_delta = total_steps % opt.print_freq
    save_delta = total_steps % opt.save_latest_freq

    for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()
        if epoch != start_epoch:
            # epoch_iter = epoch_iter % dataset_size
            epoch_iter = 0
        for i, data in enumerate(dataset, start=epoch_iter):
            if total_steps % opt.print_freq == print_delta:
                iter_start_time = time.time()

            total_steps += opt.batchSize
            epoch_iter += opt.batchSize

            # whether to collect output images
            save_fake = total_steps % opt.display_freq == display_delta

            ############## Forward Pass ######################
            model = model.train()
            losses, generated, metrics = model(data['A'],
                                               data['B'],
                                               data['geometry'],
                                               infer=False)

            # sum per device losses and metrics
            losses = [
                torch.mean(x) if not isinstance(x, int) else x for x in losses
            ]
            metric_dict = {k: torch.mean(v) for k, v in metrics.items()}
            loss_dict = dict(zip(model.module.loss_names, losses))

            # calculate final loss scalar
            loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5
            loss_G = loss_dict['G_GAN'] + loss_dict.get(
                'G_GAN_Feat', 0) + opt.vgg_weight * loss_dict.get('G_VGG', 0)

            ############### Backward Pass ####################
            # update generator weights
            optimizer_G.zero_grad()
            if opt.fp16:
                with amp.scale_loss(loss_G, optimizer_G) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss_G.backward()
            optimizer_G.step()

            # update discriminator weights
            optimizer_D.zero_grad()
            if opt.fp16:
                with amp.scale_loss(loss_D, optimizer_D) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss_D.backward()
            optimizer_D.step()

            ############## Display results and errors ##########
            ### print out errors
            if total_steps % opt.print_freq == print_delta:
                errors = {
                    k: v.data.item() if not isinstance(v, int) else v
                    for k, v in loss_dict.items()
                }
                metrics_ = {
                    k: v.data.item() if not isinstance(v, int) else v
                    for k, v in metric_dict.items()
                }
                t = (time.time() - iter_start_time) / opt.print_freq
                visualizer.print_current_errors(epoch, epoch_iter, errors, t)
                visualizer.plot_current_errors(errors, total_steps)
                visualizer.print_current_metrics(epoch, epoch_iter, metrics_,
                                                 t)
                visualizer.plot_current_metrics(metrics_, total_steps)
                #call(["nvidia-smi", "--format=csv", "--query-gpu=memory.used,memory.free"])

            ### display output images
            if save_fake:
                if opt.task_type == 'specular':
                    visuals = OrderedDict([
                        ('albedo', util.tensor2im(data['A'][0])),
                        ('generated',
                         util.tensor2im_exr(generated.data[0], type=1)),
                        ('GT', util.tensor2im_exr(data['B'][0], type=1))
                    ])
                elif opt.task_type == 'low':
                    visuals = OrderedDict([
                        ('albedo', util.tensor2im(data['A'][0])),
                        ('generated',
                         util.tensor2im_exr(generated.data[0], type=2)),
                        ('GT', util.tensor2im_exr(data['B'][0], type=2))
                    ])
                elif opt.task_type == 'high':
                    visuals = OrderedDict([
                        ('albedo', util.tensor2im(data['A'][0])),
                        ('generated',
                         util.tensor2im_exr(generated.data[0], type=3)),
                        ('GT', util.tensor2im_exr(data['B'][0], type=3))
                    ])
                visualizer.display_current_results(visuals, epoch, total_steps)

            ### save latest model
            if total_steps % opt.save_latest_freq == save_delta:
                print('saving the latest model (epoch %d, total_steps %d)' %
                      (epoch, total_steps))
                model.module.save('latest')
                np.savetxt(iter_path, (epoch, epoch_iter),
                           delimiter=',',
                           fmt='%d')

            if epoch_iter >= dataset_size:
                break

        # end of epoch
        print('End of epoch %d / %d \t Time Taken: %d sec' %
              (epoch, opt.niter + opt.niter_decay,
               time.time() - epoch_start_time))

        ###########################################################################################
        # validation at the end of each epoch
        val_start_time = time.time()
        metrics_val = []
        for _, val_data in enumerate(valset):
            model = model.eval()
            # model.half()
            generated, metrics = model(val_data['A'],
                                       val_data['B'],
                                       val_data['geometry'],
                                       infer=True)
            metric_dict = {k: torch.mean(v) for k, v in metrics.items()}
            metrics_ = {
                k: v.data.item() if not isinstance(v, int) else v
                for k, v in metric_dict.items()
            }
            metrics_val.append(metrics_)
        # Print out losses
        metrics_val = visualizer.mean4dict(metrics_val)
        t = (time.time() - val_start_time) / opt.print_freq
        visualizer.print_current_metrics(epoch,
                                         epoch_iter,
                                         metrics_val,
                                         t,
                                         isVal=True)
        visualizer.plot_current_metrics(metrics_val, total_steps, isVal=True)
        # visualization
        if opt.task_type == 'specular':
            visuals = OrderedDict([
                ('albedo', util.tensor2im(val_data['A'][0])),
                ('generated', util.tensor2im_exr(generated.data[0], type=1)),
                ('GT', util.tensor2im_exr(val_data['B'][0], type=1))
            ])
        if opt.task_type == 'low':
            visuals = OrderedDict([
                ('albedo', util.tensor2im(val_data['A'][0])),
                ('generated', util.tensor2im_exr(generated.data[0], type=2)),
                ('GT', util.tensor2im_exr(val_data['B'][0], type=2))
            ])
        if opt.task_type == 'high':
            visuals = OrderedDict([
                ('albedo', util.tensor2im(val_data['A'][0])),
                ('generated', util.tensor2im_exr(generated.data[0], type=3)),
                ('GT', util.tensor2im_exr(val_data['B'][0], type=3))
            ])
        visualizer.display_current_results(visuals, epoch, epoch, isVal=True)
        ###########################################################################################

        ### save model for this epoch
        if epoch % opt.save_epoch_freq == 0:
            print('saving the model at the end of epoch %d, iters %d' %
                  (epoch, total_steps))
            model.module.save('latest')
            model.module.save(epoch)
            np.savetxt(iter_path, (epoch + 1, 0), delimiter=',', fmt='%d')

        ### instead of only training the local enhancer, train the entire network after certain iterations
        if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global):
            model.module.update_fixed_params()

        ### linearly decay learning rate after certain iterations
        if epoch > opt.niter:
            model.module.update_learning_rate()