Exemplo n.º 1
0
    epoch_start_time = time.time()
    epoch_iter = 0
    for i, data in enumerate(dataset):
        iter_start_time = time.time()
        total_steps += opt.batchSize
        epoch_iter += opt.batchSize
        model.set_input(data)
        model.optimize_parameters()

        if total_steps % opt.display_freq == 0:
            visualizer.display_current_results(model.get_current_visuals(), epoch)

        if total_steps % opt.print_freq == 0:
            errors = model.get_current_errors()
            t = (time.time() - iter_start_time) / opt.batchSize
            visualizer.print_current_errors(epoch, epoch_iter, errors, t)
            if opt.display_id > 0:
                visualizer.plot_current_errors(epoch, float(epoch_iter)/dataset_size, opt, errors)

    if epoch % opt.save_epoch_freq == 0:
        print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps))
        model.save(epoch)

    print('End of epoch %d / %d \t Time Taken: %d sec' %
          (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))

    if epoch > opt.niter:
        decay_frac = (epoch - opt.niter) / opt.niter_decay
        new_lr = opt.lr * (1 - decay_frac)
        model.update_learning_rate(new_lr)
Exemplo n.º 2
0
        iter_counter.record_one_iteration()

        # Training
        # train generator
        if i % opt.D_steps_per_G == 0:
            trainer.run_generator_one_step(data_i)

        # train discriminator
        trainer.run_discriminator_one_step(data_i)

        # Visualizations
        if iter_counter.needs_printing():
            losses = trainer.get_latest_losses()
            visualizer.print_current_errors(epoch, iter_counter.epoch_iter,
                                            losses, iter_counter.time_per_iter)
            visualizer.plot_current_errors(losses,
                                           iter_counter.total_steps_so_far)

        if iter_counter.needs_displaying():
            visuals = OrderedDict([('input_label', data_i['label']),
                                   ('synthesized_image',
                                    trainer.get_latest_generated()),
                                   ('real_image', data_i['image'])])
            visualizer.display_current_results(visuals, epoch,
                                               iter_counter.total_steps_so_far)

        if iter_counter.needs_saving():
            print('saving the latest model (epoch %d, total_steps %d)' %
                  (epoch, iter_counter.total_steps_so_far))
            trainer.save('latest')
            iter_counter.record_current_iter()
Exemplo n.º 3
0
def train(opt):
    iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')

    if opt.continue_train:
        if opt.which_epoch == 'latest':
            try:
                start_epoch, epoch_iter = np.loadtxt(iter_path,
                                                     delimiter=',',
                                                     dtype=int)
            except:
                start_epoch, epoch_iter = 1, 0
        else:
            start_epoch, epoch_iter = int(opt.which_epoch), 0

        print('Resuming from epoch %d at iteration %d' %
              (start_epoch, epoch_iter))
        for update_point in opt.decay_epochs:
            if start_epoch < update_point:
                break

            opt.lr *= opt.decay_gamma
    else:
        start_epoch, epoch_iter = 0, 0

    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    print('#training images = %d' % dataset_size)

    model = create_model(opt)
    visualizer = Visualizer(opt)

    total_steps = (start_epoch) * dataset_size + epoch_iter

    display_delta = total_steps % opt.display_freq
    print_delta = total_steps % opt.print_freq
    save_delta = total_steps % opt.save_latest_freq
    bSize = opt.batchSize

    #in case there's no display sample one image from each class to test after every epoch
    if opt.display_id == 0:
        dataset.dataset.set_sample_mode(True)
        dataset.num_workers = 1
        for i, data in enumerate(dataset):
            if i * opt.batchSize >= opt.numClasses:
                break
            if i == 0:
                sample_data = data
            else:
                for key, value in data.items():
                    if torch.is_tensor(data[key]):
                        sample_data[key] = torch.cat(
                            (sample_data[key], data[key]), 0)
                    else:
                        sample_data[key] = sample_data[key] + data[key]
        dataset.num_workers = opt.nThreads
        dataset.dataset.set_sample_mode(False)

    for epoch in range(start_epoch, opt.epochs):
        epoch_start_time = time.time()
        if epoch != start_epoch:
            epoch_iter = 0
        for i, data in enumerate(dataset, start=epoch_iter):
            iter_start_time = time.time()
            total_steps += opt.batchSize
            epoch_iter += opt.batchSize

            # whether to collect output images
            save_fake = (total_steps % opt.display_freq
                         == display_delta) and (opt.display_id > 0)

            ############## Network Pass ########################
            model.set_inputs(data)
            disc_losses = model.update_D()
            gen_losses, gen_in, gen_out, rec_out, cyc_out = model.update_G(
                infer=save_fake)
            loss_dict = dict(gen_losses, **disc_losses)
            ##################################################

            ############## Display results and errors ##########
            ### print out errors
            if total_steps % opt.print_freq == print_delta:
                errors = {
                    k: v.item()
                    if not (isinstance(v, float) or isinstance(v, int)) else v
                    for k, v in loss_dict.items()
                }
                t = (time.time() - iter_start_time) / opt.batchSize
                visualizer.print_current_errors(epoch + 1, epoch_iter, errors,
                                                t)
                if opt.display_id > 0:
                    visualizer.plot_current_errors(
                        epoch,
                        float(epoch_iter) / dataset_size, opt, errors)

            ### display output images
            if save_fake and opt.display_id > 0:
                class_a_suffix = ' class {}'.format(data['A_class'][0])
                class_b_suffix = ' class {}'.format(data['B_class'][0])
                classes = None

                visuals = OrderedDict()
                visuals_A = OrderedDict([('real image' + class_a_suffix,
                                          util.tensor2im(gen_in.data[0]))])
                visuals_B = OrderedDict([('real image' + class_b_suffix,
                                          util.tensor2im(gen_in.data[bSize]))])

                A_out_vis = OrderedDict([('synthesized image' + class_b_suffix,
                                          util.tensor2im(gen_out.data[0]))])
                B_out_vis = OrderedDict([('synthesized image' + class_a_suffix,
                                          util.tensor2im(gen_out.data[bSize]))
                                         ])
                if opt.lambda_rec > 0:
                    A_out_vis.update([('reconstructed image' + class_a_suffix,
                                       util.tensor2im(rec_out.data[0]))])
                    B_out_vis.update([('reconstructed image' + class_b_suffix,
                                       util.tensor2im(rec_out.data[bSize]))])
                if opt.lambda_cyc > 0:
                    A_out_vis.update([('cycled image' + class_a_suffix,
                                       util.tensor2im(cyc_out.data[0]))])
                    B_out_vis.update([('cycled image' + class_b_suffix,
                                       util.tensor2im(cyc_out.data[bSize]))])

                visuals_A.update(A_out_vis)
                visuals_B.update(B_out_vis)
                visuals.update(visuals_A)
                visuals.update(visuals_B)

                ncols = len(visuals_A)
                visualizer.display_current_results(visuals, epoch, classes,
                                                   ncols)

            ### save latest model
            if total_steps % opt.save_latest_freq == save_delta:
                print('saving the latest model (epoch %d, total_steps %d)' %
                      (epoch + 1, total_steps))
                model.save('latest')
                np.savetxt(iter_path, (epoch, epoch_iter),
                           delimiter=',',
                           fmt='%d')
                if opt.display_id == 0:
                    model.eval()
                    visuals = model.inference(sample_data)
                    visualizer.save_matrix_image(visuals, 'latest')
                    model.train()

        # end of epoch
        iter_end_time = time.time()
        print('End of epoch %d / %d \t Time Taken: %d sec' %
              (epoch + 1, opt.epochs, time.time() - epoch_start_time))

        ### save model for this epoch
        if (epoch + 1) % opt.save_epoch_freq == 0:
            print('saving the model at the end of epoch %d, iters %d' %
                  (epoch + 1, total_steps))
            model.save('latest')
            model.save(epoch + 1)
            np.savetxt(iter_path, (epoch + 1, 0), delimiter=',', fmt='%d')
            if opt.display_id == 0:
                model.eval()
                visuals = model.inference(sample_data)
                visualizer.save_matrix_image(visuals, epoch + 1)
                model.train()

        ### multiply learning rate by opt.decay_gamma after certain iterations
        if (epoch + 1) in opt.decay_epochs:
            model.update_learning_rate()
def main():
    iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')
    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    print('#training images = %d' % dataset_size)

    start_epoch, epoch_iter = 1, 0
    total_steps = (start_epoch - 1) * dataset_size + epoch_iter
    display_delta = total_steps % opt.display_freq
    print_delta = total_steps % opt.print_freq
    save_delta = total_steps % opt.save_latest_freq

    model = create_model(opt)
    model = model.cuda()
    visualizer = Visualizer(opt)

    for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()
        if epoch != start_epoch:
            epoch_iter = epoch_iter % dataset_size
        for i, data in enumerate(dataset, start=epoch_iter):
            iter_start_time = time.time()
            total_steps += opt.batchSize
            epoch_iter += opt.batchSize

            # whether to collect output images
            save_fake = total_steps % opt.display_freq == display_delta

            ############## Forward Pass ######################
            losses, generated = model(Variable(data['label']), Variable(data['inst']),
                                      Variable(data['image']), Variable(data['feat']), infer=save_fake)

            # sum per device losses
            losses = [torch.mean(x) if not isinstance(x, int) else x for x in losses]
            loss_dict = dict(zip(model.loss_names, losses))

            # calculate final loss scalar
            loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5
            loss_G = loss_dict['G_GAN'] + loss_dict.get('G_GAN_Feat', 0) + loss_dict.get('G_VGG', 0)

            ############### Backward Pass ####################
            # update generator weights
            model.optimizer_G.zero_grad()
            loss_G.backward()
            model.optimizer_G.step()

            # update discriminator weights
            model.optimizer_D.zero_grad()
            loss_D.backward()
            model.optimizer_D.step()


            ############## Display results and errors ##########
            ### print out errors
            if total_steps % opt.print_freq == print_delta:
                errors = {k: v.data if not isinstance(v, int) else v for k, v in loss_dict.items()}
                t = (time.time() - iter_start_time) / opt.batchSize
                visualizer.print_current_errors(epoch, epoch_iter, errors, t)
                visualizer.plot_current_errors(errors, total_steps)

            ### display output images
            if save_fake:
                visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)),
                                       ('synthesized_image', util.tensor2im(generated.data[0])),
                                       ('real_image', util.tensor2im(data['image'][0]))])
                visualizer.display_current_results(visuals, epoch, total_steps)

            ### save latest model
            if total_steps % opt.save_latest_freq == save_delta:
                print('saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps))
                model.save('latest')
                np.savetxt(iter_path, (epoch, epoch_iter), delimiter=',', fmt='%d')

            if epoch_iter >= dataset_size:
                break

        # end of epoch
        print('End of epoch %d / %d \t Time Taken: %d sec' %
              (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))

        ### save model for this epoch
        if epoch % opt.save_epoch_freq == 0:
            print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps))
            model.save('latest')
            model.save(epoch)
            np.savetxt(iter_path, (epoch + 1, 0), delimiter=',', fmt='%d')

        ### instead of only training the local enhancer, train the entire network after certain iterations
        if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global):
            model.update_fixed_params()

        ### linearly decay learning rate after certain iterations
        if epoch > opt.niter:
            model.update_learning_rate()

    torch.cuda.empty_cache()
Exemplo n.º 5
0
def train():
    opt = TrainOptions().parse()
    if opt.debug:
        opt.display_freq = 1
        opt.print_freq = 1
        opt.nThreads = 1

    ### initialize dataset
    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    print('#training videos = %d' % dataset_size)

    ### initialize models
    models = create_model(opt)
    modelG, modelD, flowNet, optimizer_G, optimizer_D, optimizer_D_T = create_optimizer(
        opt, models)

    ### set parameters
    n_gpus, tG, tD, tDB, s_scales, t_scales, input_nc, output_nc, \
        start_epoch, epoch_iter, print_freq, total_steps, iter_path = init_params(opt, modelG, modelD, data_loader)
    visualizer = Visualizer(opt)

    ### real training starts here
    for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()
        for idx, data in enumerate(dataset, start=epoch_iter):
            if total_steps % print_freq == 0:
                iter_start_time = time.time()
            total_steps += opt.batchSize
            epoch_iter += opt.batchSize

            # whether to collect output images
            save_fake = total_steps % opt.display_freq == 0
            n_frames_total, n_frames_load, t_len = data_loader.dataset.init_data_params(
                data, n_gpus, tG)
            fake_B_prev_last, frames_all = data_loader.dataset.init_data(
                t_scales)

            for i in range(0, n_frames_total, n_frames_load):
                input_A, input_B, inst_A = data_loader.dataset.prepare_data(
                    data, i, input_nc, output_nc)

                ###################################### Forward Pass ##########################
                ####### generator
                fake_B, fake_B_raw, flow, weight, real_A, real_Bp, fake_B_last = modelG(
                    input_A, input_B, inst_A, fake_B_prev_last)

                ####### discriminator
                ### individual frame discriminator
                real_B_prev, real_B = real_Bp[:, :
                                              -1], real_Bp[:,
                                                           1:]  # the collection of previous and current real frames
                flow_ref, conf_ref = flowNet(
                    real_B, real_B_prev)  # reference flows and confidences
                fake_B_prev = modelG.module.compute_fake_B_prev(
                    real_B_prev, fake_B_prev_last, fake_B)
                fake_B_prev_last = fake_B_last

                losses = modelD(
                    0,
                    reshape([
                        real_B, fake_B, fake_B_raw, real_A, real_B_prev,
                        fake_B_prev, flow, weight, flow_ref, conf_ref
                    ]))
                losses = [
                    torch.mean(x) if x is not None else 0 for x in losses
                ]
                loss_dict = dict(zip(modelD.module.loss_names, losses))

                ### temporal discriminator
                # get skipped frames for each temporal scale
                frames_all, frames_skipped = modelD.module.get_all_skipped_frames(frames_all, \
                        real_B, fake_B, flow_ref, conf_ref, t_scales, tD, n_frames_load, i, flowNet)

                # run discriminator for each temporal scale
                loss_dict_T = []
                for s in range(t_scales):
                    if frames_skipped[0][s] is not None:
                        losses = modelD(s + 1, [
                            frame_skipped[s]
                            for frame_skipped in frames_skipped
                        ])
                        losses = [
                            torch.mean(x) if not isinstance(x, int) else x
                            for x in losses
                        ]
                        loss_dict_T.append(
                            dict(zip(modelD.module.loss_names_T, losses)))

                # collect losses
                loss_G, loss_D, loss_D_T, t_scales_act = modelD.module.get_losses(
                    loss_dict, loss_dict_T, t_scales)

                ###################################### Backward Pass #################################
                # update generator weights
                loss_backward(opt, loss_G, optimizer_G)

                # update individual discriminator weights
                loss_backward(opt, loss_D, optimizer_D)

                # update temporal discriminator weights
                for s in range(t_scales_act):
                    loss_backward(opt, loss_D_T[s], optimizer_D_T[s])

                if i == 0:
                    fake_B_first = fake_B[
                        0, 0]  # the first generated image in this sequence

            if opt.debug:
                call([
                    "nvidia-smi", "--format=csv",
                    "--query-gpu=memory.used,memory.free"
                ])

            ############## Display results and errors ##########
            ### print out errors
            if total_steps % print_freq == 0:
                t = (time.time() - iter_start_time) / print_freq
                errors = {
                    k: v.data.item() if not isinstance(v, int) else v
                    for k, v in loss_dict.items()
                }
                for s in range(len(loss_dict_T)):
                    errors.update({
                        k + str(s):
                        v.data.item() if not isinstance(v, int) else v
                        for k, v in loss_dict_T[s].items()
                    })
                visualizer.print_current_errors(epoch, epoch_iter, errors, t)
                visualizer.plot_current_errors(errors, total_steps)

            ### display output images
            if save_fake:
                visuals = util.save_all_tensors(opt, real_A, fake_B,
                                                fake_B_first, fake_B_raw,
                                                real_B, flow_ref, conf_ref,
                                                flow, weight, modelD)
                visualizer.display_current_results(visuals, epoch, total_steps)

            ### save latest model
            save_models(opt, epoch, epoch_iter, total_steps, visualizer,
                        iter_path, modelG, modelD)
            if epoch_iter > dataset_size - opt.batchSize:
                epoch_iter = 0
                break

        # end of epoch
        iter_end_time = time.time()
        visualizer.vis_print('End of epoch %d / %d \t Time Taken: %d sec' %
                             (epoch, opt.niter + opt.niter_decay,
                              time.time() - epoch_start_time))

        ### save model for this epoch and update model params
        save_models(opt,
                    epoch,
                    epoch_iter,
                    total_steps,
                    visualizer,
                    iter_path,
                    modelG,
                    modelD,
                    end_of_epoch=True)
        update_models(opt, epoch, modelG, modelD, data_loader)
Exemplo n.º 6
0
            # display images on visdom and save images
            if total_iteration % opt.display_freq == 0:
                visualizer.display_current_results(model.get_current_visuals(),
                                                   model.get_current_text(),
                                                   epoch)
                visualizer.plot_current_distribution(model.get_current_dis())

            # print training loss and save logging information to the disk
            if total_iteration % opt.print_freq == 0:
                losses = model.get_current_errors()
                t = (time.time() - iter_start_time) / opt.batchSize
                visualizer.print_current_errors(epoch, total_iteration, losses,
                                                t)
                if opt.display_id > 0:
                    visualizer.plot_current_errors(total_iteration, losses)

            # save the latest model every <save_latest_freq> iterations to the disk
            if total_iteration % opt.save_latest_freq == 0:
                print('saving the latest model (epoch %d, total_steps %d)' %
                      (epoch, total_iteration))
                model.save_networks('latest')

            # save the model every <save_iter_freq> iterations to the disk
            if total_iteration % opt.save_iters_freq == 0:
                print('saving the model of iterations %d' % total_iteration)
                model.save_networks(total_iteration)

            if total_iteration > max_iteration:
                keep_training = False
                break
Exemplo n.º 7
0
def do_train(opt):
    # print options to help debugging
    print(' '.join(sys.argv))

    # load the dataset
    dataloader = data.create_dataloader(opt)

    # create trainer for our model
    trainer = Pix2PixTrainer(opt)

    # create tool for counting iterations
    iter_counter = IterationCounter(opt, len(dataloader))

    # create tool for visualization
    visualizer = Visualizer(opt)

    if opt.train_eval:
        # val_opt = TestOptions().parse()
        original_flip = opt.no_flip
        opt.no_flip = True
        opt.phase = 'test'
        opt.isTrain = False
        dataloader_val = data.create_dataloader(opt)
        val_visualizer = Visualizer(opt)
        # # create a webpage that summarizes the all results
        web_dir = os.path.join(opt.results_dir, opt.name,
                            '%s_%s' % (opt.phase, opt.which_epoch))
        webpage = html.HTML(web_dir,
                            'Experiment = %s, Phase = %s, Epoch = %s' %
                            (opt.name, opt.phase, opt.which_epoch))
        opt.phase = 'train'
        opt.isTrain = True
        opt.no_flip = original_flip
        # process for calculate FID scores
        from inception import InceptionV3
        from fid_score import calculate_fid_given_paths
        import pathlib
        # define the inceptionV3
        block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[opt.eval_dims]
        eval_model = InceptionV3([block_idx]).cuda()
        # load real images distributions on the training set
        mu_np_root = os.path.join('datasets/train_mu_si',opt.dataset_mode,'m.npy')
        st_np_root = os.path.join('datasets/train_mu_si',opt.dataset_mode,'s.npy')
        m0, s0 = np.load(mu_np_root), np.load(st_np_root)
        # load previous best FID
        if opt.continue_train:
            fid_record_dir = os.path.join(opt.checkpoints_dir, opt.name, 'fid.txt')
            FID_score, _ = np.loadtxt(fid_record_dir, delimiter=',', dtype=float)
        else:
            FID_score = 1000
    else:
        FID_score = 1000      

    for epoch in iter_counter.training_epochs():
        iter_counter.record_epoch_start(epoch)
        for i, data_i in enumerate(dataloader, start=iter_counter.epoch_iter):
            iter_counter.record_one_iteration()

            # Training
            # train generator
            if i % opt.D_steps_per_G == 0:
                trainer.run_generator_one_step(data_i)

            # train discriminator
            trainer.run_discriminator_one_step(data_i)

            # Visualizations
            if iter_counter.needs_printing():
                losses = trainer.get_latest_losses()
                if opt.train_eval:
                    visualizer.print_current_errors(epoch, iter_counter.epoch_iter,
                                                    losses, iter_counter.time_per_iter, FID_score)
                else:
                    visualizer.print_current_errors(epoch, iter_counter.epoch_iter,
                                                    losses, iter_counter.time_per_iter)
                visualizer.plot_current_errors(losses, iter_counter.total_steps_so_far)

            # if iter_counter.needs_displaying():
            #     visuals = OrderedDict([('input_label', data_i['label']),
            #                            ('synthesized_image', trainer.get_latest_generated()),
            #                            ('real_image', data_i['image'])])
            #     visualizer.display_current_results(visuals, epoch, iter_counter.total_steps_so_far)

            if iter_counter.needs_saving():
                print('saving the latest model (epoch %d, total_steps %d)' %
                    (epoch, iter_counter.total_steps_so_far))
                trainer.save('latest')
                iter_counter.record_current_iter(FID_score)

        trainer.update_learning_rate(epoch)
        iter_counter.record_epoch_end()

        if epoch % opt.eval_epoch_freq == 0 and opt.train_eval:
            # generate fake image
            trainer.pix2pix_model.eval()
            print('start evalidation .... ')
            if opt.use_vae:
                flag = True
                opt.use_vae = False
            else:
                flag = False
            for i, data_i in enumerate(dataloader_val):
                if data_i['label'].size()[0] != opt.batchSize:
                    if opt.batchSize > 2*data_i['label'].size()[0]:
                        print('batch size is too large')
                        break
                    data_i = repair_data(data_i, opt.batchSize)
                generated = trainer.pix2pix_model(data_i, mode='inference')
                img_path = data_i['path']
                for b in range(generated.shape[0]):
                    tmp = tensor2im(generated[b])
                    visuals = OrderedDict([('input_label', data_i['label'][b]),
                                        ('synthesized_image', generated[b])])
                    val_visualizer.save_images(webpage, visuals, img_path[b:b + 1])
            webpage.save()
            trainer.pix2pix_model.train()
            if flag:
                opt.use_vae = True
            # cal fid score
            fake_path = pathlib.Path(os.path.join(web_dir, 'images/synthesized_image/'))
            files = list(fake_path.glob('*.jpg')) + list(fake_path.glob('*.png'))
            m1, s1 = calculate_activation_statistics(files, eval_model, 1, opt.eval_dims, True, images=None)
            fid_value = calculate_frechet_distance(m0, s0, m1, s1)
            visualizer.print_eval_fids(epoch, fid_value, FID_score)
            # save the best model if necessary
            if fid_value < FID_score:
                FID_score = fid_value
                trainer.save('best')

        if epoch % opt.save_epoch_freq == 0 or \
        epoch == iter_counter.total_epochs:
            print('saving the model at the end of epoch %d, iters %d' %
                (epoch, iter_counter.total_steps_so_far))
            trainer.save('latest')
            trainer.save(epoch)

    print('Training was successfully finished.')
Exemplo n.º 8
0
    epoch_start_time = time.time()
    for i, data in enumerate(dataset):
        iter_start_time = time.time()
        total_steps += opt.batchSize
        epoch_iter = total_steps % num_train
        model.set_input(data)
        model.optimize_parameters()

        if total_steps % opt.display_freq == 0:
            visualizer.display_current_results(model.get_current_visuals(), epoch)

        if total_steps % opt.print_freq == 0:
            errors = model.get_current_errors()
            visualizer.print_current_errors(epoch, epoch_iter, errors, iter_start_time)
            if opt.display_id > 0:
                visualizer.plot_current_errors(epoch, float(epoch_iter)/num_train, opt, errors)

        if total_steps % opt.save_latest_freq == 0:
            print('saving the latest model (epoch %d, total_steps %d)' %
                  (epoch, total_steps))
            model.save('latest')

    if epoch % opt.save_epoch_freq == 0:
        print('saving the model at the end of epoch %d, iters %d' %
              (epoch, total_steps))
        model.save('latest')
        model.save(epoch)

    print('End of epoch %d / %d \t Time Taken: %d sec' %
          (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
def train_pose2vid(target_dir, run_name, temporal_smoothing=False):
    import src.config.train_opt as opt

    opt = update_opt(opt, target_dir, run_name, temporal_smoothing)

    iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.json')
    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    print('#training images = %d' % dataset_size)

    if opt.load_pretrain != '':
        with open(iter_path, 'r') as f:
            iter_json = json.load(f)
    else:
        iter_json = {'start_epoch': 1, 'epoch_iter': 0}

    start_epoch = iter_json['start_epoch']
    epoch_iter = iter_json['epoch_iter']
    total_steps = (start_epoch - 1) * dataset_size + epoch_iter
    display_delta = total_steps % opt.display_freq
    print_delta = total_steps % opt.print_freq
    save_delta = total_steps % opt.save_latest_freq

    model = create_model(opt)
    model = model.to(device)
    visualizer = Visualizer(opt)

    for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()
        if epoch != start_epoch:
            epoch_iter = epoch_iter % dataset_size
        for i, data in enumerate(dataset, start=epoch_iter):
            iter_start_time = time.time()
            total_steps += opt.batchSize
            epoch_iter += opt.batchSize

            # whether to collect output images
            save_fake = total_steps % opt.display_freq == display_delta

            ############## Forward Pass ######################
            if temporal_smoothing:
                losses, generated = model(Variable(data['label']),
                                          Variable(data['inst']),
                                          Variable(data['image']),
                                          Variable(data['feat']),
                                          Variable(data['previous_label']),
                                          Variable(data['previous_image']),
                                          infer=save_fake)
            else:
                losses, generated = model(Variable(data['label']),
                                          Variable(data['inst']),
                                          Variable(data['image']),
                                          Variable(data['feat']),
                                          infer=save_fake)

            # sum per device losses
            losses = [
                torch.mean(x) if not isinstance(x, int) else x for x in losses
            ]
            loss_dict = dict(zip(model.loss_names, losses))

            # calculate final loss scalar
            loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5
            loss_G = loss_dict['G_GAN'] + loss_dict.get(
                'G_GAN_Feat', 0) + loss_dict.get('G_VGG', 0)

            ############### Backward Pass ####################
            # update generator weights
            model.optimizer_G.zero_grad()
            loss_G.backward()
            model.optimizer_G.step()

            # update discriminator weights
            model.optimizer_D.zero_grad()
            loss_D.backward()
            model.optimizer_D.step()

            ############## Display results and errors ##########

            print(f"Epoch {epoch} batch {i}:")
            print(f"loss_D: {loss_D}, loss_G: {loss_G}")
            print(
                f"loss_D_fake: {loss_dict['D_fake']}, loss_D_real: {loss_dict['D_real']}"
            )
            print(
                f"loss_G_GAN {loss_dict['G_GAN']}, loss_G_GAN_Feat: {loss_dict.get('G_GAN_Feat', 0)}, loss_G_VGG: {loss_dict.get('G_VGG', 0)}\n"
            )

            ### print out errors
            if total_steps % opt.print_freq == print_delta:
                errors = {
                    k: v.item() if not isinstance(v, int) else v
                    for k, v in loss_dict.items()
                }
                # errors = {k: v.data[0] if not isinstance(v, int) else v for k, v in loss_dict.items()}
                t = (time.time() - iter_start_time) / opt.batchSize
                visualizer.print_current_errors(epoch, epoch_iter, errors, t)
                visualizer.plot_current_errors(errors, total_steps)

            ### display output images
            if save_fake:
                visuals = OrderedDict([
                    ('input_label',
                     util.tensor2label(data['label'][0], opt.label_nc)),
                    ('synthesized_image', util.tensor2im(generated.data[0])),
                    ('real_image', util.tensor2im(data['image'][0]))
                ])
                visualizer.display_current_results(visuals, epoch, total_steps)

            ### save latest model
            if total_steps % opt.save_latest_freq == save_delta:
                print('saving the latest model (epoch %d, total_steps %d)' %
                      (epoch, total_steps))
                model.save('latest')
                iter_json['start_epoch'] = epoch
                iter_json['epoch_iter'] = epoch_iter
                with open(iter_path, 'w') as f:
                    json.dump(iter_json, f)

            if epoch_iter >= dataset_size:
                break

        # end of epoch
        print('End of epoch %d / %d \t Time Taken: %d sec' %
              (epoch, opt.niter + opt.niter_decay,
               time.time() - epoch_start_time))

        ### save model for this epoch
        if epoch % opt.save_epoch_freq == 0:
            print('saving the model at the end of epoch %d, iters %d' %
                  (epoch, total_steps))
            model.save('latest')
            model.save(epoch)
            iter_json['start_epoch'] = epoch + 1
            iter_json['epoch_iter'] = 0
            with open(iter_path, 'w') as f:
                json.dump(iter_json, f)

        ### instead of only training the local enhancer, train the entire network after certain iterations
        if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global):
            model.update_fixed_params()

        ### linearly decay learning rate after certain iterations
        if epoch > opt.niter:
            model.update_learning_rate()

    torch.cuda.empty_cache()
Exemplo n.º 10
0
        for j in range(opt.TEST_ITERS):
            iter_start_time = time.time()
            visualizer.reset()
            total_steps += opt.batchSize
            epoch_iter += opt.batchSize
            model.optimize_noise()
            if total_steps % opt.display_freq == 0:
                save_result = total_steps % opt.update_html_freq == 0
                visualizer.display_current_results(model.get_current_visuals(), epoch, save_result)

            if total_steps % opt.print_freq == 0:
                errors = model.get_current_errors()
                t = (time.time() - iter_start_time) / opt.batchSize
                visualizer.print_current_errors(epoch, epoch_iter, errors, t)
                if opt.display_id > 0:
                    visualizer.plot_current_errors(epoch, float(epoch_iter)/epoch_length, opt, errors)

            if total_steps % opt.save_model_freq == 0:
                print('saving the at (total_steps %d)' %
                      (total_steps))
                model.save(webpage, img_path, total_steps)
            if total_steps % epoch_length == 0:
                break

        print('End of epoch %d / %d \t Time Taken: %d sec' %
              (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
        model.update_learning_rate()
    visuals = model.get_current_visuals()
    print('%04d: process image... %s' % (i, img_path))
    visualizer.save_images(webpage, visuals, img_path)
Exemplo n.º 11
0
                print("Incomplete Frame: {0} Size: {1} Word: {2}".format(
                    frame_idx, img.size(), trans))
                init_tensor = True
                continue

            if frame_idx % 40 == 0:
                init_tensor = True

            model.set_input(frame)
            pred_frame = model.test(init_tensor)
            init_tensor = False

            writer.writeFrame(pred_frame)
            errors = model.get_current_errors()
            t = (time.time() - iter_start_time) / opt.batchSize
            visualizer.print_current_errors(vid_idx, frame_idx, errors, t,
                                            t_data)

            if opt.display_id > 0:
                visualizer.plot_current_errors(vid_idx,
                                               float(frame_idx) / len(video),
                                               opt, errors)

        writer.close()
        visuals = model.get_current_visuals()
        #img_path = model.get_image_paths()
        print('%04d: process video... %s' % (vid_idx, vid_path))
        #visualizer.save_images(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio)

    webpage.save()
Exemplo n.º 12
0
        iter_start_time = time.time()
        total_steps += opt.batchSize
        epoch_iter = total_steps % num_train

	
        model.set_input(data)
        model.optimize_parameters()

        if total_steps % opt.display_freq == 0:
            visualizer.display_current_results(model.get_current_visuals(), epoch)

        if total_steps % opt.print_freq == 0:
            errors = model.get_current_errors()
            visualizer.print_current_errors(epoch, epoch_iter, errors, iter_start_time)
            if opt.display_id > 0:
                visualizer.plot_current_errors(epoch, epoch_iter, opt, errors)

        if total_steps % opt.save_latest_freq == 0:
            print('saving the latest model (epoch %d, total_steps %d)' %
                  (epoch, total_steps))
            model.save('latest')

    if epoch % opt.save_epoch_freq == 0:
        print('saving the model at the end of epoch %d, iters %d' %
              (epoch, total_steps))
        model.save('latest')
        model.save(epoch)

    print('End of epoch %d / %d \t Time Taken: %d sec' %
          (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
Exemplo n.º 13
0
def train():
    opt = vars(TrainOptions().parse())
    # Initialize dataset:==============================================================================================#
    data_loader = create_dataloader(**opt)
    dataset_size = len(data_loader)
    print(f'Number of training videos = {dataset_size}')
    # Initialize models:===============================================================================================#
    models = prepare_models(**opt)
    model_g, model_d, flow_net, optimizer_g, optimizer_d, optimizer_d_t = create_optimizer(
        models, **opt)
    # Set parameters:==================================================================================================#
    n_gpus, tG, tD, tDB, s_scales, t_scales, input_nc, output_nc, \
    start_epoch, epoch_iter, print_freq, total_steps, iter_path = init_params(model_g, model_d, data_loader, **opt)
    visualizer = Visualizer(**opt)

    # Initialize loss list:============================================================================================#
    losses_G = []
    losses_D = []
    # Training start:==================================================================================================#
    for epoch in range(start_epoch, opt['niter'] + opt['niter_decay'] + 1):
        epoch_start_time = time.time()
        for idx, video in enumerate(data_loader, start=epoch_iter):
            if not total_steps % print_freq:
                iter_start_time = time.time()
            total_steps += opt['batch_size']
            epoch_iter += opt['batch_size']

            # whether to collect output images
            save_fake = total_steps % opt['display_freq'] == 0
            fake_B_prev_last = None
            real_B_all, fake_B_all, flow_ref_all, conf_ref_all = None, None, None, None  # all real/generated frames so far
            if opt['sparse_D']:
                real_B_all, fake_B_all, flow_ref_all, conf_ref_all = [
                    None
                ] * t_scales, [None] * t_scales, [None] * t_scales, [
                    None
                ] * t_scales
            frames_all = real_B_all, fake_B_all, flow_ref_all, conf_ref_all

            for i, (input_A, input_B) in enumerate(VideoSeq(**video, **opt)):

                # Forward Pass:========================================================================================#
                # Generator:===========================================================================================#
                fake_B, fake_B_raw, flow, weight, real_A, real_Bp, fake_B_last = model_g(
                    input_A, input_B, fake_B_prev_last)

                # Discriminator:=======================================================================================#
                # individual frame discriminator:==============================#
                # the collection of previous and current real frames
                real_B_prev, real_B = real_Bp[:, :-1], real_Bp[:, 1:]
                # reference flows and confidences
                flow_ref, conf_ref = flow_net(real_B, real_B_prev)
                fake_B_prev = model_g.compute_fake_B_prev(
                    real_B_prev, fake_B_prev_last, fake_B)
                fake_B_prev_last = fake_B_last

                losses = model_d(
                    0,
                    reshape([
                        real_B, fake_B, fake_B_raw, real_A, real_B_prev,
                        fake_B_prev, flow, weight, flow_ref, conf_ref
                    ]))
                losses = [
                    torch.mean(x) if x is not None else 0 for x in losses
                ]
                loss_dict = dict(zip(model_d.loss_names, losses))

                # Temporal Discriminator:======================================#
                # get skipped frames for each temporal scale
                frames_all, frames_skipped = \
                    model_d.get_all_skipped_frames(frames_all,
                                                  real_B,
                                                  fake_B,
                                                  flow_ref,
                                                  conf_ref,
                                                  t_scales,
                                                  tD,
                                                  video.n_frames_load,
                                                  i,
                                                  flow_net)
                # run discriminator for each temporal scale:===================#
                loss_dict_T = []
                for s in range(t_scales):
                    if frames_skipped[0][s] is not None:
                        losses = model_d(s + 1, [
                            frame_skipped[s]
                            for frame_skipped in frames_skipped
                        ])
                        losses = [
                            torch.mean(x) if not isinstance(x, int) else x
                            for x in losses
                        ]
                        loss_dict_T.append(
                            dict(zip(model_d.loss_names_T, losses)))

                # Collect losses:==============================================#
                loss_G, loss_D, loss_D_T, t_scales_act = model_d.get_losses(
                    loss_dict, loss_dict_T, t_scales)

                losses_G.append(loss_G.item())
                losses_D.append(loss_D.item())

                ################################## Backward Pass ###########################
                # Update generator weights
                loss_backward(loss_G, optimizer_g)

                # update individual discriminator weights
                loss_backward(loss_D, optimizer_d)

                # update temporal discriminator weights
                for s in range(t_scales_act):
                    loss_backward(opt, loss_D_T[s], optimizer_d_t[s])
                # the first generated image in this sequence
                if i == 0:
                    fake_B_first = fake_B[0, 0]

            # Display results and errors:==============================================#
            # Print out errors:================================================#
            if total_steps % print_freq == 0:
                t = (time.time() - iter_start_time) / print_freq
                errors = {k: v.data.item() if not isinstance(v, int) \
                    else v for k, v in loss_dict.items()}
                for s in range(len(loss_dict_T)):
                    errors.update({k + str(s): v.data.item() \
                        if not isinstance(v, int) \
                        else v for k, v in loss_dict_T[s].items()})
                visualizer.print_current_errors(epoch, epoch_iter, errors, t)
                visualizer.plot_current_errors(errors, total_steps)

            # Display output images:===========================================#
            if save_fake:
                visuals = util.save_all_tensors(opt, real_A, fake_B,
                                                fake_B_first, fake_B_raw,
                                                real_B, flow_ref, conf_ref,
                                                flow, weight, model_d)
                visualizer.display_current_results(visuals, epoch, total_steps)

            # Save latest model:===============================================#
            save_models(epoch, epoch_iter, total_steps, visualizer, iter_path,
                        model_g, model_d, **opt)
            if epoch_iter > dataset_size - opt['batch_size']:
                epoch_iter = 0
                break

        # End of epoch:========================================================#
        visualizer.vis_print(
            f'End of epoch {epoch} / {opt["niter"] + opt["niter_decay"]} \t'
            f' Time Taken: {time.time() - epoch_start_time} sec')

        # save model for this epoch and update model params:=================#
        save_models(epoch,
                    epoch_iter,
                    total_steps,
                    visualizer,
                    iter_path,
                    model_g,
                    model_d,
                    end_of_epoch=True,
                    **opt)
        update_models(epoch, model_g, model_d, data_loader, **opt)

        from matplotlib import pyplot as plt
        plt.switch_backend('agg')
        print("Generator Loss: %f." % losses_G[-1])
        print("Discriminator Loss: %f." % losses_D[-1])
        # Plot Losses
        plt.plot(losses_G, '-b', label='losses_G')
        plt.plot(losses_D, '-r', label='losses_D')
        # plt.plot(losses_D_T, '-r', label='losses_D_T')
        plot_name = 'checkpoints/' + opt['name'] + '/losses_plot.png'
        plt.savefig(plot_name)
        plt.close()
Exemplo n.º 14
0
def train():
    opt = TrainOptions().parse()
    if opt.debug:
        opt.display_freq = 1
        opt.print_freq = 1
        opt.nThreads = 1

    # Initialize dataset:======================================================#
    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    print('Number of training videos = %d' % dataset_size)

    # Initialize models:=======================================================#
    models = prepare_models(opt)
    modelG, modelD, flowNet, optimizer_G, optimizer_D, optimizer_D_T = \
            create_optimizer(opt, models)

    # Set parameters:==========================================================#
    n_gpus, tG, tD, tDB, s_scales, t_scales, input_nc, output_nc, \
    start_epoch, epoch_iter, print_freq, total_steps, iter_path = \
    init_params(opt, modelG, modelD, data_loader)
    visualizer = Visualizer(opt)

    # Initialize loss list:====================================================#
    losses_G = []
    losses_D = []
    losses_D_T = []
    losses_t_scales = []

    # Real training starts here:===============================================#
    for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()
        for idx, data in enumerate(dataset, start=epoch_iter):
            if total_steps % print_freq == 0:
                iter_start_time = time.time()
            total_steps += opt.batch_size
            epoch_iter += opt.batch_size

            # whether to collect output images
            save_fake = total_steps % opt.display_freq == 0
            n_frames_total, n_frames_load, t_len = \
                data_loader.dataset.init_data_params(data, n_gpus, tG)
            fake_B_prev_last, frames_all = data_loader.dataset.init_data(
                t_scales)

            for i in range(0, n_frames_total, n_frames_load):
                input_A, input_B, input_C, inst_A = \
                        data_loader.dataset.prepare_data(data, i, input_nc, output_nc)

                ############################### Forward Pass ###############################
                ####### Generator:=========================================================#
                fake_B, fake_B_raw, flow, weight, real_A, real_Bp, fake_B_last = \
                        modelG(input_A, input_B, inst_A, fake_B_prev_last)

                ####### Discriminator:=====================================================#
                # individual frame discriminator:==============================#
                # the collection of previous and current real frames
                real_B_prev, real_B = real_Bp[:, :-1], real_Bp[:, 1:]
                # reference flows and confidences
                flow_ref, conf_ref = flowNet(real_B, real_B_prev)
                fake_B_prev = modelG.module.compute_fake_B_prev(
                    real_B_prev, fake_B_prev_last, fake_B)
                fake_B_prev_last = fake_B_last

                losses = modelD(
                    0,
                    reshape([
                        real_B, fake_B, fake_B_raw, real_A, real_B_prev,
                        fake_B_prev, flow, weight, flow_ref, conf_ref, input_C
                    ]))
                losses = [
                    torch.mean(x) if x is not None else 0 for x in losses
                ]
                loss_dict = dict(zip(modelD.module.loss_names, losses))

                # Temporal Discriminator:======================================#
                # get skipped frames for each temporal scale
                frames_all, frames_skipped = \
                modelD.module.get_all_skipped_frames(frames_all,
                real_B,
                                                     fake_B,
                                                     flow_ref,
                                                     conf_ref,
                                                     t_scales,
                                                     tD,
                                                     n_frames_load,
                                                     i,
                                                     flowNet)
                # run discriminator for each temporal scale:===================#
                loss_dict_T = []
                for s in range(t_scales):
                    if frames_skipped[0][s] is not None:
                        losses = modelD(s + 1, [
                            frame_skipped[s]
                            for frame_skipped in frames_skipped
                        ])
                        losses = [
                            torch.mean(x) if not isinstance(x, int) else x
                            for x in losses
                        ]
                        loss_dict_T.append(
                            dict(zip(modelD.module.loss_names_T, losses)))

                # Collect losses:==============================================#
                loss_G, loss_D, loss_D_T, t_scales_act = \
                                modelD.module.get_losses(loss_dict, loss_dict_T, t_scales)

                losses_G.append(loss_G.item())
                losses_D.append(loss_D.item())

                ################################## Backward Pass ###########################
                # Update generator weights
                loss_backward(opt, loss_G, optimizer_G)

                # update individual discriminator weights
                loss_backward(opt, loss_D, optimizer_D)

                # update temporal discriminator weights
                for s in range(t_scales_act):
                    loss_backward(opt, loss_D_T[s], optimizer_D_T[s])
                # the first generated image in this sequence
                if i == 0: fake_B_first = fake_B[0, 0]

            if opt.debug:
                call([
                    "nvidia-smi", "--format=csv",
                    "--query-gpu=memory.used,memory.free"
                ])

    # Display results and errors:==============================================#
    # Print out errors:================================================#
            if total_steps % print_freq == 0:
                t = (time.time() - iter_start_time) / print_freq
                errors = {k: v.data.item() if not isinstance(v, int) \
                                           else v for k, v in loss_dict.items()}
                for s in range(len(loss_dict_T)):
                    errors.update({k + str(s): v.data.item() \
                                   if not isinstance(v, int) \
                                   else v for k, v in loss_dict_T[s].items()})
                visualizer.print_current_errors(epoch, epoch_iter, errors, t)
                visualizer.plot_current_errors(errors, total_steps)

            # Display output images:===========================================#
            if save_fake:
                visuals = util.save_all_tensors(opt, real_A, fake_B,
                                                fake_B_first, fake_B_raw,
                                                real_B, flow_ref, conf_ref,
                                                flow, weight, modelD)
                visualizer.display_current_results(visuals, epoch, total_steps)

            # Save latest model:===============================================#
            save_models(opt, epoch, epoch_iter, total_steps, visualizer,
                        iter_path, modelG, modelD)
            if epoch_iter > dataset_size - opt.batch_size:
                epoch_iter = 0
                break

        # End of epoch:========================================================#
        visualizer.vis_print('End of epoch %d / %d \t Time Taken: %d sec' % \
       (epoch, opt.niter + opt.niter_decay,
                             time.time() - epoch_start_time))

        ### save model for this epoch and update model params:=================#
        save_models(opt,
                    epoch,
                    epoch_iter,
                    total_steps,
                    visualizer,
                    iter_path,
                    modelG,
                    modelD,
                    end_of_epoch=True)
        update_models(opt, epoch, modelG, modelD, data_loader)

        from matplotlib import pyplot as plt
        plt.switch_backend('agg')
        print("Generator Loss: %f." % losses_G[-1])
        print("Discriminator loss: %f." % losses_D[-1])
        #Plot Losses
        plt.plot(losses_G, '-b', label='losses_G')
        plt.plot(losses_D, '-r', label='losses_D')
        # plt.plot(losses_D_T, '-r', label='losses_D_T')
        plot_name = 'checkpoints/' + opt.name + '/losses_plot.png'
        plt.savefig(plot_name)
        plt.close()
Exemplo n.º 15
0
        # print('setting input data to model!')
        model.set_input(x, noise, cond_c, cond_d)
        model.optimize_parameters()

        if total_steps % opt.display_freq == 0:
            visualizer.display_current_results(model.get_current_visuals(), epoch)
            # visualizer.plot_current_label(model.get_current_labels(), epoch)
            # visualizer.display_current_results(model.get_current_visuals(), epoch_iter)

        if total_steps % opt.print_freq == 0:
            errors = model.get_current_errors()
            t = (time.time() - iter_start_time) / opt.batchSize
            visualizer.print_current_errors(epoch, epoch_iter, errors, t)
            if opt.display_id > 0:
                visualizer.plot_current_errors(epoch, float(epoch_iter)/(dataset_size * opt.batchSize), opt, errors)

        if total_steps % opt.save_latest_freq == 0:
            print('saving the latest model (epoch %d, total_steps %d)' %
                  (epoch, total_steps))
            model.save('latest')

    if epoch % opt.save_epoch_freq == 0:
        print('saving the model at the end of epoch %d, iters %d' %
              (epoch, total_steps))
        model.save('latest')
        model.save(epoch)

    print('End of epoch %d / %d \t Time Taken: %d sec' %
          (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
################################################################
Exemplo n.º 16
0
        cost.backward()
        optimizer.step()

        step_num += 1

        if np.mod(step_num, opt.print_freq) == 0:
            elapsed_time = timer() - t
            print('%s: %s / %s, ... elapsed time: %f (s)' %
                  (epoch, step_num, int(
                      len(dataset) / opt.batchSize), elapsed_time))
            print(inv_depths_mean)
            t = timer()
            visualizer.plot_current_errors(
                step_num, 1, opt,
                OrderedDict([
                    ('photometric_cost', photometric_cost.data.cpu()[0]),
                    ('smoothness_cost', smoothness_cost.data.cpu()[0]),
                    ('cost', cost.data.cpu()[0])
                ]))

        if np.mod(step_num, opt.display_freq) == 0:
            # frame_vis = frames.data[:,1,:,:,:].permute(0,2,3,1).contiguous().view(-1,opt.imW, 3).cpu().numpy().astype(np.uint8)
            # depth_vis = vis_depthmap(inv_depths.data[:,1,:,:].contiguous().view(-1,opt.imW).cpu()).numpy().astype(np.uint8)
            frame_vis = frames.data.permute(
                1, 2, 0).contiguous().cpu().numpy().astype(np.uint8)
            depth_vis = vis_depthmap(inv_depths.data.cpu()).numpy().astype(
                np.uint8)
            visualizer.display_current_results(
                OrderedDict([('%s frame' % (opt.name), frame_vis),
                             ('%s inv_depth' % (opt.name), depth_vis)]), epoch)
            sio.savemat(
Exemplo n.º 17
0
        visualizer.reset()
        total_steps += 1
        model.set_input(data)
        model.optimize_parameters()

        if step % opt.display_step == 0:
            save_result = step % opt.update_html_freq == 0
            visualizer.display_current_results(model.get_current_visuals(), epoch, save_result)

            errors = model.get_current_errors()
            t = (time.time() - iter_start_time) / opt.batchSize
            visualizer.print_current_errors(epoch, step, training_iters, errors, t, 'Train')

        if step % opt.plot_step == 0:
            if opt.display_id > 0:
                visualizer.plot_current_errors(epoch, step / float(training_iters), opt, errors)

        if total_steps % opt.save_latest_freq == 0:
            print('saving the latest model (epoch %d, total_steps %d)' %
                  (epoch, total_steps))
            model.save('latest')

    if epoch % opt.save_epoch_freq == 0:
        print('saving the model at the end of epoch %d, iters %d' %
              (epoch, total_steps))
        model.save('latest')
        model.save(epoch)
    print('End of epoch %d / %d \t Time Taken: %d sec' %
          (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))

    model.update_learning_rate()
Exemplo n.º 18
0
class Trainer():
    def __init__(self, opt, data_loader):
        iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')
        start_epoch, epoch_iter = 1, 0
        ### if continue training, recover previous states
        if opt.continue_train:
            if os.path.exists(iter_path):
                start_epoch, epoch_iter = np.loadtxt(iter_path,
                                                     delimiter=',',
                                                     dtype=int)
            print('Resuming from epoch %d at iteration %d' %
                  (start_epoch, epoch_iter))

        print_freq = lcm(opt.print_freq, opt.batchSize)
        total_steps = (start_epoch - 1) * len(data_loader) + epoch_iter
        total_steps = total_steps // print_freq * print_freq

        self.opt = opt
        self.epoch_iter, self.print_freq, self.total_steps, self.iter_path = epoch_iter, print_freq, total_steps, iter_path
        self.start_epoch, self.epoch_iter = start_epoch, epoch_iter
        self.dataset_size = len(data_loader)
        self.visualizer = Visualizer(opt)

    def start_of_iter(self):
        if self.total_steps % self.print_freq == 0:
            self.iter_start_time = time.time()
        self.total_steps += self.opt.batchSize
        self.epoch_iter += self.opt.batchSize
        self.save = self.total_steps % self.opt.display_freq == 0

    def end_of_iter(self, loss_dicts, output_list, model):
        opt = self.opt
        epoch, epoch_iter, print_freq, total_steps = self.epoch, self.epoch_iter, self.print_freq, self.total_steps
        ############## Display results and errors ##########
        ### print out errors
        if is_master() and total_steps % print_freq == 0:
            t = (time.time() - self.iter_start_time) / print_freq
            errors = {
                k: v.data.item() if not isinstance(v, int) else v
                for k, v in loss_dicts.items()
            }
            self.visualizer.print_current_errors(epoch, epoch_iter, errors, t)
            self.visualizer.plot_current_errors(errors, total_steps)

        ### display output images
        if is_master() and self.save:
            visuals = save_all_tensors(opt, output_list, model)
            self.visualizer.display_current_results(visuals, epoch,
                                                    total_steps)

        if is_master() and opt.print_mem:
            call([
                "nvidia-smi", "--format=csv",
                "--query-gpu=memory.used,memory.free"
            ])

        ### save latest model
        save_models(opt, epoch, epoch_iter, total_steps, self.visualizer,
                    self.iter_path, model)
        if epoch_iter > self.dataset_size - opt.batchSize:
            return True
        return False

    def start_of_epoch(self, epoch, model, data_loader):
        self.epoch = epoch
        self.epoch_start_time = time.time()
        if self.opt.distributed:
            data_loader.dataloader.sampler.set_epoch(epoch)
        # update model params
        update_models(self.opt, epoch, model, data_loader)

    def end_of_epoch(self, model):
        opt = self.opt
        iter_end_time = time.time()
        self.visualizer.vis_print(
            opt, 'End of epoch %d / %d \t Time Taken: %d sec' %
            (self.epoch, opt.niter + opt.niter_decay,
             time.time() - self.epoch_start_time))

        ### save model for this epoch
        save_models(opt,
                    self.epoch,
                    self.epoch_iter,
                    self.total_steps,
                    self.visualizer,
                    self.iter_path,
                    model,
                    end_of_epoch=True)
        self.epoch_iter = 0
Exemplo n.º 19
0
        epoch_iter = total_steps - dataset_size * (epoch - 1)
        model.set_input(data)
        model.optimize_parameters()
        model.accum_accs()

        if total_steps % opt_train.display_freq == 0:
            visualizer.display_current_results(model.get_current_visuals(),
                                               epoch)

        if total_steps % opt_train.print_freq == 0:
            errors = model.get_current_errors()
            t = (time.time() - iter_start_time) / opt_train.batchSize
            visualizer.print_current_errors(epoch, epoch_iter, errors, t)
            if opt_train.display_id > 0:
                visualizer.plot_current_errors(
                    epoch,
                    float(epoch_iter) / dataset_size, opt_train, errors)
        if total_steps % opt_train.print_freq == 0:
            accs = model.get_current_accs()
            visualizer_acc.plot_current_errors(
                epoch,
                float(epoch_iter) / dataset_size, opt_train, accs)

        if total_steps % opt_train.save_latest_freq == 0:
            print('saving the latest model (epoch %d, total_steps %d)' %
                  (epoch, total_steps))
            model.save('latest')

    ####################################################################################################################
    # eval val after every epoch
    model.reset_accs()
Exemplo n.º 20
0
class Tester:
    def __init__(self, opt, dataset_key='test', visualizer=None):
        self.opt = deepcopy(opt)

        self.opt.serial_batches = True
        self.opt.no_flip = True
        self.opt.isTrain = False
        self.opt.dataset_key = dataset_key

        if 'results_dir' not in self.opt:
            self.opt.results_dir = 'results/'

        self.dataloader = data.create_dataloader(self.opt)

        self.visualizer = Visualizer(
            self.opt) if visualizer is None else visualizer

        base_path = os.getcwd()
        if self.opt.checkpoints_dir.startswith("./"):
            self.opt.checkpoints_dir = os.path.join(
                base_path, self.opt.checkpoints_dir[2:])
        else:
            self.opt.checkpoints_dir = os.path.join(base_path,
                                                    self.opt.checkpoints_dir)

        self.is_validation = self.opt.dataset_key in ["val", "train"]
        self.N = self.dataloader.dataset.N

        self.results_dir = os.path.join(opt.checkpoints_dir, self.opt.name,
                                        self.opt.results_dir,
                                        self.opt.dataset_key)
        if not os.path.exists(self.results_dir):
            os.makedirs(self.results_dir)

    def forward(self, model, data_i):
        fake = model.forward(data_i, mode="inference").detach().cpu()
        fake_resized = ImageProcessor.to_255resized_imagebatch(fake,
                                                               as_tensor=True)
        return fake, fake_resized

    def get_iterator(self, dataloader, indices=None):
        """

        Args:
            indices: a list of indices that should be loaded from dataloader. If it is none, the iterator iterates
                over the entire dataset.

        Returns: iterator

        """
        if indices is None:
            for data_i in dataloader:
                yield data_i
        else:
            for i_val in indices:
                data_i = dataloader.dataset.get_particular(i_val)
                yield data_i

    def _prepare_error_log(self):
        error_log = h5py.File(
            os.path.join(self.results_dir,
                         f"error_log_{self.opt.dataset_key}.h5"), "w")
        error_log.create_dataset("error", shape=(self.N, ), dtype=np.float)
        error_log.create_dataset("user", shape=(self.N, ), dtype='S4')
        error_log.create_dataset("filename", shape=(self.N, ), dtype='S13')
        error_log.create_dataset("visualisation",
                                 shape=(self.N, 1, 380, 1000),
                                 dtype=np.uint8)
        return error_log

    def _write_error_log_batch(self, error_log, data_i, i, fake, errors):
        visualisation_data = {**data_i, "fake": fake}
        visuals = visualize_sidebyside(visualisation_data, error_list=errors)

        # We add the entire batch to the output file
        idx_from, idx_to = i * self.opt.batchSize, i * self.opt.batchSize + self.opt.batchSize
        error_log["user"][idx_from:idx_to] = np.array(data_i["user"],
                                                      dtype='S4')
        error_log["filename"][idx_from:idx_to] = np.array(data_i["filename"],
                                                          dtype='S13')
        error_log["error"][idx_from:idx_to] = errors
        vis = np.array([np.copy(v) for k, v in visuals.items()])
        # vis are all floats in [-1, 1]
        vis = (vis + 1) * 128
        error_log["visualisation"][idx_from:idx_to] = vis
        return error_log

    def run_batch(self, data_i, model):
        fake, fake_resized = self.forward(model, data_i)
        target_image = ImageProcessor.as_batch(data_i["target_original"],
                                               as_tensor=True)
        errors = np.array(
            MSECalculator.calculate_mse_for_images(fake_resized, target_image))
        return errors, fake, fake_resized, target_image

    def run_validation(self,
                       model,
                       generator,
                       limit=-1,
                       write_error_log=False):
        print(f"write error log: {write_error_log}")
        assert self.is_validation, "Must be in validation mode"
        if write_error_log:
            error_log = self._prepare_error_log()

        all_errors = list()

        counter = 0
        for i, data_i in enumerate(generator):
            counter += data_i['label'].shape[0]
            if counter > limit:
                break
            if i % 10 == 9:
                print(f"Processing batch {i}")
                print(
                    f"Error so far: {np.sum(all_errors) / len(all_errors) * 1471}"
                )
            errors, fake, fake_resized, target_image = self.run_batch(
                data_i, model)
            all_errors += list(errors)
            if write_error_log:
                error_log = self._write_error_log_batch(
                    error_log, data_i, i, fake, errors)

        if write_error_log:
            error_log.close()
        return all_errors

    def print_results(self,
                      all_errors,
                      errors_dict,
                      epoch='n.a.',
                      n_steps="n.a."):
        print("Validation Results")
        print("------------------")
        print(
            f"Error calculated on {len(all_errors)} / {self.dataloader.dataset.N} samples"
        )
        for k in sorted(errors_dict):
            print(f"  {k}, {errors_dict[k]:.2f}")
        print(
            f"  dataset_key: {self.opt.dataset_key}, model: {self.opt.name}, epoch: {epoch}, n_steps: {n_steps}"
        )

    def run_visual_validation(self, model, mode, epoch, n_steps, limit):
        print(f"Visualizing images for mode '{mode}'...")
        indices = self._get_validation_indices(mode, limit)
        generator = self.get_iterator(self.dataloader, indices=indices)

        result_list = list()
        error_list = list()
        for data_i in generator:
            # data_i = dataloader.dataset.get_particular(i_val)
            errors, fake, fake_resized, target_image = self.run_batch(
                data_i, model)
            data_i['fake'] = fake
            result_list.append(data_i)
            error_list.append(errors)
        error_list = np.array(error_list)
        error_list = error_list.reshape(-1)
        result = {
            k: [rl[k] for rl in result_list]
            for k in result_list[0].keys()
        }
        for key in [
                "style_image", "target", "target_original", "fake", "label"
        ]:
            result[key] = torch.cat(result[key], dim=0)

        visuals = visualize_sidebyside(
            result,
            log_key=f"{self.opt.dataset_key}/{mode}",
            w=200,
            h=320,
            error_list=error_list)
        self.visualizer.display_current_results(visuals, epoch, n_steps)

    def _get_validation_indices(self, mode, limit):
        if 'rand' in mode:
            validation_indices = self.dataloader.dataset.get_random_indices(
                limit)
        elif 'fix' in mode:
            # Use fixed validation indices
            validation_indices = self.dataloader.dataset.get_validation_indices(
            )[:limit]
        elif 'full' in mode:
            validation_indices = None
        else:
            raise ValueError(f"Invalid mode: {mode}")
        return validation_indices

    def run(self,
            model,
            mode,
            epoch=None,
            n_steps=None,
            limit=-1,
            write_error_log=False,
            log=False):
        print(f"Running validation for mode '{mode}'...")
        limit = limit if limit > 0 else self.dataloader.dataset.N
        indices = self._get_validation_indices(mode, limit)
        generator = self.get_iterator(self.dataloader, indices=indices)
        all_errors = self.run_validation(model,
                                         generator,
                                         limit=limit,
                                         write_error_log=write_error_log)

        errors_dict = MSECalculator.calculate_error_statistics(
            all_errors, mode=mode, dataset_key=self.opt.dataset_key)
        self.print_results(all_errors, errors_dict, epoch, n_steps)

        if log:
            self.log_visualizer(errors_dict, epoch, n_steps)

    def log_visualizer(self, errors_dict, epoch=0, total_steps_so_far=0):
        """

        Args:
            errors_dict: must contain
            epoch:
            total_steps_so_far:
            log_key:

        Returns:

        """
        self.visualizer.print_current_errors(epoch,
                                             total_steps_so_far,
                                             errors_dict,
                                             t=0)
        self.visualizer.plot_current_errors(errors_dict, total_steps_so_far)

    def run_test(self, model, limit=-1):
        filepaths = list()

        for i, data_i in enumerate(self.dataloader):
            if limit > 0 and i * self.opt.batchSize >= limit:
                break
            if i % 10 == 0:
                print(
                    f"Processing batch {i} (processed {self.opt.batchSize * i} images)"
                )

            # The test file names are only 12 characters long, so we have dot to remove
            img_filename = [re.sub(r'\.', '', f) for f in data_i['filename']]

            fake, fake_resized = self.forward(model, data_i)
            # We are testing
            for b in range(len(img_filename)):
                result_path = os.path.join(self.results_dir,
                                           img_filename[b] + ".npy")
                assert torch.min(fake_resized[b]) >= 0 and torch.max(
                    fake_resized[b]) <= 255
                np.save(result_path, np.copy(fake_resized[b]).astype(np.uint8))
                filepaths.append(result_path)

        # We are testing
        path_filepaths = os.path.join(self.results_dir, "pred_npy_list.txt")
        with open(path_filepaths, 'w') as f:
            for line in filepaths:
                f.write(line)
                f.write(os.linesep)
        print(f"Written {len(filepaths)} files. Filepath: {path_filepaths}")

    def run_partial_modes(self, model, epoch, n_steps, log, visualize_images,
                          limit):
        # for mode in ['fix', 'rand']:
        for mode in ['rand']:
            self.run(model=model,
                     mode=mode,
                     epoch=epoch,
                     n_steps=n_steps,
                     log=log,
                     limit=limit)
            if visualize_images:
                self.run_visual_validation(model,
                                           mode=mode,
                                           epoch=epoch,
                                           n_steps=n_steps,
                                           limit=4)
Exemplo n.º 21
0
def train():
    opt = TrainOptions().parse()
    if opt.debug:
        opt.display_freq = 1
        opt.print_freq = 1
        opt.nThreads = 1

    # initialize dataset
    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    if opt.dataset_mode == 'pose':
        print('#training frames = %d' % dataset_size)
    else:
        print('#training videos = %d' % dataset_size)

    # initialize models
    modelG, modelD, flowNet = create_model(opt)
    visualizer = Visualizer(opt)

    iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')
    # if continue training, recover previous states
    if opt.continue_train:
        try:
            start_epoch, epoch_iter = np.loadtxt(iter_path,
                                                 delimiter=',',
                                                 dtype=int)
        except:
            start_epoch, epoch_iter = 1, 0
        print('Resuming from epoch %d at iteration %d' %
              (start_epoch, epoch_iter))
        if start_epoch > opt.niter:
            modelG.module.update_learning_rate(start_epoch - 1)
            modelD.module.update_learning_rate(start_epoch - 1)
        if (opt.n_scales_spatial > 1) and (opt.niter_fix_global != 0) and (
                start_epoch > opt.niter_fix_global):
            modelG.module.update_fixed_params()
        if start_epoch > opt.niter_step:
            data_loader.dataset.update_training_batch(
                (start_epoch - 1) // opt.niter_step)
            modelG.module.update_training_batch(
                (start_epoch - 1) // opt.niter_step)
    else:
        start_epoch, epoch_iter = 1, 0

    # set parameters
    # number of gpus used for generator for each batch
    n_gpus = opt.n_gpus_gen // opt.batchSize
    tG, tD = opt.n_frames_G, opt.n_frames_D
    tDB = tD * opt.output_nc
    s_scales = opt.n_scales_spatial
    t_scales = opt.n_scales_temporal
    input_nc = 1 if opt.label_nc != 0 else opt.input_nc
    output_nc = opt.output_nc

    opt.print_freq = lcm(opt.print_freq, opt.batchSize)
    total_steps = (start_epoch - 1) * dataset_size + epoch_iter
    total_steps = total_steps // opt.print_freq * opt.print_freq

    # real training starts here
    for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()
        for idx, data in enumerate(dataset, start=epoch_iter):
            if total_steps % opt.print_freq == 0:
                iter_start_time = time.time()
            total_steps += opt.batchSize
            epoch_iter += opt.batchSize

            # whether to collect output images
            save_fake = total_steps % opt.display_freq == 0

            # n_frames_total = n_frames_load * n_loadings + tG - 1
            _, n_frames_total, height, width = data['B'].size()
            n_frames_total = n_frames_total // opt.output_nc
            # number of total frames loaded into GPU at a time for each batch
            n_frames_load = opt.max_frames_per_gpu * n_gpus
            n_frames_load = min(n_frames_load, n_frames_total - tG + 1)
            # number of loaded frames plus previous frames
            t_len = n_frames_load + tG - 1

            # the last generated frame from previous training batch (which becomes input to the next batch)
            fake_B_last = None
            # all real/generated frames so far
            real_B_all, fake_B_all, flow_ref_all, conf_ref_all = None, None, None, None
            if opt.sparse_D:
                real_B_all, fake_B_all, flow_ref_all, conf_ref_all = [
                    None
                ] * t_scales, [None] * t_scales, [None] * t_scales, [
                    None
                ] * t_scales
            # temporally subsampled frames
            real_B_skipped, fake_B_skipped = [None] * t_scales, [None
                                                                 ] * t_scales
            flow_ref_skipped, conf_ref_skipped = [None] * t_scales, [
                None
            ] * t_scales  # temporally subsampled flows

            for i in range(0, n_frames_total - t_len + 1, n_frames_load):
                # 5D tensor: batchSize, # of frames, # of channels, height, width
                input_A = Variable(
                    data['A'][:, i * input_nc:(i + t_len) * input_nc,
                              ...]).view(-1, t_len, input_nc, height, width)
                input_B = Variable(
                    data['B'][:, i * output_nc:(i + t_len) * output_nc,
                              ...]).view(-1, t_len, output_nc, height, width)
                inst_A = Variable(data['inst'][:, i:i + t_len, ...]).view(
                    -1, t_len, 1, height,
                    width) if len(data['inst'].size()) > 2 else None

                ###################################### Forward Pass ##########################
                # generator
                fake_B, fake_B_raw, flow, weight, real_A, real_Bp, fake_B_last = modelG(
                    input_A, input_B, inst_A, fake_B_last)

                if i == 0:
                    # the first generated image in this sequence
                    fake_B_first = fake_B[0, 0]
                # the collection of previous and current real frames
                real_B_prev, real_B = real_Bp[:, :-1], real_Bp[:, 1:]

                # discriminator
                # individual frame discriminator
                # reference flows and confidences
                flow_ref, conf_ref = flowNet(real_B, real_B_prev)
                fake_B_prev = real_B_prev[:, 0:
                                          1] if fake_B_last is None else fake_B_last[
                                              0][:, -1:]
                if fake_B.size()[1] > 1:
                    fake_B_prev = torch.cat(
                        [fake_B_prev, fake_B[:, :-1].detach()], dim=1)

                losses = modelD(
                    0,
                    reshape([
                        real_B, fake_B, fake_B_raw, real_A, real_B_prev,
                        fake_B_prev, flow, weight, flow_ref, conf_ref
                    ]))
                losses = [
                    torch.mean(x) if x is not None else 0 for x in losses
                ]
                loss_dict = dict(zip(modelD.module.loss_names, losses))

                # temporal discriminator
                loss_dict_T = []
                # get skipped frames for each temporal scale
                if t_scales > 0:
                    if opt.sparse_D:
                        real_B_all, real_B_skipped = get_skipped_frames_sparse(
                            real_B_all, real_B, t_scales, tD, n_frames_load, i)
                        fake_B_all, fake_B_skipped = get_skipped_frames_sparse(
                            fake_B_all, fake_B, t_scales, tD, n_frames_load, i)
                        flow_ref_all, flow_ref_skipped = get_skipped_frames_sparse(
                            flow_ref_all,
                            flow_ref,
                            t_scales,
                            tD,
                            n_frames_load,
                            i,
                            is_flow=True)
                        conf_ref_all, conf_ref_skipped = get_skipped_frames_sparse(
                            conf_ref_all,
                            conf_ref,
                            t_scales,
                            tD,
                            n_frames_load,
                            i,
                            is_flow=True)
                    else:
                        real_B_all, real_B_skipped = get_skipped_frames(
                            real_B_all, real_B, t_scales, tD)
                        fake_B_all, fake_B_skipped = get_skipped_frames(
                            fake_B_all, fake_B, t_scales, tD)
                        flow_ref_all, conf_ref_all, flow_ref_skipped, conf_ref_skipped = get_skipped_flows(
                            flowNet, flow_ref_all, conf_ref_all,
                            real_B_skipped, flow_ref, conf_ref, t_scales, tD)

                # run discriminator for each temporal scale
                for s in range(t_scales):
                    if real_B_skipped[s] is not None:
                        losses = modelD(s + 1, [
                            real_B_skipped[s], fake_B_skipped[s],
                            flow_ref_skipped[s], conf_ref_skipped[s]
                        ])
                        losses = [
                            torch.mean(x) if not isinstance(x, int) else x
                            for x in losses
                        ]
                        loss_dict_T.append(
                            dict(zip(modelD.module.loss_names_T, losses)))

                # collect losses
                loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5
                loss_G = loss_dict['G_GAN'] + \
                    loss_dict['G_GAN_Feat'] + loss_dict['G_VGG']
                loss_G += loss_dict['G_Warp'] + loss_dict['F_Flow'] + \
                    loss_dict['F_Warp'] + loss_dict['W']
                if opt.add_face_disc:
                    loss_G += loss_dict['G_f_GAN'] + loss_dict['G_f_GAN_Feat']
                    loss_D += (loss_dict['D_f_fake'] +
                               loss_dict['D_f_real']) * 0.5

                # collect temporal losses
                loss_D_T = []
                t_scales_act = min(t_scales, len(loss_dict_T))
                for s in range(t_scales_act):
                    loss_G += loss_dict_T[s]['G_T_GAN'] + \
                        loss_dict_T[s]['G_T_GAN_Feat'] + \
                        loss_dict_T[s]['G_T_Warp']
                    loss_D_T.append((loss_dict_T[s]['D_T_fake'] +
                                     loss_dict_T[s]['D_T_real']) * 0.5)

                ###################################### Backward Pass #################################
                optimizer_G = modelG.module.optimizer_G
                optimizer_D = modelD.module.optimizer_D
                # update generator weights
                optimizer_G.zero_grad()
                loss_G.backward()
                optimizer_G.step()

                # update discriminator weights
                # individual frame discriminator
                optimizer_D.zero_grad()
                loss_D.backward()
                optimizer_D.step()
                # temporal discriminator
                for s in range(t_scales_act):
                    optimizer_D_T = getattr(modelD.module,
                                            'optimizer_D_T' + str(s))
                    optimizer_D_T.zero_grad()
                    loss_D_T[s].backward()
                    optimizer_D_T.step()

            if opt.debug:
                call([
                    "nvidia-smi", "--format=csv",
                    "--query-gpu=memory.used,memory.free"
                ])

            ############## Display results and errors ##########
            # print out errors
            if total_steps % opt.print_freq == 0:
                t = (time.time() - iter_start_time) / opt.print_freq
                errors = {
                    k: v.data.item() if not isinstance(v, int) else v
                    for k, v in loss_dict.items()
                }
                for s in range(len(loss_dict_T)):
                    errors.update({
                        k + str(s):
                        v.data.item() if not isinstance(v, int) else v
                        for k, v in loss_dict_T[s].items()
                    })
                visualizer.print_current_errors(epoch, epoch_iter, errors, t)
                visualizer.plot_current_errors(errors, total_steps)

            # display output images
            if save_fake:
                if opt.label_nc != 0:
                    input_image = util.tensor2label(real_A[0, -1],
                                                    opt.label_nc)
                elif opt.dataset_mode == 'pose':
                    input_image = util.tensor2im(real_A[0, -1, :3])
                    if real_A.size()[2] == 6:
                        input_image2 = util.tensor2im(real_A[0, -1, 3:])
                        input_image[input_image2 != 0] = input_image2[
                            input_image2 != 0]
                else:
                    c = 3 if opt.input_nc == 3 else 1
                    input_image = util.tensor2im(real_A[0, -1, :c],
                                                 normalize=False)
                if opt.use_instance:
                    edges = util.tensor2im(real_A[0, -1, -1:, ...],
                                           normalize=False)
                    input_image += edges[:, :, np.newaxis]

                if opt.add_face_disc:
                    ys, ye, xs, xe = modelD.module.get_face_region(real_A[0,
                                                                          -1:])
                    if ys is not None:
                        input_image[ys, xs:xe, :] = input_image[
                            ye, xs:xe, :] = input_image[
                                ys:ye, xs, :] = input_image[ys:ye, xe, :] = 255

                visual_list = [
                    ('input_image', input_image),
                    ('fake_image', util.tensor2im(fake_B[0, -1])),
                    ('fake_first_image', util.tensor2im(fake_B_first)),
                    ('fake_raw_image', util.tensor2im(fake_B_raw[0, -1])),
                    ('real_image', util.tensor2im(real_B[0, -1])),
                    ('flow_ref', util.tensor2flow(flow_ref[0, -1])),
                    ('conf_ref',
                     util.tensor2im(conf_ref[0, -1], normalize=False))
                ]
                if flow is not None:
                    visual_list += [('flow', util.tensor2flow(flow[0, -1])),
                                    ('weight',
                                     util.tensor2im(weight[0, -1],
                                                    normalize=False))]
                visuals = OrderedDict(visual_list)
                visualizer.display_current_results(visuals, epoch, total_steps)

            # save latest model
            if total_steps % opt.save_latest_freq == 0:
                visualizer.vis_print(
                    'saving the latest model (epoch %d, total_steps %d)' %
                    (epoch, total_steps))
                modelG.module.save('latest')
                modelD.module.save('latest')
                np.savetxt(iter_path, (epoch, epoch_iter),
                           delimiter=',',
                           fmt='%d')

            if epoch_iter > dataset_size - opt.batchSize:
                epoch_iter = 0
                break

        # end of epoch
        iter_end_time = time.time()
        visualizer.vis_print('End of epoch %d / %d \t Time Taken: %d sec' %
                             (epoch, opt.niter + opt.niter_decay,
                              time.time() - epoch_start_time))

        # save model for this epoch
        if epoch % opt.save_epoch_freq == 0:
            visualizer.vis_print(
                'saving the model at the end of epoch %d, iters %d' %
                (epoch, total_steps))
            modelG.module.save('latest')
            modelD.module.save('latest')
            modelG.module.save(epoch)
            modelD.module.save(epoch)
            np.savetxt(iter_path, (epoch + 1, 0), delimiter=',', fmt='%d')

        # linearly decay learning rate after certain iterations
        if epoch > opt.niter:
            modelG.module.update_learning_rate(epoch)
            modelD.module.update_learning_rate(epoch)

        # gradually grow training sequence length
        if (epoch % opt.niter_step) == 0:
            data_loader.dataset.update_training_batch(epoch // opt.niter_step)
            modelG.module.update_training_batch(epoch // opt.niter_step)

        # finetune all scales
        if (opt.n_scales_spatial > 1) and (opt.niter_fix_global != 0) and (
                epoch == opt.niter_fix_global):
            modelG.module.update_fixed_params()
Exemplo n.º 22
0
def main_worker(gpu, world_size, idx_server, opt):
    print('Use GPU: {} for training'.format(gpu))
    ngpus_per_node = world_size
    world_size = opt.world_size
    rank = idx_server * ngpus_per_node + gpu
    opt.gpu = gpu
    dist.init_process_group(backend='nccl', init_method=opt.dist_url, world_size=world_size, rank=rank)
    torch.cuda.set_device(opt.gpu)

    # load the dataset
    dataloader = data.create_dataloader(opt, world_size, rank)
    
    # create trainer for our model
    trainer = Pix2PixTrainer(opt)
    
    # create tool for counting iterations
    iter_counter = IterationCounter(opt, len(dataloader), world_size, rank)
    
    # create tool for visualization
    visualizer = Visualizer(opt, rank)
    
    for epoch in iter_counter.training_epochs():
        # set epoch for data sampler
        dataloader.sampler.set_epoch(epoch)

        iter_counter.record_epoch_start(epoch)

        for i, data_i in enumerate(dataloader, start=iter_counter.epoch_iter):
            iter_counter.record_one_iteration()
    
            # Training
            # train generator
            trainer.run_generator_one_step(data_i)

            # train discriminator
            trainer.run_discriminator_one_step(data_i)

            # Visualizations
            if iter_counter.needs_printing():
                losses = trainer.get_latest_losses()
                visualizer.print_current_errors(epoch, iter_counter.epoch_iter,
                                                losses, iter_counter.time_per_iter)
                visualizer.plot_current_errors(losses, iter_counter.total_steps_so_far)

        visuals = OrderedDict([('input_label', data_i['label']),
                               ('synthesized_image', trainer.get_latest_generated()),
                               ('real_image', data_i['image'])])
        visualizer.display_current_results(visuals, epoch, iter_counter.total_steps_so_far)

        if rank == 0:
            print('saving the latest model (epoch %d, total_steps %d)' %
                  (epoch, iter_counter.total_steps_so_far))
            trainer.save('latest')
            iter_counter.record_current_iter()

        trainer.update_learning_rate(epoch)
        iter_counter.record_epoch_end()

        if (epoch % opt.save_epoch_freq == 0 or epoch == iter_counter.total_epochs) and (rank == 0):
            print('saving the model at the end of epoch %d, iters %d' %
                  (epoch, iter_counter.total_steps_so_far))
            trainer.save(epoch)
    
    print('Training was successfully finished.')
Exemplo n.º 23
0
def test_train():
    # print options to help debugging
    # print(' '.join(sys.argv))

    # load the dataset
    dataloader = data.create_dataloader(opt)

    # create trainer for our model
    trainer = Pix2PixTrainer(opt)

    # create tool for counting iterations
    iter_counter = IterationCounter(opt, len(dataloader))

    # create tool for visualization
    visualizer = Visualizer(opt)

    for epoch in iter_counter.training_epochs():
        iter_counter.record_epoch_start(epoch)
        for i, data_i in enumerate(dataloader, start=iter_counter.epoch_iter):
            iter_counter.record_one_iteration()

            # Training
            # train generator
            if i % opt.D_steps_per_G == 0:
                trainer.run_generator_one_step(data_i)

            # train discriminator
            trainer.run_discriminator_one_step(data_i)

            # Visualizations
            if iter_counter.needs_printing():
                losses = trainer.get_latest_losses()
                visualizer.print_current_errors(epoch, iter_counter.epoch_iter,
                                                losses,
                                                iter_counter.time_per_iter)
                visualizer.plot_current_errors(losses,
                                               iter_counter.total_steps_so_far)

            if iter_counter.needs_displaying():
                visuals = OrderedDict([('input_label', data_i['label']),
                                       ('synthesized_image',
                                        trainer.get_latest_generated()),
                                       ('real_image', data_i['image'])])
                visualizer.display_current_results(
                    visuals, epoch, iter_counter.total_steps_so_far)

            if iter_counter.needs_saving():
                print('saving the latest model (epoch %d, total_steps %d)' %
                      (epoch, iter_counter.total_steps_so_far))
                trainer.save('latest')
                iter_counter.record_current_iter()

        trainer.update_learning_rate(epoch)
        iter_counter.record_epoch_end()

        if epoch % opt.save_epoch_freq == 0 or \
           epoch == iter_counter.total_epochs:
            print('saving the model at the end of epoch %d, iters %d' %
                  (epoch, iter_counter.total_steps_so_far))
            trainer.save('latest')
            trainer.save(epoch)

    print('Training was successfully finished.')
Exemplo n.º 24
0
def train():
    opt = TrainOptions().parse()
    iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')
    if opt.continue_train:
        try:
            start_epoch, epoch_iter = np.loadtxt(iter_path,
                                                 delimiter=',',
                                                 dtype=int)
        except:
            start_epoch, epoch_iter = 1, 0
        # compute resume lr
        if start_epoch > opt.niter:
            lrd_unit = opt.lr / opt.niter_decay
            resume_lr = opt.lr - (start_epoch - opt.niter) * lrd_unit
            opt.lr = resume_lr
        print('Resuming from epoch %d at iteration %d' %
              (start_epoch, epoch_iter))
    else:
        start_epoch, epoch_iter = 1, 0

    opt.print_freq = lcm(opt.print_freq, opt.batchSize)
    if opt.debug:
        opt.display_freq = 2
        opt.print_freq = 2
        opt.niter = 3
        opt.niter_decay = 0
        opt.max_dataset_size = 1
        opt.valSize = 1

    ## Loading data
    # train data
    data_loader = CreateDataLoader(opt, isVal=False)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    print('# training images = %d' % dataset_size)
    # validation data
    data_loader = CreateDataLoader(opt, isVal=True)
    valset = data_loader.load_data()
    print('# validation images = %d' % len(data_loader))

    ## Loading model
    model = create_model(opt)
    visualizer = Visualizer(opt)
    if opt.fp16:
        from apex import amp
        model, [optimizer_G, optimizer_D
                ] = amp.initialize(model,
                                   [model.optimizer_G, model.optimizer_D],
                                   opt_level='O1')
        model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids)
    else:
        optimizer_G, optimizer_D = model.module.optimizer_G, model.module.optimizer_D

    total_steps = (start_epoch - 1) * dataset_size + epoch_iter

    display_delta = total_steps % opt.display_freq
    print_delta = total_steps % opt.print_freq
    save_delta = total_steps % opt.save_latest_freq

    for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()
        if epoch != start_epoch:
            epoch_iter = epoch_iter % dataset_size
        for i, data in enumerate(dataset, start=epoch_iter):
            if total_steps % opt.print_freq == print_delta:
                iter_start_time = time.time()

            total_steps += opt.batchSize
            epoch_iter += opt.batchSize

            # whether to collect output images
            save_fake = total_steps % opt.display_freq == display_delta

            ############## Forward Pass ######################
            model = model.train()
            losses, generated, metrics = model(data['A'],
                                               data['B'],
                                               data['geometry'],
                                               infer=False)

            # sum per device losses and metrics
            losses = [
                torch.mean(x) if not isinstance(x, int) else x for x in losses
            ]
            metric_dict = {k: torch.mean(v) for k, v in metrics.items()}
            loss_dict = dict(zip(model.module.loss_names, losses))

            # calculate final loss scalar
            loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5
            loss_G = loss_dict['G_GAN'] + opt.gan_feat_weight * loss_dict.get(
                'G_GAN_Feat', 0) + opt.vgg_weight * loss_dict.get('G_VGG', 0)

            ############### Backward Pass ####################
            # update generator weights
            optimizer_G.zero_grad()
            if opt.fp16:
                with amp.scale_loss(loss_G, optimizer_G) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss_G.backward()
            optimizer_G.step()

            # update discriminator weights
            optimizer_D.zero_grad()
            if opt.fp16:
                with amp.scale_loss(loss_D, optimizer_D) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss_D.backward()
            optimizer_D.step()

            ############## Display results and errors ##########
            ### print out errors
            if total_steps % opt.print_freq == print_delta:
                errors = {
                    k: v.data.item() if not isinstance(v, int) else v
                    for k, v in loss_dict.items()
                }
                metrics_ = {
                    k: v.data.item() if not isinstance(v, int) else v
                    for k, v in metric_dict.items()
                }
                t = (time.time() - iter_start_time) / opt.print_freq
                visualizer.print_current_errors(epoch, epoch_iter, errors, t)
                visualizer.plot_current_errors(errors, total_steps)
                visualizer.print_current_metrics(epoch, epoch_iter, metrics_,
                                                 t)
                visualizer.plot_current_metrics(metrics_, total_steps)
                #call(["nvidia-smi", "--format=csv", "--query-gpu=memory.used,memory.free"])

            ### display output images
            if save_fake:
                if opt.task_type == 'specular':
                    visuals = OrderedDict([
                        ('albedo', util.tensor2im(data['A'][0])),
                        ('generated',
                         util.tensor2im_exr(generated.data[0], type=1)),
                        ('GT', util.tensor2im_exr(data['B'][0], type=1))
                    ])
                elif opt.task_type == 'low':
                    visuals = OrderedDict([
                        ('albedo', util.tensor2im(data['A'][0])),
                        ('generated',
                         util.tensor2im_exr(generated.data[0], type=2)),
                        ('GT', util.tensor2im_exr(data['B'][0], type=2))
                    ])
                elif opt.task_type == 'high':
                    visuals = OrderedDict([
                        ('albedo', util.tensor2im(data['A'][0])),
                        ('generated',
                         util.tensor2im_exr(generated.data[0], type=3)),
                        ('GT', util.tensor2im_exr(data['B'][0], type=3))
                    ])
                visualizer.display_current_results(visuals, epoch, total_steps)

            ### save latest model
            if total_steps % opt.save_latest_freq == save_delta:
                print('saving the latest model (epoch %d, total_steps %d)' %
                      (epoch, total_steps))
                model.module.save('latest')
                np.savetxt(iter_path, (epoch, epoch_iter),
                           delimiter=',',
                           fmt='%d')

            if epoch_iter >= dataset_size:
                break

        # end of epoch
        print('End of epoch %d / %d \t Time Taken: %d sec' %
              (epoch, opt.niter + opt.niter_decay,
               time.time() - epoch_start_time))

        ###########################################################################################
        # validation at the end of each epoch
        val_start_time = time.time()
        metrics_val = []
        for _, val_data in enumerate(valset):
            model = model.eval()
            # model.half()
            generated, metrics = model(val_data['A'],
                                       val_data['B'],
                                       val_data['geometry'],
                                       infer=True)
            metric_dict = {k: torch.mean(v) for k, v in metrics.items()}
            metrics_ = {
                k: v.data.item() if not isinstance(v, int) else v
                for k, v in metric_dict.items()
            }
            metrics_val.append(metrics_)
        # Print out losses
        metrics_val = visualizer.mean4dict(metrics_val)
        t = (time.time() - val_start_time) / opt.print_freq
        visualizer.print_current_metrics(epoch,
                                         epoch_iter,
                                         metrics_val,
                                         t,
                                         isVal=True)
        visualizer.plot_current_metrics(metrics_val, total_steps, isVal=True)
        # visualization
        if opt.task_type == 'specular':
            visuals = OrderedDict([
                ('albedo', util.tensor2im(val_data['A'][0])),
                ('generated', util.tensor2im_exr(generated.data[0], type=1)),
                ('GT', util.tensor2im_exr(val_data['B'][0], type=1))
            ])
        if opt.task_type == 'low':
            visuals = OrderedDict([
                ('albedo', util.tensor2im(val_data['A'][0])),
                ('generated', util.tensor2im_exr(generated.data[0], type=2)),
                ('GT', util.tensor2im_exr(val_data['B'][0], type=2))
            ])
        if opt.task_type == 'high':
            visuals = OrderedDict([
                ('albedo', util.tensor2im(val_data['A'][0])),
                ('generated', util.tensor2im_exr(generated.data[0], type=3)),
                ('GT', util.tensor2im_exr(val_data['B'][0], type=3))
            ])
        visualizer.display_current_results(visuals, epoch, epoch, isVal=True)
        ###########################################################################################

        ### save model for this epoch
        if epoch % opt.save_epoch_freq == 0:
            print('saving the model at the end of epoch %d, iters %d' %
                  (epoch, total_steps))
            model.module.save('latest')
            model.module.save(epoch)
            np.savetxt(iter_path, (epoch + 1, 0), delimiter=',', fmt='%d')

        ### instead of only training the local enhancer, train the entire network after certain iterations
        if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global):
            model.module.update_fixed_params()

        ### linearly decay learning rate after certain iterations
        if epoch > opt.niter:
            model.module.update_learning_rate()
Exemplo n.º 25
0
def do_train(opt):
    dataloader = data.create_dataloader(opt)
    # dataset [CustomDataset] of size 2000 was created

    # create trainer for our model
    trainer = Pix2PixTrainer(opt)
    # Network [SPADEGenerator] was created. Total number of parameters: 92.5 million. To see the architecture, do print(network).
    # Network [MultiscaleDiscriminator] was created. Total number of parameters: 5.6 million. To see the architecture, do print(network).
    # Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth
    # HBox(children=(FloatProgress(value=0.0, max=574673361.0), HTML(value='')))

    # create tool for counting iterations
    iter_counter = IterationCounter(opt, len(dataloader))

    # create tool for visualization
    visualizer = Visualizer(opt)
    # create web directory ./checkpoints/ipdb_test/web...

    for epoch in iter_counter.training_epochs():
        iter_counter.record_epoch_start(epoch)
        for i, data_i in enumerate(dataloader, start=iter_counter.epoch_iter):
            # data_i =
            # {'label': tensor([[[[ 0.,  0.,  0.,  ...,  0.,  0.,  0.],
            #           [ 0.,  0.,  0.,  ...,  0.,  0.,  0.],
            #           [ 0.,  0.,  0.,  ...,  0.,  0.,  0.],
            #           ...,
            #           [ 0.,  0.,  0.,  ..., 13., 13., 13.],
            #           [ 0.,  0.,  0.,  ..., 13., 13., 13.],
            #           [ 0.,  0.,  0.,  ..., 13., 13., 13.]]]]), 'instance': tensor([0]), 'image': tensor([[[[-1.0000, -1.0000, -0.9922,  ...,  0.5529,  0.5529,  0.5529],
            #           [-1.0000, -1.0000, -0.9922,  ...,  0.5529,  0.5529,  0.5529],
            #           [-1.0000, -0.9922, -0.9843,  ...,  0.5529,  0.5529,  0.5529],
            #           ...,
            #           [ 0.4118,  0.4275,  0.4118,  ..., -0.7490, -0.7333, -0.7020],
            #           [ 0.4196,  0.4039,  0.4196,  ..., -0.7020, -0.7804, -0.7255],
            #           [ 0.4039,  0.4196,  0.4588,  ..., -0.6784, -0.7333, -0.6941]],

            #          [[-0.9529, -0.9686, -0.9843,  ...,  0.5843,  0.5843,  0.5843],
            #           [-0.9529, -0.9686, -0.9843,  ...,  0.5843,  0.5843,  0.5843],
            #           [-0.9608, -0.9686, -0.9765,  ...,  0.5843,  0.5843,  0.5843],
            #           ...,
            #           [ 0.4431,  0.4588,  0.4431,  ..., -0.8510, -0.8353, -0.8039],
            #           [ 0.4510,  0.4353,  0.4510,  ..., -0.8039, -0.8824, -0.8275],
            #           [ 0.4353,  0.4510,  0.4902,  ..., -0.7725, -0.8275, -0.7882]],

            #          [[-0.9843, -1.0000, -1.0000,  ...,  0.6549,  0.6549,  0.6549],
            #           [-0.9843, -1.0000, -1.0000,  ...,  0.6549,  0.6549,  0.6549],
            #           [-0.9922, -1.0000, -0.9922,  ...,  0.6549,  0.6549,  0.6549],
            #           ...,
            #           [ 0.5294,  0.5451,  0.5294,  ..., -0.9216, -0.8980, -0.8667],
            #           [ 0.5373,  0.5216,  0.5373,  ..., -0.8824, -0.9529, -0.8980],
            #           [ 0.5216,  0.5373,  0.5765,  ..., -0.8667, -0.9216, -0.8824]]]]), 'path': ['../../Celeb_subset/train/images/8516.jpg']}
            iter_counter.record_one_iteration()

            # Training
            # train generator
            if i % opt.D_steps_per_G == 0:
                trainer.run_generator_one_step(data_i)

            # train discriminator
            trainer.run_discriminator_one_step(data_i)

            # Visualizations
            if iter_counter.needs_printing():
                losses = trainer.get_latest_losses()
                visualizer.print_current_errors(epoch, iter_counter.epoch_iter,
                                                losses,
                                                iter_counter.time_per_iter)
                visualizer.plot_current_errors(losses,
                                               iter_counter.total_steps_so_far)

            if iter_counter.needs_displaying():
                visuals = OrderedDict([('input_label', data_i['label']),
                                       ('synthesized_image',
                                        trainer.get_latest_generated()),
                                       ('real_image', data_i['image'])])
                visualizer.display_current_results(
                    visuals, epoch, iter_counter.total_steps_so_far)

            if iter_counter.needs_saving():
                print('saving the latest model (epoch %d, total_steps %d)' %
                      (epoch, iter_counter.total_steps_so_far))
                trainer.save('latest')
                iter_counter.record_current_iter()

        trainer.update_learning_rate(epoch)
        iter_counter.record_epoch_end()

        if epoch % opt.save_epoch_freq == 0 or \
        epoch == iter_counter.total_epochs:
            print('saving the model at the end of epoch %d, iters %d' %
                  (epoch, iter_counter.total_steps_so_far))
            trainer.save('latest')
            trainer.save(epoch)

    print('Training was successfully finished.')
Exemplo n.º 26
0
def run(opt):
    print("Number of GPUs used: {}".format(torch.cuda.device_count()))
    print("Current Experiment Name: {}".format(opt.name))

    # The dataloader will yield the training samples
    dataloader = data.create_dataloader(opt)

    trainer = TrainerManager(opt)
    inference_manager = InferenceManager(
        num_samples=opt.num_evaluation_samples,
        opt=opt,
        cuda=len(opt.gpu_ids) > 0,
        write_details=False,
        save_images=False)

    # For logging and visualizations
    iter_counter = IterationCounter(opt, len(dataloader))
    visualizer = Visualizer(opt)

    if not opt.debug:
        # We keep a copy of the current source code for each experiment
        copy_src(path_from="./",
                 path_to=os.path.join(opt.checkpoints_dir, opt.name))

    # We wrap training into a try/except clause such that the model is saved
    # when interrupting with Ctrl+C
    try:
        for epoch in iter_counter.training_epochs():
            iter_counter.record_epoch_start(epoch)
            for i, data_i in enumerate(dataloader,
                                       start=iter_counter.epoch_iter):

                # Training the generator
                if i % opt.D_steps_per_G == 0:
                    trainer.run_generator_one_step(data_i)

                # Training the discriminator
                trainer.run_discriminator_one_step(data_i)

                iter_counter.record_one_iteration()

                # Logging, plotting and visualizing
                if iter_counter.needs_printing():
                    losses = trainer.get_latest_losses()
                    visualizer.print_current_errors(
                        epoch, iter_counter.epoch_iter, losses,
                        iter_counter.time_per_iter,
                        iter_counter.total_time_so_far,
                        iter_counter.total_steps_so_far)
                    visualizer.plot_current_errors(
                        losses, iter_counter.total_steps_so_far)

                if iter_counter.needs_displaying():
                    logs = trainer.get_logs()
                    visuals = [('input_label', data_i['label']),
                               ('out_train', trainer.get_latest_generated()),
                               ('real_train', data_i['image'])]
                    if opt.guiding_style_image:
                        visuals.append(
                            ('guiding_image', data_i['guiding_image']))
                        visuals.append(
                            ('guiding_input_label', data_i['guiding_label']))

                    if opt.evaluate_val_set:
                        validation_output = inference_validation(
                            trainer.sr_model, inference_manager, opt)
                        visuals += validation_output
                    visuals = OrderedDict(visuals)
                    visualizer.display_current_results(
                        visuals, epoch, iter_counter.total_steps_so_far, logs)

                if iter_counter.needs_saving():
                    print(
                        'Saving the latest model (epoch %d, total_steps %d)' %
                        (epoch, iter_counter.total_steps_so_far))
                    trainer.save('latest')
                    iter_counter.record_current_iter()

                if iter_counter.needs_evaluation():
                    # Evaluate on training set
                    result_train = evaluate_training_set(
                        inference_manager, trainer.sr_model_on_one_gpu,
                        dataloader)
                    info = iter_counter.record_fid(
                        result_train["FID"],
                        split="train",
                        num_samples=opt.num_evaluation_samples)
                    info += os.linesep + iter_counter.record_metrics(
                        result_train, split="train")
                    visualizer.plot_current_errors(
                        result_train,
                        iter_counter.total_steps_so_far,
                        split="train/")

                    if opt.evaluate_val_set:
                        # Evaluate on validation set
                        result_val = evaluate_validation_set(
                            inference_manager, trainer.sr_model_on_one_gpu,
                            opt)
                        info += os.linesep + iter_counter.record_fid(
                            result_val["FID"],
                            split="validation",
                            num_samples=opt.num_evaluation_samples)
                        info += os.linesep + iter_counter.record_metrics(
                            result_val, split="validation")
                        visualizer.plot_current_errors(
                            result_val,
                            iter_counter.total_steps_so_far,
                            split="validation/")

            trainer.update_learning_rate(epoch)
            iter_counter.record_epoch_end()

            if epoch % opt.save_epoch_freq == 0 or \
                            epoch == iter_counter.total_epochs:
                print('Saving the model at the end of epoch %d, iters %d' %
                      (epoch, iter_counter.total_steps_so_far))
                trainer.save('latest')
                trainer.save(epoch)
                iter_counter.record_current_iter()

        print('Training was successfully finished.')
    except (KeyboardInterrupt, SystemExit):
        print("KeyboardInterrupt. Shutting down.")
        print(traceback.format_exc())
    except Exception as e:
        print(traceback.format_exc())
    finally:
        print('Saving the model before quitting')
        trainer.save('latest')
        iter_counter.record_current_iter()
Exemplo n.º 27
0
        model.module.optimizer_G.step()

        # update discriminator weights
        model.module.optimizer_D.zero_grad()
        loss_D.backward()
        model.module.optimizer_D.step()

        # call(["nvidia-smi", "--format=csv", "--query-gpu=memory.used,memory.free"])

        ############## Display results and errors ##########
        # print out errors
        if total_steps % opt.print_freq == print_delta:
            errors = {k: v.data[0] if not isinstance(v, int) else v for k, v in loss_dict.items()}
            t = (time.time() - iter_start_time) / opt.batchSize
            visualizer.print_current_errors(epoch, epoch_iter, errors, t)
            visualizer.plot_current_errors(errors, total_steps)

        # display output images
        if save_fake:
            visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)),
                                   ('synthesized_image', util.tensor2im(generated.data[0])),
                                   ('real_image', util.tensor2im(data['image'][0]))])
            visualizer.display_current_results(visuals, epoch, total_steps)

        # save latest model
        if total_steps % opt.save_latest_freq == save_delta:
            print('saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps))
            model.module.save('latest')
            np.savetxt(iter_path, (epoch, epoch_iter), delimiter=',', fmt='%d')

        if epoch_iter >= dataset_size:
Exemplo n.º 28
0
def train_one_set(n_layers_frozen, opt):
    # create trainer for our model and freeze necessary model layers
    opt.niter = opt.niter + 20  # 20 more iterations of training (5000 image passes) per set
    opt.lr = 0.00002  # 1/10th of the original lr
    trainer = Pix2PixTrainer(opt)
    model = next(trainer.pix2pix_model.children())
    freeze_layers(model, n_layers_frozen)

    # Proceed with training.

    # load the dataset
    dataloader = data.create_dataloader(opt)

    # trainer = Pix2PixTrainer(opt)

    # create tool for counting iterations
    iter_counter = IterationCounter(opt, len(dataloader))

    # create tool for visualization
    visualizer = Visualizer(opt)

    for epoch in iter_counter.training_epochs():
        iter_counter.record_epoch_start(epoch)
        for i, data_i in enumerate(dataloader, start=iter_counter.epoch_iter):
            iter_counter.record_one_iteration()

            # Training
            # train generator
            if i % opt.D_steps_per_G == 0:
                trainer.run_generator_one_step(data_i)

            # train discriminator
            trainer.run_discriminator_one_step(data_i)

            # Visualizations
            if iter_counter.needs_printing():
                losses = trainer.get_latest_losses()
                visualizer.print_current_errors(epoch, iter_counter.epoch_iter,
                                                losses,
                                                iter_counter.time_per_iter)
                visualizer.plot_current_errors(losses,
                                               iter_counter.total_steps_so_far)

            if iter_counter.needs_displaying():
                visuals = OrderedDict([('input_label', data_i['label']),
                                       ('synthesized_image',
                                        trainer.get_latest_generated()),
                                       ('real_image', data_i['image'])])
                visualizer.display_current_results(
                    visuals, epoch, iter_counter.total_steps_so_far)

            if iter_counter.needs_saving():
                print('saving the latest model (epoch %d, total_steps %d)' %
                      (epoch, iter_counter.total_steps_so_far))
                trainer.save('latest')
                iter_counter.record_current_iter()

        trainer.update_learning_rate(epoch)
        iter_counter.record_epoch_end()

        if epoch % opt.save_epoch_freq == 0 or \
           epoch == iter_counter.total_epochs:
            print('saving the model at the end of epoch %d, iters %d' %
                  (epoch, iter_counter.total_steps_so_far))
            trainer.save('latest')
            trainer.save(epoch)
Exemplo n.º 29
0
        total_steps += opt.batchSize
        epoch_iter += opt.batchSize
        model.set_input(data)
        model.optimize_parameters()

        if total_steps % opt.display_freq == 0:
            visualizer.display_current_results(model.get_current_visuals(),
                                               epoch)

        if total_steps % opt.print_freq == 0:
            errors = model.get_current_errors()
            t = (time.time() - iter_start_time) / opt.batchSize
            visualizer.print_current_errors(epoch, epoch_iter, errors, t)
            if opt.display_id > 0:
                visualizer.plot_current_errors(
                    epoch,
                    float(epoch_iter) / dataset_size, opt, errors)

        if total_steps % opt.save_latest_freq == 0:
            print('saving the latest model (epoch %d, total_steps %d)' %
                  (epoch, total_steps))
            model.save('latest')

    if epoch % opt.save_epoch_freq == 0:
        print('saving the model at the end of epoch %d, iters %d' %
              (epoch, total_steps))
        model.save('latest')
        model.save(epoch)

    print('End of epoch %d / %d \t Time Taken: %d sec' %
          (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
Exemplo n.º 30
0
def train():
    import time
    from options.train_options import TrainOptions
    from data import CreateDataLoader
    from models import create_model
    from util.visualizer import Visualizer
    opt = TrainOptions().parse()
    model = create_model(opt)
    #Loading data
    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    print('Training images = %d' % dataset_size)
    visualizer = Visualizer(opt)
    total_steps = 0
    #Starts training
    for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()
        iter_data_time = time.time()
        epoch_iter = 0
        for i, data in enumerate(dataset):
            iter_start_time = time.time()
            if total_steps % opt.print_freq == 0:
                t_data = iter_start_time - iter_data_time
            visualizer.reset()
            total_steps += opt.batchSize
            epoch_iter += opt.batchSize
            model.set_input(data)
            model.optimize_parameters()
            #Save current images (real_A, real_B, fake_B)
            if epoch_iter % opt.display_freq == 0:
                save_result = total_steps % opt.update_html_freq == 0
                visualizer.display_current_results(model.get_current_visuals(),
                                                   epoch, epoch_iter,
                                                   save_result)
            #Save current errors
            if total_steps % opt.print_freq == 0:
                errors = model.get_current_errors()
                t = (time.time() - iter_start_time) / opt.batchSize
                visualizer.print_current_errors(epoch, epoch_iter, errors, t,
                                                t_data)
                if opt.display_id > 0:
                    visualizer.plot_current_errors(
                        epoch,
                        float(epoch_iter) / dataset_size, opt, errors)
            #Save model based on the number of iterations
            if total_steps % opt.save_latest_freq == 0:
                print('saving the latest model (epoch %d, total_steps %d)' %
                      (epoch, total_steps))
                model.save('latest')

            iter_data_time = time.time()
        #Save model based on the number of epochs
        print(opt.dataset_mode)
        if epoch % opt.save_epoch_freq == 0:
            print('saving the model at the end of epoch %d, iters %d' %
                  (epoch, total_steps))
            model.save('latest')
            model.save(epoch)

        print('End of epoch %d / %d \t Time Taken: %d sec' %
              (epoch, opt.niter + opt.niter_decay,
               time.time() - epoch_start_time))
        model.update_learning_rate()
Exemplo n.º 31
0
        model.module.optimizer_G.step()

        # update discriminator weights
        model.module.optimizer_D.zero_grad()
        loss_D.backward()
        model.module.optimizer_D.step()

        #call(["nvidia-smi", "--format=csv", "--query-gpu=memory.used,memory.free"]) 

        ############## Display results and errors ##########
        ### print out errors
        if total_steps % opt.print_freq == print_delta:
            errors = {k: v.data[0] if not isinstance(v, int) else v for k, v in loss_dict.items()}
            t = (time.time() - iter_start_time) / opt.batchSize
            visualizer.print_current_errors(epoch, epoch_iter, errors, t)
            visualizer.plot_current_errors(errors, total_steps)

        ### display output images
        if save_fake:
            visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)),
                                   ('synthesized_image', util.tensor2im(generated.data[0])),
                                   ('real_image', util.tensor2im(data['image'][0]))])
            visualizer.display_current_results(visuals, epoch, total_steps)

        ### save latest model
        if total_steps % opt.save_latest_freq == save_delta:
            print('saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps))
            model.module.save('latest')            
            np.savetxt(iter_path, (epoch, epoch_iter), delimiter=',', fmt='%d')
       
    # end of epoch 
Exemplo n.º 32
0
        total_steps += 1 * opt.batchSize
        epoch_iter = total_steps - dataset_size * (epoch - 1)
        model.set_input(face_dataset[idx])
        model.optimize_stage1_parameters()
        # if total_steps % opt.display_freq == 0:
        visualizer.display_current_results(model.get_current_visuals(stage),
                                           epoch, stage)

        # if total_steps % opt.print_freq == 0:
        errors = model.get_current_errors(stage)
        # t = (time.time() - iter_start_time) / opt.batchSize
        t = time.time() - iter_start_time
        visualizer.print_current_errors(epoch, epoch_iter, errors, t)
        if opt.display_id > 0:
            visualizer.plot_current_errors(epoch,
                                           float(epoch_iter) / dataset_size,
                                           opt, errors, stage)

        # print(total_steps)
        if total_steps % opt.save_latest_freq == 0:
            print('saving the latest model (epoch %d, total_steps %d)' %
                  (epoch, total_steps))
            model.save('latest', stage)

    if epoch % opt.save_epoch_freq == 0:
        print('saving the model at the end of epoch %d, iters %d' %
              (epoch, total_steps))
        model.save('latest', stage)
        model.save(epoch, stage)

    print('End of epoch %d / %d \t Time Taken: %d sec' %